Source code for chop.nn.quantized.modules.conv1d

from functools import partial
from typing import Union
import torch

from torch import Tensor
from torch.nn.common_types import _size_1_t

from chop.nn.quantizers import (
    block_fp_quantizer,
    block_log_quantizer,
    block_minifloat_quantizer,
    integer_quantizer,
    log_quantizer,
    minifloat_denorm_quantizer,
    minifloat_ieee_quantizer,
    binary_quantizer,
    ternary_quantizer,
)
from ..utils import get_stats


class _Conv1dBase(torch.nn.Conv1d):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_1_t,
        stride: _size_1_t = 1,
        padding: _size_1_t | str = 0,
        dilation: _size_1_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        device=None,
        dtype=None,
    ) -> None:
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
            device,
            dtype,
        )
        self.bypass = False
        self.w_quantizer = None
        self.x_quantizer = None
        self.b_quantizer = None
        self.pruning_masks = None

    def forward(self, x: Tensor) -> Tensor:
        if self.bypass:
            return self._conv_forward(x, self.weight, self.bias)
        x = self.x_quantizer(x)
        w = self.w_quantizer(self.weight)
        bias = self.b_quantizer(self.bias) if self.bias is not None else None
        # WARNING: this may have been simplified, we are assuming here the accumulation is lossless!
        # The addition size is in_channels * K * K
        return self._conv_forward(x, w, bias)

    # def get_quantized_weight(self) -> Tensor:
    #     return self.w_quantizer(self.weight)

    # def get_quantized_weights_with_inputs(self, x: Tensor) -> Tensor:
    #     x = self.x_quantizer(x)
    #     w = self.w_quantizer(self.weight)
    #     bias = self.b_quantizer(self.bias) if self.bias is not None else None
    #     y = self._conv_forward(x, w, bias)
    #     return {
    #         "x": x,
    #         "w": w,
    #         "bias": bias,
    #         "y": y,
    #     }

    # def get_output_bitwidth(self):
    #     """output bit width info for HW gen"""
    #     raise NotImplementedError()


