Source code for chop.nn.quantized.modules.conv2d

from functools import partial
from typing import Union, Optional
import torch

from torch import Tensor
from torch.nn.common_types import _size_2_t

# LUTNet
import numpy as np
from typing import Type
from chop.nn.quantizers.LUTNet.BaseTrainer import BaseTrainer, LagrangeTrainer
from chop.nn.quantizers.LUTNet.MaskBase import MaskBase, MaskExpanded
from chop.nn.quantizers import (
    residual_sign_quantizer,
    block_fp_quantizer,
    block_log_quantizer,
    block_minifloat_quantizer,
    integer_quantizer,
    log_quantizer,
    minifloat_denorm_quantizer,
    minifloat_ieee_quantizer,
    binary_quantizer,
    ternary_quantizer,
)
from ..utils import get_stats, quantiser_passthrough
import math

# LogicNets
from chop.nn.quantizers.LogicNets.utils import (
    generate_permutation_matrix,
    get_int_state_space,
    fetch_mask_indices,
)


class _Conv2dBase(torch.nn.Conv2d):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: _size_2_t | str = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = False,
        padding_mode: str = "zeros",
        device=None,
        dtype=None,
    ) -> None:
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
            device,
            dtype,
        )
        self.bypass = False
        self.x_quantizer = None
        self.w_quantizer = None
        self.b_quantizer = None
        self.pruning_masks = None

    def forward(self, x: Tensor) -> Tensor:
        if self.bypass:
            return self._conv_forward(x, self.weight, self.bias)
        x = self.x_quantizer(x)
        w = self.w_quantizer(self.weight)
        bias = self.b_quantizer(self.bias) if self.bias is not None else None
        # WARNING: this may have been simplified, we are assuming here the accumulation is lossless!
        # The addition size is in_channels * K * K
        return self._conv_forward(x, w, bias)

    def get_quantized_weight(self) -> Tensor:
        return self.w_quantizer(self.weight)

    def get_quantized_weights_with_inputs(self, x: Tensor) -> dict:
        x = self.x_quantizer(x)
        w = self.w_quantizer(self.weight)
        bias = self.b_quantizer(self.bias) if self.bias is not None else None
        y = self._conv_forward(x, w, bias)
        return {
            "x": x,
            "w": w,
            "bias": bias,
            "y": y,
        }

    def get_output_bitwidth(self) -> dict:
        raise NotImplementedError()


