Source code for chop.nn.quantized.modules.linear

from functools import partial

from chop.nn.quantized.functional.linear import (
    linearBinary,
    linearBinaryScaling,
    linearBlockFP,
    linearBlockLog,
    linearBlockMinifloat,
    linearInteger,
    linearLog,
    linearMXIntHardware,
    linearMinifloatDenorm,
    linearMinifloatIEEE,
    linearTernary,
)
import torch
from torch import Tensor
from torch.nn import functional as F


from ..utils import get_stats, quantiser_passthrough

from chop.nn.quantizers import (
    residual_sign_quantizer,
    block_fp_quantizer,
    block_log_quantizer,
    block_minifloat_quantizer,
    integer_quantizer,
    integer_floor_quantizer,
    log_quantizer,
    minifloat_denorm_quantizer,
    minifloat_ieee_quantizer,
    binary_quantizer,
    ternary_quantizer,
    mxint_hardware,
)

# LUTNet
import numpy as np
from typing import Type
from chop.nn.quantizers.LUTNet.BaseTrainer import BaseTrainer, LagrangeTrainer
from chop.nn.quantizers.LUTNet.MaskBase import MaskBase, MaskExpanded

# LogicNets
from chop.nn.quantizers.LogicNets.utils import (
    generate_permutation_matrix,
    get_int_state_space,
    fetch_mask_indices,
)

# LogicNets
from chop.nn.quantizers.LogicNets.utils import (
    generate_permutation_matrix,
    get_int_state_space,
    fetch_mask_indices,
)