[docs] class Conv1dInteger(_Conv1dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type device=None, dtype=None, config=None, ) -> None: super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return # establish quantizers w_width, w_frac_width = config["weight_width"], config["weight_frac_width"] x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] b_width, b_frac_width = config["bias_width"], config["bias_frac_width"] self.w_quantizer = partial( integer_quantizer, width=w_width, frac_width=w_frac_width ) self.x_quantizer = partial( integer_quantizer, width=x_width, frac_width=x_frac_width ) self.b_quantizer = partial( integer_quantizer, width=b_width, frac_width=b_frac_width )
# def get_output_bitwidth(self): # config = self.config # w_width, w_frac = config["weight_width"], config["weight_frac_width"] # x_width, x_frac = config["data_in_width"], config["data_in_frac_width"] # bias_width = config["bias_width"] # ops = self.in_channels * self.kernel_size[0] # product_width = w_width + x_width # product_frac_width = w_frac + x_frac # # *: +1 for bias # output_width = max(bias_width, product_width + ceil(log2(ops))) + 1 # output_frac_width = product_frac_width # o_bitwidth = {} # o_bitwidth["data_out_width"] = output_width # o_bitwidth["data_out_frac_width"] = output_frac_width # # output_bitwidth_info["product_width"] = product_width # # output_bitwidth_info["product_frac_width"] = product_frac_width # return o_bitwidth
[docs] class Conv1dMinifloatDenorm(_Conv1dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias"], ) x_width, x_exponent_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], ) b_width, b_exponent_width, b_exponent_bias = ( config["bias_width"], config["bias_exponent_width"], config["bias_exponent_bias"], ) self.w_quantizer = partial( minifloat_denorm_quantizer, width=w_width, exponent_width=w_exponent_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( minifloat_denorm_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( minifloat_denorm_quantizer, width=b_width, exponent_width=b_exponent_width, exponent_bias=b_exponent_bias, )
[docs] class Conv1dLog(_Conv1dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_bias = ( config["weight_width"], config["weight_exponent_bias"], ) x_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_bias"], ) b_width, b_exponent_bias = ( config["bias_width"], config["bias_exponent_bias"], ) self.w_quantizer = partial( log_quantizer, width=w_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( log_quantizer, width=b_width, exponent_bias=b_exponent_bias, )
[docs] class Conv1dMinifloatIEEE(_Conv1dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias"], ) x_width, x_exponent_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], ) b_width, b_exponent_width, b_exponent_bias = ( config["bias_width"], config["bias_exponent_width"], config["bias_exponent_bias"], ) self.w_quantizer = partial( minifloat_ieee_quantizer, width=w_width, exponent_width=w_exponent_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( minifloat_ieee_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( minifloat_ieee_quantizer, width=b_width, exponent_width=b_exponent_width, exponent_bias=b_exponent_bias, )
[docs] class Conv1dBlockFP(_Conv1dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_width, w_exponent_bias, w_block_size = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias"], config["weight_block_size"], ) x_width, x_exponent_width, x_exponent_bias, x_block_size = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], config["data_in_block_size"], ) b_width, b_exponent_width, b_exponent_bias, b_block_size = ( config["bias_width"], config["bias_exponent_width"], config["bias_exponent_bias"], config["bias_block_size"], ) self.w_quantizer = partial( block_fp_quantizer, width=w_width, exponent_width=w_exponent_width, exponent_bias=w_exponent_bias, block_size=w_block_size, skip_first_dim=True, ) self.x_quantizer = partial( block_fp_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, block_size=x_block_size, skip_first_dim=True, ) self.b_quantizer = partial( block_fp_quantizer, width=b_width, exponent_width=b_exponent_width, exponent_bias=b_exponent_bias, block_size=b_block_size, skip_first_dim=False, )
[docs] class Conv1dBlockMinifloat(_Conv1dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_width, w_exponent_bias_width, w_block_size = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias_width"], config["weight_block_size"], ) x_width, x_exponent_width, x_exponent_bias_width, x_block_size = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias_width"], config["data_in_block_size"], ) b_width, b_exponent_width, b_exponent_bias_width, b_block_size = ( config["bias_width"], config["bias_exponent_width"], config["bias_exponent_bias_width"], config["bias_block_size"], ) self.w_quantizer = partial( block_minifloat_quantizer, width=w_width, exponent_width=w_exponent_width, exponent_bias_width=w_exponent_bias_width, block_size=w_block_size, skip_first_dim=True, ) self.x_quantizer = partial( block_minifloat_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias_width=x_exponent_bias_width, block_size=x_block_size, skip_first_dim=True, ) self.b_quantizer = partial( block_minifloat_quantizer, width=b_width, exponent_width=b_exponent_width, exponent_bias_width=b_exponent_bias_width, block_size=b_block_size, skip_first_dim=False, )
[docs] class Conv1dBlockLog(_Conv1dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: _size_1_t | str = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return width, exponent_bias_width, block_size = ( config["weight_width"], config["weight_exponent_bias_width"], config["weight_block_size"], ) x_width, x_exponent_bias_width, x_block_size = ( config["data_in_width"], config["data_in_exponent_bias_width"], config["data_in_block_size"], ) b_width, b_exponent_bias_width, b_block_size = ( config["bias_width"], config["bias_exponent_bias_width"], config["bias_block_size"], ) self.w_quantizer = partial( block_log_quantizer, width=width, exponent_bias_width=exponent_bias_width, block_size=block_size, skip_first_dim=True, ) self.x_quantizer = partial( block_log_quantizer, width=x_width, exponent_bias_width=x_exponent_bias_width, block_size=x_block_size, skip_first_dim=True, ) self.b_quantizer = partial( block_log_quantizer, width=b_width, exponent_bias_width=b_exponent_bias_width, block_size=b_block_size, skip_first_dim=False, )
[docs] class Conv1dBinary(_Conv1dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type device=None, dtype=None, config=None, ) -> None: super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return x_stochastic = config["data_in_stochastic"] x_bipolar = config["data_in_bipolar"] self.w_quantizer = partial( binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar ) self.x_quantizer = partial( binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar ) self.b_quantizer = partial( binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar )
[docs] class Conv1dTernary(_Conv1dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: _size_1_t | str = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config=None, ) -> None: super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return x_scaling_factor = config["data_in_scaling_factor"] w_scaling_factor = config["weight_scaling_factor"] b_scaling_factor = config["bias_scaling_factor"] x_mean = get_stats(config, "data_in_mean") x_median = get_stats(config, "data_in_median") x_max = get_stats(config, "data_in_max") w_mean = get_stats(config, "weight_mean") w_median = get_stats(config, "weight_median") w_max = get_stats(config, "weight_max") b_mean = get_stats(config, "bias_mean") b_median = get_stats(config, "bias_median") b_max = get_stats(config, "bias_max") self.x_quantizer = partial( ternary_quantizer, scaling_factor=x_scaling_factor, maximum=x_max, median=x_median, mean=x_mean, ) self.w_quantizer = partial( ternary_quantizer, scaling_factor=w_scaling_factor, maximum=w_max, median=w_median, mean=w_mean, ) self.b_quantizer = partial( ternary_quantizer, scaling_factor=b_scaling_factor, maximum=b_max, median=b_median, mean=b_mean, )