Source code for chop.nn.quantized.modules.attention_head
import torch
from torch import Tensor
import torch.nn as nn
import math
from typing import Optional, Tuple
from functools import partial
from chop.nn.quantized.functional.matmul import (
generic_matmul_integer,
)
from chop.nn.quantizers.integer import integer_quantizer
class _BertSelfAttentionHeadBase(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
# ! TO DO: replace these with quantized functions?
self.matmul = torch.matmul
self.softmax = nn.functional.softmax
def self_attention_head(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
) -> Tensor:
attention_scores = self.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = self.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = self.matmul(attention_probs, value_layer)
return context_layer
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
) -> Tensor:
return self.self_attention_head(
query_layer=query_layer,
key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask,
)
[docs]
class BertSelfAttentionHeadInteger(_BertSelfAttentionHeadBase):
[docs]
def __init__(self, config, q_config: dict = None) -> None:
super().__init__(config)
self.query_quantizer = partial(
integer_quantizer,
**q_config,
)
self.key_quantizer = partial(integer_quantizer, **q_config)
self.value_quantizer = partial(integer_quantizer, **q_config)
[docs]
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
) -> Tensor:
query_layer = self.query_quantizer(query_layer)
key_layer = self.key_quantizer(key_layer)
value_layer = self.value_quantizer(value_layer)
return self.self_attention_head(
query_layer=query_layer,
key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask,
)