class _LinearBase(torch.nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__(
            in_features,
            out_features,
            bias,
            device,
            dtype,
        )
        self.bypass = False
        self.pruning_masks = None
        # NOTE: Quantizers properties are not needed for now
        # self.x_quantizer = None
        # self.w_quantizer = None
        # self.b_quantizer = None
        # self.out_quantizer = None

    # NOTE: This is not needed for now
    # def forward(self, x: Tensor) -> Tensor:
    #     if self.bypass:
    #         # if bypass, there is no quantization
    #         return F.linear(x, self.weight, self.bias)
    #     else:
    #         x = self.x_quantizer(x)
    #         w = self.w_quantizer(self.weight)
    #         bias = self.b_quantizer(self.bias) if self.bias is not None else None
    #         out = F.linear(x, w, bias)
    #         if self.out_quantizer is None:
    #             return out
    #         return self.out_quantizer(out)


[docs] class LinearInteger(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, out_config=None, floor=False, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.out_config = out_config self.bypass = config.get("bypass", False) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearInteger(x, self.weight, self.bias, self.config, self.out_config)
[docs] class LinearMinifloatDenorm(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearMinifloatDenorm(x, self.weight, self.bias, self.config)
[docs] class LinearMinifloatIEEE(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearMinifloatIEEE(x, self.weight, self.bias, self.config)
[docs] class LinearLog(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearLog(x, self.weight, self.bias, self.config)
[docs] class LinearBlockFP(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearBlockFP(x, self.weight, self.bias, self.config)
[docs] class LinearBlockMinifloat(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearBlockMinifloat(x, self.weight, self.bias, self.config)
[docs] class LinearBlockLog(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearBlockLog(x, self.weight, self.bias, self.config)
[docs] class LinearBinary(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearBinary(x, self.weight, self.bias, self.config)
[docs] class LinearBinaryScaling(_LinearBase): """ Binary scaling variant of the linear transformation layer. - "bypass": Bypass quantization for standard linear transformation. - "data_in_stochastic", "bias_stochastic", "weight_stochastic": Stochastic settings. - "data_in_bipolar", "bias_bipolar", "weight_bipolar": Bipolar settings. - "binary_training": Apply binary scaling during training. """
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) # self.gamma = torch.nn.Parameter(torch.tensor(1.0, requires_grad=True)) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearBinaryScaling(x, self.weight, self.bias, self.config)
[docs] class LinearTernary(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_scaling_factor = config["weight_scaling_factor"] w_mean = get_stats(config, "weight_mean") w_median = get_stats(config, "weight_median") w_max = get_stats(config, "weight_max") self.w_quantizer = partial( ternary_quantizer, scaling_factor=w_scaling_factor, maximum=w_max, median=w_median, mean=w_mean, ) self.x_quantizer = quantiser_passthrough self.b_quantizer = quantiser_passthrough
# self.b_quantizer = partial( # ternary_quantizer, # scaling_factor=b_scaling_factor, # maximum=b_max, # median=b_median, # mean=b_mean, # )
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearTernary(x, self.weight, self.bias, self.config)
# LUT
[docs] class LinearBinaryResidualSign(_LinearBase): """ Binary Linear layer with redisual sign variant of the linear transformation layer. - "bypass": Bypass quantization for standard linear transformation. - "data_in_stochastic", "bias_stochastic", "weight_stochastic": Stochastic settings. - "data_in_bipolar", "bias_bipolar", "weight_bipolar": Bipolar settings. - "binary_training": Apply binary scaling during training. - "data_in_levels": The num of residual layers to use. - "data_in_residual_sign" : Apply residual sign on input """
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.levels = config.get("data_in_levels", 2) # NOTE: Hardcode 2 for now self.config = config self.bypass = config.get("bypass", False) self.gamma = torch.nn.Parameter(torch.tensor(1.0, requires_grad=True)) # Initialized parameter ars = np.arange(self.levels) + 1.0 ars = ars[::-1] means = ars / np.sum(ars) # Create a torch.nn.Parameter from the means tensor self.means = ( torch.nn.Parameter( torch.tensor(means, dtype=torch.float32, requires_grad=True) ) if self.config.get("data_in_residual_sign", True) else None ) # prunning masks self.pruning_masks = torch.nn.Parameter( torch.ones_like(self.weight), requires_grad=False ) if self.bypass: return x_stochastic, w_stochastic = ( config["data_in_stochastic"], config["weight_stochastic"], ) x_bipolar, w_bipolar = ( config["data_in_bipolar"], config["weight_bipolar"], ) self.binary_training = config["binary_training"] self.w_quantizer = partial( binary_quantizer, stochastic=w_stochastic, bipolar=w_bipolar ) self.x_quantizer = partial( binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar )
[docs] def forward(self, x: Tensor) -> Tensor: if self.bypass: # if bypss, there is no quantization return F.linear(x, self.weight, self.bias) x_expanded = 0 if self.means is not None: out_bin = residual_sign_quantizer( levels=self.levels, x_quantizer=self.x_quantizer, means=self.means, x=x ) for l in range(self.levels): x_expanded = x_expanded + out_bin[l, :, :] else: x_expanded = x if self.binary_training: w = self.w_quantizer(self.weight) return F.linear( x_expanded, w * self.gamma.abs() * self.pruning_masks, self.bias, ) else: self.weigh = self.weight.data.clamp_(-1, 1) return F.linear( x_expanded, self.weight * self.gamma.abs() * self.pruning_masks, self.bias, )
[docs] class LinearLUT(torch.nn.Module): input_mask: torch.Tensor tables_count: int in_features: int out_features: int trainer: BaseTrainer mask_builder_type: Type[MaskBase] mask_builder: MaskBase
[docs] def __init__( self, config: None, in_features: int, out_features: int, mask_builder_type: Type[MaskBase] = MaskExpanded, trainer_type: Type[BaseTrainer] = LagrangeTrainer, bias: bool = True, device: str = None, ) -> None: super(LinearLUT, self).__init__() self.in_features = in_features self.out_features = out_features self.levels = config.get("data_in_levels", 2) self.input_expanded = config["data_in_input_expanded"] self.k = config["data_in_k"] self.kk = 2 ** config["data_in_k"] self.mask_builder_type = mask_builder_type # Initialize mask builder self.input_mask = self._input_mask_builder() # TODO: table * output feature map self.tables_count = self.mask_builder.get_tables_count() * self.out_features self.trainer = trainer_type( levels=self.levels, tables_count=self.tables_count, k=config["data_in_k"], binarization_level=(1 if config["data_in_binarization_level"] == 1 else 0), input_expanded=config["data_in_input_expanded"], device=device, ) self.weight = self.trainer.weight self.pruning_masks = self.trainer.pruning_masks # TODO: we might need to this later on # stdv = 1 / np.sqrt(self.in_features) # w = np.random.normal(loc=0.0, scale=stdv, size=list(self.trainer.weight.shape)).astype(np.float32) # self.trainer.weight = torch.nn.Parameter( # torch.tensor(w, requires_grad=True)) self.bias = ( torch.nn.Linear(1, out_features, device=device).bias if bias else None ) # Residual sign code self.x_quantizer = partial(binary_quantizer, stochastic=False, bipolar=True) ars = np.arange(self.levels) + 1.0 ars = ars[::-1] means = ars / np.sum(ars) self.means = torch.nn.Parameter( torch.tensor(means, dtype=torch.float32, requires_grad=True) )
def _table_input_selections_builder(self) -> np.array: _all_inputs_set = set(range(self.in_features)) result = [] for in_idx in range(self.in_features): _idx_set = set([in_idx]) _selection = list(_all_inputs_set - _idx_set) result.append((in_idx, _selection)) return result def _input_mask_builder(self) -> torch.Tensor: """ Initializing table (using indices for the connections) """ result = [] # TODO: elements can appear more than once in the feature-1 input? for _ in range(self.out_features): self.mask_builder = self.mask_builder_type( self.k, self._table_input_selections_builder(), True ) result.append(self.mask_builder.build()) return np.concatenate(result)
[docs] def forward( self, input: torch.Tensor, targets: torch.tensor = None, initalize: bool = False, ): assert len(input.shape) == 2 batch_size = input.shape[0] out_bin = residual_sign_quantizer( levels=self.levels, x_quantizer=self.x_quantizer, means=self.means, x=input ) expanded_input = out_bin[:, :, self.input_mask] # [levels, batch_size, mask] output = self.trainer(expanded_input, targets, initalize).squeeze() output = output.view(batch_size, -1) assert output.shape[-1] == self.tables_count output = output.view( batch_size, self.out_features, int(self.tables_count / self.out_features), ) output = output.sum(-1) if self.bias is not None: output = output + self.bias return output
[docs] def pre_initialize(self): self.trainer.clear_initializion()
[docs] def update_initialized_weights(self): self.trainer.update_initialized_weights()
[docs] class LinearLogicNets(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, activation_module=None, # To initialize a LogicNets, activation functions are needed input_layers=None, # A LogicNets layer may be merged with one or more inputs layers such as activations and batchnorm output_layers=None, # A LogicNets layer may be merged with one or more output layers such as activations and batchnorm ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return # establish quantizer self.x_width, self.x_frac_width = ( config["data_in_width"], config["data_in_frac_width"], ) self.y_width, self.y_frac_width = ( config["data_out_width"], config["data_out_frac_width"], ) self.x_quantizer = partial( integer_quantizer, width=self.x_width, frac_width=self.x_frac_width ) self.y_quantizer = partial( integer_quantizer, width=self.y_width, frac_width=self.y_frac_width ) # self.input_quant = input_quant # self.output_quant = output_quant self.activation = activation_module self.is_lut_inference = True self.neuron_truth_tables = None # self.calculate_truth_tables() # self.apply_input_quant = apply_input_quant # self.apply_output_quant = apply_output_quant self.input_layers = input_layers self.output_layers = output_layers self.apply_layers = False
# TODO: This function might be a useful utility outside of this class..
[docs] def table_lookup( self, connected_input: Tensor, input_perm_matrix: Tensor, bin_output_states: Tensor, ) -> Tensor: fan_in_size = connected_input.shape[1] ci_bcast = connected_input.unsqueeze(2) # Reshape to B x Fan-in x 1 pm_bcast = input_perm_matrix.t().unsqueeze( 0 ) # Reshape to 1 x Fan-in x InputStates eq = (ci_bcast == pm_bcast).sum( dim=1 ) == fan_in_size # Create a boolean matrix which matches input vectors to possible input states matches = eq.sum(dim=1) # Count the number of perfect matches per input vector if not (matches == torch.ones_like(matches, dtype=matches.dtype)).all(): raise Exception( f"One or more vectors in the input is not in the possible input state space" ) indices = torch.argmax(eq.type(torch.int64), dim=1) return bin_output_states[indices]
[docs] def lut_forward(self, x: Tensor) -> Tensor: x = torch.flatten( x, 1 ) # N - added this; is 1 needed to flatten all dims except batch? # if self.apply_input_quant: # x = self.input_quant(x) # Use this to fetch the bin output of the input, if the input isn't already in binary format x = self.encode(self.x_quantizer(x)) y = torch.zeros((x.shape[0], self.out_features)) # Perform table lookup for each neuron output for i in range(self.out_features): indices, input_perm_matrix, bin_output_states = self.neuron_truth_tables[i] # Move logicnets tensor to GPU input_perm_matrix = input_perm_matrix.to(x.device) bin_output_states = bin_output_states.to(x.device) connected_input = x[:, indices] y[:, i] = self.table_lookup( connected_input, input_perm_matrix, bin_output_states ) return y
[docs] def construct_mask_index(self): # contract a mask have the same shape as self.weight but with zero element being assign to zero and other assign to 1 self.mask = torch.where( self.weight != 0, torch.tensor(1), torch.tensor(0) ).reshape( self.weight.shape[0], -1 ) # pay attention to dimension (out_feature, in_feature)
# Consider using masked_select instead of fetching the indices
[docs] def calculate_truth_tables(self): # print( # "weight", torch.where(self.weight != 0, torch.tensor(1), torch.tensor(0)) # ) # pay attention to dimension (out_feature, in_feature) with torch.no_grad(): # Precalculate all of the input value permutations input_state_space = list() # TODO: is a list the right data-structure here? bin_state_space = list() # get a neuron_state for m in range(self.in_features): neuron_state_space = self.decode( get_int_state_space(self.x_width) ) # TODO: this call should include the index of the element of interest bin_space = get_int_state_space( self.x_width ) # TODO: this call should include the index of the element of interest input_state_space.append(neuron_state_space) bin_state_space.append(bin_space) neuron_truth_tables = list() self.construct_mask_index() # construct pruning mask for n in range(self.out_features): input_mask = self.mask[ n, : ] # N: select row of mask tensor that corresponds to the output feature on this iteration fan_in = torch.sum(input_mask) indices = fetch_mask_indices(input_mask) # Generate a matrix containing all possible input states input_permutation_matrix = generate_permutation_matrix( [input_state_space[i] for i in indices] ) bin_input_permutation_matrix = generate_permutation_matrix( [bin_state_space[i] for i in indices] ) # TODO: Update this block to just run inference on the fc layer, once BN has been moved to output_quant num_permutations = input_permutation_matrix.shape[0] padded_perm_matrix = torch.zeros((num_permutations, self.in_features)) padded_perm_matrix[:, indices] = input_permutation_matrix bin_output_states = self.encode(self.math_forward(padded_perm_matrix))[ :, n ] # Calculate bin for the current input # Append the connectivity, input permutations and output permutations to the neuron truth tables neuron_truth_tables.append( (indices, bin_input_permutation_matrix, bin_output_states) ) # Change this to be the binary output states self.neuron_truth_tables = neuron_truth_tables
[docs] def math_forward(self, input: Tensor) -> Tensor: if self.activation == "unittest": # This is the for performing unittest on the layer return self.y_quantizer( F.linear(self.x_quantizer(input), self.weight, self.bias) ) if self.apply_layers: x = input if self.input_layers: x = self.run_layers(x, self.input_layers) y = self.y_quantizer(F.linear(self.x_quantizer(x), self.weight, self.bias)) if self.output_layers: y = self.run_layers(y, self.output_layers) return y # This is the case where the linear layer is the only module in the LogicNets module return self.y_quantizer( F.linear(self.x_quantizer(input), self.weight, self.bias) )
[docs] def set_fused(self, fused: bool): self.apply_layers = fused
[docs] def run_layers(self, input: Tensor, layers) -> Tensor: assert isinstance(layers, list) y = input for layer in layers: layer_name = layer.__class__.__name__ SUPPORTED_LAYERS = { "ReLU": 1, "Tanh": 1, "BatchNorm1d": 1, "str": 0, } # "str" type is short the "output". Hence this logicnets will be a pure linear without activation. if layer_name not in SUPPORTED_LAYERS: raise ValueError( "Unsupported output layer {}. Please choose from {}".format( layer_name, list(SUPPORTED_LAYERS.keys()) ) ) if SUPPORTED_LAYERS[layer_name]: y = layer(y) return y
[docs] def encode(self, input: Tensor) -> Tensor: return input * 2**self.x_frac_width
[docs] def decode(self, input: Tensor) -> Tensor: return input / 2**self.x_frac_width
[docs] def forward(self, x: Tensor) -> Tensor: if self.is_lut_inference: return self.decode(self.lut_forward(x)) else: return self.math_forward(x)
[docs] class LinearMXIntHardware(_LinearBase):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, config=None, out_config=None, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config self.out_config = out_config self.bypass = config.get("bypass", False) if self.bypass: return
[docs] def forward(self, x): if self.bypass: return F.linear(x, self.weight, self.bias) return linearMXIntHardware( x, self.weight, self.bias, self.config, self.out_config )