[docs] class Conv2dInteger(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type device=None, dtype=None, config=None, ) -> None: super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return # establish quantizers w_width, w_frac_width = config["weight_width"], config["weight_frac_width"] x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] # check bias quantizer, if not, use weight quantizer b_width, b_frac_width = config["bias_width"], config["bias_frac_width"] self.w_quantizer = partial( integer_quantizer, width=w_width, frac_width=w_frac_width ) self.x_quantizer = partial( integer_quantizer, width=x_width, frac_width=x_frac_width ) self.b_quantizer = partial( integer_quantizer, width=b_width, frac_width=b_frac_width )
# def get_output_bitwidth(self) -> dict: # config = self.config # w_width, w_frac = config["weight_width"], config["weight_frac_width"] # x_width, x_frac = config["data_in_width"], config["data_in_frac_width"] # bias_width = config["bias_width"] # ops = self.in_channels * self.kernel_size[0] * self.kernel_size[1] # product_width = w_width + x_width # product_frac_width = w_frac + x_frac # # *: +1 for bias # output_width = max(bias_width, product_width + ceil(log2(ops))) + 1 # output_frac_width = product_frac_width # o_bitwidth = {} # o_bitwidth["data_out_width"] = output_width # o_bitwidth["data_out_frac_width"] = output_frac_width # # o_bitwidth["product_width"] = product_width # # o_bitwidth["product_frac_width"] = product_frac_width # return o_bitwidth
[docs] class Conv2dMinifloatDenorm(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias"], ) x_width, x_exponent_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], ) b_width, b_exponent_width, b_exponent_bias = ( config["bias_width"], config["bias_exponent_width"], config["bias_exponent_bias"], ) self.w_quantizer = partial( minifloat_denorm_quantizer, width=w_width, exponent_width=w_exponent_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( minifloat_denorm_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( minifloat_denorm_quantizer, width=b_width, exponent_width=b_exponent_width, exponent_bias=b_exponent_bias, )
[docs] class Conv2dMinifloatIEEE(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_width, w_exponent_bias = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias"], ) x_width, x_exponent_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], ) b_width, b_exponent_width, b_exponent_bias = ( config["bias_width"], config["bias_exponent_width"], config["bias_exponent_bias"], ) self.w_quantizer = partial( minifloat_ieee_quantizer, width=w_width, exponent_width=w_exponent_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( minifloat_ieee_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( minifloat_ieee_quantizer, width=b_width, exponent_width=b_exponent_width, exponent_bias=b_exponent_bias, )
class Conv2dLog(_Conv2dBase): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_bias = ( config["weight_width"], config["weight_exponent_bias"], ) x_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_bias"], ) b_width, b_exponent_bias = ( config["bias_width"], config["bias_exponent_bias"], ) self.w_quantizer = partial( log_quantizer, width=w_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( log_quantizer, width=b_width, exponent_bias=b_exponent_bias, )
[docs] class Conv2dLog(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_bias = ( config["weight_width"], config["weight_exponent_bias"], ) x_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_bias"], ) b_width, b_exponent_bias = ( config["bias_width"], config["bias_exponent_bias"], ) self.w_quantizer = partial( log_quantizer, width=w_width, exponent_bias=w_exponent_bias, ) self.x_quantizer = partial( log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) self.b_quantizer = partial( log_quantizer, width=b_width, exponent_bias=b_exponent_bias, )
[docs] class Conv2dBlockFP(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_width, w_exponent_bias, w_block_size = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias"], config["weight_block_size"], ) x_width, x_exponent_width, x_exponent_bias, x_block_size = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], config["data_in_block_size"], ) b_width, b_exponent_width, b_exponent_bias, b_block_size = ( config["bias_width"], config["bias_exponent_width"], config["bias_exponent_bias"], config["bias_block_size"], ) # blocking/unblocking 4D kernel/feature map is not supported self.w_quantizer = partial( block_fp_quantizer, width=w_width, exponent_width=w_exponent_width, exponent_bias=w_exponent_bias, block_size=w_block_size, skip_first_dim=True, ) self.x_quantizer = partial( block_fp_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, block_size=x_block_size, skip_first_dim=True, ) self.b_quantizer = partial( block_fp_quantizer, width=b_width, exponent_width=b_exponent_width, exponent_bias=b_exponent_bias, block_size=b_block_size, skip_first_dim=False, )
[docs] def forward(self, x: Tensor) -> Tensor: if self.bypass: return self._conv_forward(x, self.weight, self.bias) x_shape = [i for i in x.shape] w_shape = [i for i in self.weight.shape] # a hack for handling 4D block/unblock x = torch.flatten(x, 0, 1) x = self.x_quantizer(x) x = torch.reshape(x, x_shape) w = torch.flatten(self.weight, 0, 1) w = self.w_quantizer(w) w = torch.reshape(w, w_shape) bias = self.b_quantizer(self.bias) if self.bias is not None else None # WARNING: this may have been simplified, we are assuming here the accumulation is lossless! # The addition size is in_channels * K * K return self._conv_forward(x, w, bias)
[docs] class Conv2dBlockMinifloat(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_width, w_exponent_bias_width, w_block_size = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias_width"], config["weight_block_size"], ) x_width, x_exponent_width, x_exponent_bias_width, x_block_size = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias_width"], config["data_in_block_size"], ) b_width, b_exponent_width, b_exponent_bias_width, b_block_size = ( config["bias_width"], config["bias_exponent_width"], config["bias_exponent_bias_width"], config["bias_block_size"], ) # blocking/unblocking 4D kernel/feature map is not supported self.w_quantizer = partial( block_minifloat_quantizer, width=w_width, exponent_width=w_exponent_width, exponent_bias_width=w_exponent_bias_width, block_size=w_block_size, skip_first_dim=True, ) self.x_quantizer = partial( block_minifloat_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias_width=x_exponent_bias_width, block_size=x_block_size, skip_first_dim=True, ) self.b_quantizer = partial( block_minifloat_quantizer, width=b_width, exponent_width=b_exponent_width, exponent_bias_width=b_exponent_bias_width, block_size=b_block_size, skip_first_dim=False, )
[docs] def forward(self, x: Tensor) -> Tensor: if self.bypass: return self._conv_forward(x, self.weight, self.bias) x_shape = [i for i in x.shape] w_shape = [i for i in self.weight.shape] x = torch.flatten(x, 0, 1) x = self.x_quantizer(x) x = torch.reshape(x, x_shape) w = torch.flatten(self.weight, 0, 1) w = self.w_quantizer(w) w = torch.reshape(w, w_shape) bias = self.b_quantizer(self.bias) if self.bias is not None else None # WARNING: this may have been simplified, we are assuming here the accumulation is lossless! # The addition size is in_channels * K * K return self._conv_forward(x, w, bias)
[docs] class Conv2dBlockLog(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: _size_2_t | str = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, config: dict = None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_width, w_exponent_bias_width, block_size = ( config["weight_width"], config["weight_exponent_bias_width"], config["weight_block_size"], ) x_width, x_exponent_bias_width, block_size = ( config["data_in_width"], config["data_in_exponent_bias_width"], config["data_in_block_size"], ) b_width, b_exponent_bias_width, block_size = ( config["bias_width"], config["bias_exponent_bias_width"], config["bias_block_size"], ) # blocking/unblocking 4D kernel/feature map is not supported self.w_quantizer = partial( block_log_quantizer, width=w_width, exponent_bias_width=w_exponent_bias_width, block_size=block_size, skip_first_dim=True, ) self.x_quantizer = partial( block_log_quantizer, width=x_width, exponent_bias_width=x_exponent_bias_width, block_size=block_size, skip_first_dim=True, ) self.b_quantizer = partial( block_log_quantizer, width=b_width, exponent_bias_width=b_exponent_bias_width, block_size=block_size, skip_first_dim=False, )
[docs] def forward(self, x: Tensor) -> Tensor: if self.bypass: return self._conv_forward(x, self.weight, self.bias) x_shape = [i for i in x.shape] w_shape = [i for i in self.weight.shape] x = torch.flatten(x, 0, 1) x = self.x_quantizer(x) x = torch.reshape(x, x_shape) w = torch.flatten(self.weight, 0, 1) w = self.w_quantizer(w) w = torch.reshape(w, w_shape) bias = self.b_quantizer(self.bias) if self.bias is not None else None # WARNING: this may have been simplified, we are assuming here the accumulation is lossless! # The addition size is in_channels * K * K return self._conv_forward(x, w, bias)
[docs] class Conv2dBinary(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type device=None, dtype=None, config=None, ) -> None: super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False, padding_mode=padding_mode, device=device, dtype=dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return # establish quantizers x_stochastic, b_stochastic, w_stochastic = ( config["data_in_stochastic"], config["bias_stochastic"], config["weight_stochastic"], ) x_bipolar, b_bipolar, w_bipolar = ( config["data_in_bipolar"], config["bias_bipolar"], config["weight_bipolar"], ) self.w_quantizer = partial( binary_quantizer, stochastic=w_stochastic, bipolar=w_bipolar ) self.x_quantizer = partial( binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar ) self.b_quantizer = partial( binary_quantizer, stochastic=b_stochastic, bipolar=b_bipolar )
[docs] class Conv2dBinaryScaling(_Conv2dBase): """ Binary scaling variant of the conv2d transformation layer. - "bypass": Bypass quantization for standard linear transformation. - "data_in_stochastic", "bias_stochastic", "weight_stochastic": Stochastic settings. - "data_in_bipolar", "bias_bipolar", "weight_bipolar": Bipolar settings. - "binary_training": Apply binary scaling during training. """
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type device=None, dtype=None, config=None, ): super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return # stdv = 1 / np.sqrt( # self.kernel_size[0] * self.kernel_size[1] * self.in_channels # ) # w = np.random.normal( # loc=0.0, # scale=stdv, # size=[ # self.out_channels, # self.in_channels, # self.kernel_size[0], # self.kernel_size[1], # ], # ).astype(np.float32) # self.weight = nn.Parameter(torch.tensor(w, requires_grad=True)) self.gamma = torch.nn.Parameter(torch.tensor(1.0, requires_grad=True)) self.binary_training = True x_stochastic, b_stochastic, w_stochastic = ( config["data_in_stochastic"], config["bias_stochastic"], config["weight_stochastic"], ) x_bipolar, b_bipolar, w_bipolar = ( config["data_in_bipolar"], config["bias_bipolar"], config["weight_bipolar"], ) self.w_quantizer = partial( binary_quantizer, stochastic=w_stochastic, bipolar=w_bipolar ) self.b_quantizer = quantiser_passthrough self.x_quantizer = quantiser_passthrough
[docs] def forward(self, x: Tensor) -> Tensor: if self.bypass: return self._conv_forward(x, self.weight, self.bias) if self.binary_training: x = self.x_quantizer(x) w = self.w_quantizer(self.weight) bias = self.b_quantizer(self.bias) if self.bias is not None else None # WARNING: this may have been simplified, we are assuming here the accumulation is lossless! # The addition size is in_channels * K * K return self._conv_forward(x, w * self.gamma.abs(), bias) else: self.weight.data.clamp_(-1, 1) return self._conv_forward(x, self.weight * self.gamma.abs(), self.bias)
[docs] class Conv2dBinaryResidualSign(_Conv2dBase): """ Binary conv2d layer with redisual sign variant of the linear transformation layer. - "bypass": Bypass quantization for standard linear transformation. - "data_in_stochastic", "bias_stochastic", "weight_stochastic": Stochastic settings. - "data_in_bipolar", "bias_bipolar", "weight_bipolar": Bipolar settings. - "binary_training": Apply binary scaling during training. - "data_in_levels": The num of residual layers to use. - "data_in_residual_sign" : Apply residual sign on input """
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type device=None, dtype=None, config=None, ): super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return # residual_config self.levels, self.binary_training = ( config.get("data_in_levels", 2), config["binary_training"], ) self.gamma = torch.nn.Parameter(torch.tensor(1.0, requires_grad=True)) # Create a torch.nn.Parameter from the means tensor ars = np.arange(self.levels) + 1.0 ars = ars[::-1] means = ars / np.sum(ars) self.means = ( torch.nn.Parameter( torch.tensor(means, dtype=torch.float32, requires_grad=True) ) if self.config.get("data_in_residual_sign", True) else None ) # pruning masks self.pruning_masks = torch.nn.Parameter( torch.ones_like(self.weight), requires_grad=False ) x_stochastic, w_stochastic = ( config["data_in_stochastic"], config["weight_stochastic"], ) x_bipolar, w_bipolar = ( config["data_in_bipolar"], config["weight_bipolar"], ) self.w_quantizer = partial( binary_quantizer, stochastic=w_stochastic, bipolar=w_bipolar ) self.x_quantizer = partial( binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar )
[docs] def forward(self, x: Tensor) -> Tensor: if self.bypass: return self._conv_forward(x, self.weight, self.bias) x_expanded = 0 if self.means is not None: out_bin = residual_sign_quantizer( levels=self.levels, x_quantizer=self.x_quantizer, means=self.means, x=x ) for l in range(self.levels): x_expanded = x_expanded + out_bin[l, :, :] else: x_expanded = x if self.binary_training: w = self.w_quantizer(self.weight) bias = self.b_quantizer(self.bias) if self.bias is not None else None # WARNING: this may have been simplified, we are assuming here the accumulation is lossless! # The addition size is in_channels * K * K return self._conv_forward( x_expanded, w * self.gamma.abs() * self.pruning_masks, bias ) else: # print(self.gamma.abs(), self.pruning_masks) self.weigh = self.weight.data.clamp_(-1, 1) return self._conv_forward( x_expanded, self.weight * self.gamma.abs() * self.pruning_masks, self.bias, )
[docs] class Conv2dTernary(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type device=None, dtype=None, config=None, ) -> None: super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return # establish quantizers self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_scaling_factor = config["weight_scaling_factor"] w_mean = get_stats(config, "weight_mean") w_median = get_stats(config, "weight_median") w_max = get_stats(config, "weight_max") self.w_quantizer = partial( ternary_quantizer, scaling_factor=w_scaling_factor, maximum=w_max, median=w_median, mean=w_mean, ) self.x_quantizer = quantiser_passthrough self.b_quantizer = quantiser_passthrough
# self.b_quantizer = partial( # ternary_quantizer, # scaling_factor=b_scaling_factor, # maximum=b_max, # median=b_median, # mean=b_mean, # )
[docs] class Conv2dLUT(torch.nn.Module): in_channels: int out_channels: int kernel_size: tuple stride: tuple padding: bool dilation: tuple groups: tuple bias: torch.Tensor padding_mode: str input_dim: tuple device: Optional[str] input_mask: torch.Tensor k: int trainer: BaseTrainer mask_builder_type: Type[MaskBase] mask_builder: MaskBase tables_count: int
[docs] def __init__( self, config: None, # mase configuration in_channels: int, out_channels: int, kernel_size: Union[int, tuple], stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, groups: Union[int, tuple] = 1, bias: bool = True, padding_mode: str = "zeros", trainer_type: Type[BaseTrainer] = LagrangeTrainer, mask_builder_type: Type[MaskBase] = MaskExpanded, # k: int = 2, # binarization_level: int = 0, # input_expanded: bool = True, # input_dim: Union[int, tuple] = None, device: str = None, ): super(Conv2dLUT, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = torch.nn.modules.utils._pair(kernel_size) self.stride = torch.nn.modules.utils._pair(stride) self.padding = torch.nn.modules.utils._pair(padding) self.dilation = torch.nn.modules.utils._pair(dilation) self.mask_builder_type = mask_builder_type self.groups = torch.nn.modules.utils._pair(groups) self.bias = None self.padding_mode = padding_mode # LUT attributes self.k = config["data_in_k"] self.kk = 2 ** config["data_in_k"] self.levels = config.get("data_in_levels", 2) self.input_expanded = config["data_in_input_expanded"] self.input_dim = torch.nn.modules.utils._pair(config["data_in_dim"]) self.device = device self.input_mask = self._input_mask_builder() self.tables_count = self.mask_builder.get_tables_count() self.trainer = trainer_type( levels=self.levels, tables_count=self.tables_count, k=config["data_in_k"], binarization_level=( # binarization_level 1 is binarized weight, 0 is not binarized 1 if config["data_in_binarization_level"] == 1 else 0 ), input_expanded=config["data_in_input_expanded"], device=device, ) self.unfold = torch.nn.Unfold( kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride, ) self.fold = torch.nn.Fold( output_size=(self._out_dim(0), self._out_dim(1)), kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride, ) self.weight = self.trainer.weight # TODO: Does this work? self.pruning_masks = self.trainer.pruning_masks # Residual sign code self.x_quantizer = partial(binary_quantizer, stochastic=False, bipolar=True) ars = np.arange(self.levels) + 1.0 ars = ars[::-1] means = ars / np.sum(ars) self.means = torch.nn.Parameter( torch.tensor(means, dtype=torch.float32, requires_grad=True) ) # pruning masks self.pruning_masks = torch.nn.Parameter( torch.ones_like(self.weight), requires_grad=False )
def _get_kernel_selections(self, channel_id): result = [] for kh_index in range(self.kernel_size[0]): for kw_index in range(self.kernel_size[1]): result.append((channel_id, kh_index, kw_index)) return set(result) def _table_input_selections(self): result = [] for out_index in range(self.out_channels): for input_index in range(self.in_channels): selections = self._get_kernel_selections( input_index ) # [(channel_id, kh, kw)...] 9 for kh_index in range(self.kernel_size[0]): for kw_index in range(self.kernel_size[1]): conv_index = (input_index, kh_index, kw_index) sub_selections = list(selections - set([conv_index])) result.append( (conv_index, sub_selections) ) # [(channel_id, kh, kw),[(channel_id, kh, kw)...]] 9 return result def _input_mask_builder(self) -> torch.Tensor: result = [] selections = self._table_input_selections() # [kw * kh * ic * oc] self.mask_builder = self.mask_builder_type(self.k, selections, True) result.append(self.mask_builder.build()) return np.concatenate(result) def _out_dim(self, dim): _out = ( self.input_dim[dim] + 2 * self.padding[dim] - self.dilation[dim] * (self.kernel_size[dim] - 1) - 1 ) / self.stride[dim] return math.floor(_out + 1)
[docs] def forward( self, input: torch.Tensor, # [10, 256, 3, 3] targets: torch.tensor = None, initalize: bool = False, ): assert len(input.shape) == 4 batch_size = input.shape[0] folded_input = self.unfold(input).transpose(1, 2) # [10, 1, 2304] folded_input = residual_sign_quantizer( levels=self.levels, x_quantizer=self.x_quantizer, means=self.means, x=folded_input, ) folded_input = folded_input.view( self.levels, batch_size, -1, self.in_channels, self.kernel_size[0], self.kernel_size[1], ) # [levels, batch_size, 1, in_channels, kernel_size[0], kernel_size[1]] # print(self.input_mask.shape) # [1179648, 3] 256*256*9*2 NOTE: each element in the kernal corespond a table expanded_input = folded_input[ :, :, :, self.input_mask[:, 0], self.input_mask[:, 1], self.input_mask[:, 2], ] # [levels, batch_size, 1, in_channels * kernel_size[0] * kernel_size[1] * k] # [10, 1, 1179648] output = self.trainer( expanded_input, targets, initalize ).squeeze() # [10, 589824] output = output.view( batch_size, self.out_channels, self._out_dim(0), self._out_dim(1), -1, ).sum( -1 ) # [10, 256, 1, 1, 2304] -> [10, 256, 1, 1] output = output.view( batch_size, self._out_dim(0) * self._out_dim(1), -1 ).transpose( 1, 2 ) # [10, 1, 256] -> [10, 256, 1] output = output.view( batch_size, self.out_channels, self._out_dim(0), self._out_dim(1) ) # [10, 256, 1, 1] return output
[docs] def pre_initialize(self): self.trainer.clear_initializion()
[docs] def update_initialized_weights(self): self.trainer.update_initialized_weights()
[docs] class Conv2DLogicNets(_Conv2dBase):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", # TODO: refine this type device=None, dtype=None, config=None, ) -> None: super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return self.in_features = in_channels * kernel_size[0] * kernel_size[1] self.out_features = out_channels * kernel_size[0] * kernel_size[1] # establish quantizers self.x_width, self.x_frac_width = ( config["data_in_width"], config["data_in_frac_width"], ) self.y_width, self.y_frac_width = ( config["data_out_width"], config["data_out_frac_width"], ) self.x_quantizer = partial( integer_quantizer, width=self.x_width, frac_width=self.x_frac_width ) self.y_quantizer = partial( integer_quantizer, width=self.y_width, frac_width=self.y_frac_width ) self.is_lut_inference = True self.neuron_truth_tables = None
# self.calculate_truth_tables() # We will call this explicitly during the transform
[docs] def table_lookup( self, connected_input: Tensor, input_perm_matrix: Tensor, bin_output_states: Tensor, ) -> Tensor: fan_in_size = connected_input.shape[1] ci_bcast = connected_input.unsqueeze(2) # Reshape to B x Fan-in x 1 pm_bcast = input_perm_matrix.t().unsqueeze( 0 ) # Reshape to 1 x Fan-in x InputStates eq = (ci_bcast == pm_bcast).sum( dim=1 ) == fan_in_size # Create a boolean matrix which matches input vectors to possible input states matches = eq.sum(dim=1) # Count the number of perfect matches per input vector if not (matches == torch.ones_like(matches, dtype=matches.dtype)).all(): raise Exception( f"One or more vectors in the input is not in the possible input state space" ) indices = torch.argmax(eq.type(torch.int64), dim=1) return bin_output_states[indices]
[docs] def lut_forward(self, x: Tensor) -> Tensor: x = torch.flatten( x, 1 ) # N - added this; is 1 needed to flatten all dims except batch? # if self.apply_input_quant: # x = self.input_quant(x) # Use this to fetch the bin output of the input, if the input isn't already in binary format x = self.encode(self.x_quantizer(x)) y = torch.zeros((x.shape[0], self.out_features)) # Perform table lookup for each neuron output for i in range(self.out_features): indices, input_perm_matrix, bin_output_states = self.neuron_truth_tables[i] connected_input = x[:, indices] y[:, i] = self.table_lookup( connected_input, input_perm_matrix, bin_output_states ) return y
[docs] def construct_mask_index(self): # contract a mask have the same shape as self.weight but with zero element being assign to zero and other assign to 1 self.mask = torch.where( self.weight != 0, torch.tensor(1), torch.tensor(0) ).reshape( self.weight.shape[0], -1 ) # pay attention to dimension (out_feature, in_feature)
# Consider using masked_select instead of fetching the indices
[docs] def calculate_truth_tables(self): with torch.no_grad(): # Precalculate all of the input value permutations input_state_space = list() # TODO: is a list the right data-structure here? bin_state_space = list() # get a neuron_state for m in range(self.in_features): neuron_state_space = self.decode( get_int_state_space(self.x_width) ) # TODO: this call should include the index of the element of interest bin_space = get_int_state_space( self.x_width ) # TODO: this call should include the index of the element of interest input_state_space.append(neuron_state_space) bin_state_space.append(bin_space) neuron_truth_tables = list() self.construct_mask_index() # construct prunning mask for n in range(self.out_features): input_mask = self.mask[ n, : ] # N: select row of mask tensor that corresponds to the output feature on this iteration fan_in = torch.sum(input_mask) indices = fetch_mask_indices(input_mask) # Generate a matrix containing all possible input states input_permutation_matrix = generate_permutation_matrix( [input_state_space[i] for i in indices] ) bin_input_permutation_matrix = generate_permutation_matrix( [bin_state_space[i] for i in indices] ) # print("in_feature={}, out_feature={}, kernel={}".format(self.in_features, self.out_features, self.kernel_size)) # print("fan_in", fan_in, "indices", indices, "input_permutation_matrix", input_permutation_matrix.shape, [input_state_space[i] for i in indices], "bin_input_permutation_matrix", bin_input_permutation_matrix.shape, [bin_state_space[i] for i in indices]) num_permutations = input_permutation_matrix.shape[0] padded_perm_matrix = torch.zeros((num_permutations, self.in_features)) padded_perm_matrix[:, indices] = input_permutation_matrix # print("input", padded_perm_matrix.shape) # TODO: Update this block to just run inference on the fc layer, once BN has been moved to output_quant bin_output_states = self.encode(self.math_forward(padded_perm_matrix))[ :, n ] # Calculate bin for the current input # Append the connectivity, input permutations and output permutations to the neuron truth tables neuron_truth_tables.append( (indices, bin_input_permutation_matrix, bin_output_states) ) # Change this to be the binary output states self.neuron_truth_tables = neuron_truth_tables
[docs] def forward(self, x: Tensor) -> Tensor: if self.is_lut_inference: return self.decode(self.lut_forward(x))
[docs] def encode(self, input: Tensor) -> Tensor: return input * 2**self.x_frac_width
[docs] def decode(self, input: Tensor) -> Tensor: return input / 2**self.x_frac_width
[docs] def math_forward(self, input: Tensor) -> Tensor: return self.y_quantizer( self._conv_forward(self.x_quantizer(input), self.weight, self.bias) )