Source code for chop.nn.quantized.modules.attention

from functools import partial

import torch
from torch import Tensor
from torch.nn import functional as F

from transformers.models.bert.modeling_bert import BertSelfAttention

from chop.nn.quantized.modules.linear import (
    LinearInteger,
)
from chop.nn.quantized.functional import fixed_softermax
from chop.nn.quantized.functional import matmul_integer

from typing import Optional, Tuple


class _BertSelfAttentionBase(BertSelfAttention):
    def __init__(
        self,
        config,
        q_config: dict = None,
        out_q_config: dict = None,
        position_embedding_type=None,
        bias=True,
        output_tensor_only=False,
    ) -> None:
        super().__init__(config, position_embedding_type)
        self.bypass = False
        self.q_config = q_config
        self.out_q_config = out_q_config
        self.bias = bias
        self.output_tensor_only = output_tensor_only

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        out = super().forward(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        if self.output_tensor_only:
            return out[0]
        return out


[docs] class BertSelfAttentionInteger(_BertSelfAttentionBase):
[docs] def __init__( self, config, q_config: dict = None, out_q_config: dict = None, position_embedding_type=None, bias=True, floor=False, output_tensor_only=False, ) -> None: super().__init__( config, q_config, out_q_config, position_embedding_type, bias=bias, output_tensor_only=output_tensor_only, ) self.query = LinearInteger( config.hidden_size, config.hidden_size, config=q_config, out_config=out_q_config, bias=bias, floor=floor, ) self.key = LinearInteger( config.hidden_size, config.hidden_size, config=q_config, out_config=out_q_config, bias=bias, floor=floor, ) self.value = LinearInteger( config.hidden_size, config.hidden_size, config=q_config, out_config=out_q_config, bias=bias, floor=floor, ) # * Matmul is used for Q @ K^T and Scores @ V where the input values have already # * been casted to the output precision, so we provide the output precision to the # * software model self.matmul = partial( matmul_integer, config={ "data_in_width": self.q_config["data_out_width"], "data_in_frac_width": self.q_config["data_out_frac_width"], "weight_width": self.q_config["data_out_width"], "weight_frac_width": self.q_config["data_out_frac_width"], }, out_config=out_q_config, floor=floor, )