from functools import partial
from chop.nn.quantized.functional.linear import (
linearBinary,
linearBinaryScaling,
linearBlockFP,
linearBlockLog,
linearBlockMinifloat,
linearInteger,
linearLog,
linearMXIntHardware,
linearMinifloatDenorm,
linearMinifloatIEEE,
linearTernary,
)
import torch
from torch import Tensor
from torch.nn import functional as F
from ..utils import get_stats, quantiser_passthrough
from chop.nn.quantizers import (
residual_sign_quantizer,
block_fp_quantizer,
block_log_quantizer,
block_minifloat_quantizer,
integer_quantizer,
integer_floor_quantizer,
log_quantizer,
minifloat_denorm_quantizer,
minifloat_ieee_quantizer,
binary_quantizer,
ternary_quantizer,
mxint_hardware,
)
# LUTNet
import numpy as np
from typing import Type
from chop.nn.quantizers.LUTNet.BaseTrainer import BaseTrainer, LagrangeTrainer
from chop.nn.quantizers.LUTNet.MaskBase import MaskBase, MaskExpanded
# LogicNets
from chop.nn.quantizers.LogicNets.utils import (
generate_permutation_matrix,
get_int_state_space,
fetch_mask_indices,
)
# LogicNets
from chop.nn.quantizers.LogicNets.utils import (
generate_permutation_matrix,
get_int_state_space,
fetch_mask_indices,
)
class _LinearBase(torch.nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
device=None,
dtype=None,
) -> None:
super().__init__(
in_features,
out_features,
bias,
device,
dtype,
)
self.bypass = False
self.pruning_masks = None
# NOTE: Quantizers properties are not needed for now
# self.x_quantizer = None
# self.w_quantizer = None
# self.b_quantizer = None
# self.out_quantizer = None
# NOTE: This is not needed for now
# def forward(self, x: Tensor) -> Tensor:
# if self.bypass:
# # if bypass, there is no quantization
# return F.linear(x, self.weight, self.bias)
# else:
# x = self.x_quantizer(x)
# w = self.w_quantizer(self.weight)
# bias = self.b_quantizer(self.bias) if self.bias is not None else None
# out = F.linear(x, w, bias)
# if self.out_quantizer is None:
# return out
# return self.out_quantizer(out)
[docs]
class LinearInteger(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
out_config=None,
floor=False,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.out_config = out_config
self.bypass = config.get("bypass", False)
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearInteger(x, self.weight, self.bias, self.config, self.out_config)
[docs]
class LinearMinifloatDenorm(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearMinifloatDenorm(x, self.weight, self.bias, self.config)
[docs]
class LinearMinifloatIEEE(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearMinifloatIEEE(x, self.weight, self.bias, self.config)
[docs]
class LinearLog(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearLog(x, self.weight, self.bias, self.config)
[docs]
class LinearBlockFP(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearBlockFP(x, self.weight, self.bias, self.config)
[docs]
class LinearBlockMinifloat(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearBlockMinifloat(x, self.weight, self.bias, self.config)
[docs]
class LinearBlockLog(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearBlockLog(x, self.weight, self.bias, self.config)
[docs]
class LinearBinary(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearBinary(x, self.weight, self.bias, self.config)
[docs]
class LinearBinaryScaling(_LinearBase):
"""
Binary scaling variant of the linear transformation layer.
- "bypass": Bypass quantization for standard linear transformation.
- "data_in_stochastic", "bias_stochastic", "weight_stochastic": Stochastic settings.
- "data_in_bipolar", "bias_bipolar", "weight_bipolar": Bipolar settings.
- "binary_training": Apply binary scaling during training.
"""
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
# self.gamma = torch.nn.Parameter(torch.tensor(1.0, requires_grad=True))
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearBinaryScaling(x, self.weight, self.bias, self.config)
[docs]
class LinearTernary(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
w_scaling_factor = config["weight_scaling_factor"]
w_mean = get_stats(config, "weight_mean")
w_median = get_stats(config, "weight_median")
w_max = get_stats(config, "weight_max")
self.w_quantizer = partial(
ternary_quantizer,
scaling_factor=w_scaling_factor,
maximum=w_max,
median=w_median,
mean=w_mean,
)
self.x_quantizer = quantiser_passthrough
self.b_quantizer = quantiser_passthrough
# self.b_quantizer = partial(
# ternary_quantizer,
# scaling_factor=b_scaling_factor,
# maximum=b_max,
# median=b_median,
# mean=b_mean,
# )
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearTernary(x, self.weight, self.bias, self.config)
# LUT
[docs]
class LinearBinaryResidualSign(_LinearBase):
"""
Binary Linear layer with redisual sign variant of the linear transformation layer.
- "bypass": Bypass quantization for standard linear transformation.
- "data_in_stochastic", "bias_stochastic", "weight_stochastic": Stochastic settings.
- "data_in_bipolar", "bias_bipolar", "weight_bipolar": Bipolar settings.
- "binary_training": Apply binary scaling during training.
- "data_in_levels": The num of residual layers to use.
- "data_in_residual_sign" : Apply residual sign on input
"""
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.levels = config.get("data_in_levels", 2) # NOTE: Hardcode 2 for now
self.config = config
self.bypass = config.get("bypass", False)
self.gamma = torch.nn.Parameter(torch.tensor(1.0, requires_grad=True))
# Initialized parameter
ars = np.arange(self.levels) + 1.0
ars = ars[::-1]
means = ars / np.sum(ars)
# Create a torch.nn.Parameter from the means tensor
self.means = (
torch.nn.Parameter(
torch.tensor(means, dtype=torch.float32, requires_grad=True)
)
if self.config.get("data_in_residual_sign", True)
else None
)
# prunning masks
self.pruning_masks = torch.nn.Parameter(
torch.ones_like(self.weight), requires_grad=False
)
if self.bypass:
return
x_stochastic, w_stochastic = (
config["data_in_stochastic"],
config["weight_stochastic"],
)
x_bipolar, w_bipolar = (
config["data_in_bipolar"],
config["weight_bipolar"],
)
self.binary_training = config["binary_training"]
self.w_quantizer = partial(
binary_quantizer, stochastic=w_stochastic, bipolar=w_bipolar
)
self.x_quantizer = partial(
binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar
)
[docs]
def forward(self, x: Tensor) -> Tensor:
if self.bypass:
# if bypss, there is no quantization
return F.linear(x, self.weight, self.bias)
x_expanded = 0
if self.means is not None:
out_bin = residual_sign_quantizer(
levels=self.levels, x_quantizer=self.x_quantizer, means=self.means, x=x
)
for l in range(self.levels):
x_expanded = x_expanded + out_bin[l, :, :]
else:
x_expanded = x
if self.binary_training:
w = self.w_quantizer(self.weight)
return F.linear(
x_expanded,
w * self.gamma.abs() * self.pruning_masks,
self.bias,
)
else:
self.weigh = self.weight.data.clamp_(-1, 1)
return F.linear(
x_expanded,
self.weight * self.gamma.abs() * self.pruning_masks,
self.bias,
)
[docs]
class LinearLUT(torch.nn.Module):
input_mask: torch.Tensor
tables_count: int
in_features: int
out_features: int
trainer: BaseTrainer
mask_builder_type: Type[MaskBase]
mask_builder: MaskBase
[docs]
def __init__(
self,
config: None,
in_features: int,
out_features: int,
mask_builder_type: Type[MaskBase] = MaskExpanded,
trainer_type: Type[BaseTrainer] = LagrangeTrainer,
bias: bool = True,
device: str = None,
) -> None:
super(LinearLUT, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.levels = config.get("data_in_levels", 2)
self.input_expanded = config["data_in_input_expanded"]
self.k = config["data_in_k"]
self.kk = 2 ** config["data_in_k"]
self.mask_builder_type = mask_builder_type
# Initialize mask builder
self.input_mask = self._input_mask_builder()
# TODO: table * output feature map
self.tables_count = self.mask_builder.get_tables_count() * self.out_features
self.trainer = trainer_type(
levels=self.levels,
tables_count=self.tables_count,
k=config["data_in_k"],
binarization_level=(1 if config["data_in_binarization_level"] == 1 else 0),
input_expanded=config["data_in_input_expanded"],
device=device,
)
self.weight = self.trainer.weight
self.pruning_masks = self.trainer.pruning_masks
# TODO: we might need to this later on
# stdv = 1 / np.sqrt(self.in_features)
# w = np.random.normal(loc=0.0, scale=stdv, size=list(self.trainer.weight.shape)).astype(np.float32)
# self.trainer.weight = torch.nn.Parameter(
# torch.tensor(w, requires_grad=True))
self.bias = (
torch.nn.Linear(1, out_features, device=device).bias if bias else None
)
# Residual sign code
self.x_quantizer = partial(binary_quantizer, stochastic=False, bipolar=True)
ars = np.arange(self.levels) + 1.0
ars = ars[::-1]
means = ars / np.sum(ars)
self.means = torch.nn.Parameter(
torch.tensor(means, dtype=torch.float32, requires_grad=True)
)
def _table_input_selections_builder(self) -> np.array:
_all_inputs_set = set(range(self.in_features))
result = []
for in_idx in range(self.in_features):
_idx_set = set([in_idx])
_selection = list(_all_inputs_set - _idx_set)
result.append((in_idx, _selection))
return result
def _input_mask_builder(self) -> torch.Tensor:
"""
Initializing table (using indices for the connections)
"""
result = []
# TODO: elements can appear more than once in the feature-1 input?
for _ in range(self.out_features):
self.mask_builder = self.mask_builder_type(
self.k, self._table_input_selections_builder(), True
)
result.append(self.mask_builder.build())
return np.concatenate(result)
[docs]
def forward(
self,
input: torch.Tensor,
targets: torch.tensor = None,
initalize: bool = False,
):
assert len(input.shape) == 2
batch_size = input.shape[0]
out_bin = residual_sign_quantizer(
levels=self.levels, x_quantizer=self.x_quantizer, means=self.means, x=input
)
expanded_input = out_bin[:, :, self.input_mask] # [levels, batch_size, mask]
output = self.trainer(expanded_input, targets, initalize).squeeze()
output = output.view(batch_size, -1)
assert output.shape[-1] == self.tables_count
output = output.view(
batch_size,
self.out_features,
int(self.tables_count / self.out_features),
)
output = output.sum(-1)
if self.bias is not None:
output = output + self.bias
return output
[docs]
def pre_initialize(self):
self.trainer.clear_initializion()
[docs]
def update_initialized_weights(self):
self.trainer.update_initialized_weights()
[docs]
class LinearLogicNets(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
activation_module=None, # To initialize a LogicNets, activation functions are needed
input_layers=None, # A LogicNets layer may be merged with one or more inputs layers such as activations and batchnorm
output_layers=None, # A LogicNets layer may be merged with one or more output layers such as activations and batchnorm
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
# establish quantizer
self.x_width, self.x_frac_width = (
config["data_in_width"],
config["data_in_frac_width"],
)
self.y_width, self.y_frac_width = (
config["data_out_width"],
config["data_out_frac_width"],
)
self.x_quantizer = partial(
integer_quantizer, width=self.x_width, frac_width=self.x_frac_width
)
self.y_quantizer = partial(
integer_quantizer, width=self.y_width, frac_width=self.y_frac_width
)
# self.input_quant = input_quant
# self.output_quant = output_quant
self.activation = activation_module
self.is_lut_inference = True
self.neuron_truth_tables = None
# self.calculate_truth_tables()
# self.apply_input_quant = apply_input_quant
# self.apply_output_quant = apply_output_quant
self.input_layers = input_layers
self.output_layers = output_layers
self.apply_layers = False
# TODO: This function might be a useful utility outside of this class..
[docs]
def table_lookup(
self,
connected_input: Tensor,
input_perm_matrix: Tensor,
bin_output_states: Tensor,
) -> Tensor:
fan_in_size = connected_input.shape[1]
ci_bcast = connected_input.unsqueeze(2) # Reshape to B x Fan-in x 1
pm_bcast = input_perm_matrix.t().unsqueeze(
0
) # Reshape to 1 x Fan-in x InputStates
eq = (ci_bcast == pm_bcast).sum(
dim=1
) == fan_in_size # Create a boolean matrix which matches input vectors to possible input states
matches = eq.sum(dim=1) # Count the number of perfect matches per input vector
if not (matches == torch.ones_like(matches, dtype=matches.dtype)).all():
raise Exception(
f"One or more vectors in the input is not in the possible input state space"
)
indices = torch.argmax(eq.type(torch.int64), dim=1)
return bin_output_states[indices]
[docs]
def lut_forward(self, x: Tensor) -> Tensor:
x = torch.flatten(
x, 1
) # N - added this; is 1 needed to flatten all dims except batch?
# if self.apply_input_quant:
# x = self.input_quant(x) # Use this to fetch the bin output of the input, if the input isn't already in binary format
x = self.encode(self.x_quantizer(x))
y = torch.zeros((x.shape[0], self.out_features))
# Perform table lookup for each neuron output
for i in range(self.out_features):
indices, input_perm_matrix, bin_output_states = self.neuron_truth_tables[i]
# Move logicnets tensor to GPU
input_perm_matrix = input_perm_matrix.to(x.device)
bin_output_states = bin_output_states.to(x.device)
connected_input = x[:, indices]
y[:, i] = self.table_lookup(
connected_input, input_perm_matrix, bin_output_states
)
return y
[docs]
def construct_mask_index(self):
# contract a mask have the same shape as self.weight but with zero element being assign to zero and other assign to 1
self.mask = torch.where(
self.weight != 0, torch.tensor(1), torch.tensor(0)
).reshape(
self.weight.shape[0], -1
) # pay attention to dimension (out_feature, in_feature)
# Consider using masked_select instead of fetching the indices
[docs]
def calculate_truth_tables(self):
# print(
# "weight", torch.where(self.weight != 0, torch.tensor(1), torch.tensor(0))
# ) # pay attention to dimension (out_feature, in_feature)
with torch.no_grad():
# Precalculate all of the input value permutations
input_state_space = list() # TODO: is a list the right data-structure here?
bin_state_space = list()
# get a neuron_state
for m in range(self.in_features):
neuron_state_space = self.decode(
get_int_state_space(self.x_width)
) # TODO: this call should include the index of the element of interest
bin_space = get_int_state_space(
self.x_width
) # TODO: this call should include the index of the element of interest
input_state_space.append(neuron_state_space)
bin_state_space.append(bin_space)
neuron_truth_tables = list()
self.construct_mask_index() # construct pruning mask
for n in range(self.out_features):
input_mask = self.mask[
n, :
] # N: select row of mask tensor that corresponds to the output feature on this iteration
fan_in = torch.sum(input_mask)
indices = fetch_mask_indices(input_mask)
# Generate a matrix containing all possible input states
input_permutation_matrix = generate_permutation_matrix(
[input_state_space[i] for i in indices]
)
bin_input_permutation_matrix = generate_permutation_matrix(
[bin_state_space[i] for i in indices]
)
# TODO: Update this block to just run inference on the fc layer, once BN has been moved to output_quant
num_permutations = input_permutation_matrix.shape[0]
padded_perm_matrix = torch.zeros((num_permutations, self.in_features))
padded_perm_matrix[:, indices] = input_permutation_matrix
bin_output_states = self.encode(self.math_forward(padded_perm_matrix))[
:, n
] # Calculate bin for the current input
# Append the connectivity, input permutations and output permutations to the neuron truth tables
neuron_truth_tables.append(
(indices, bin_input_permutation_matrix, bin_output_states)
) # Change this to be the binary output states
self.neuron_truth_tables = neuron_truth_tables
[docs]
def math_forward(self, input: Tensor) -> Tensor:
if self.activation == "unittest":
# This is the for performing unittest on the layer
return self.y_quantizer(
F.linear(self.x_quantizer(input), self.weight, self.bias)
)
if self.apply_layers:
x = input
if self.input_layers:
x = self.run_layers(x, self.input_layers)
y = self.y_quantizer(F.linear(self.x_quantizer(x), self.weight, self.bias))
if self.output_layers:
y = self.run_layers(y, self.output_layers)
return y
# This is the case where the linear layer is the only module in the LogicNets module
return self.y_quantizer(
F.linear(self.x_quantizer(input), self.weight, self.bias)
)
[docs]
def set_fused(self, fused: bool):
self.apply_layers = fused
[docs]
def run_layers(self, input: Tensor, layers) -> Tensor:
assert isinstance(layers, list)
y = input
for layer in layers:
layer_name = layer.__class__.__name__
SUPPORTED_LAYERS = {
"ReLU": 1,
"Tanh": 1,
"BatchNorm1d": 1,
"str": 0,
} # "str" type is short the "output". Hence this logicnets will be a pure linear without activation.
if layer_name not in SUPPORTED_LAYERS:
raise ValueError(
"Unsupported output layer {}. Please choose from {}".format(
layer_name, list(SUPPORTED_LAYERS.keys())
)
)
if SUPPORTED_LAYERS[layer_name]:
y = layer(y)
return y
[docs]
def encode(self, input: Tensor) -> Tensor:
return input * 2**self.x_frac_width
[docs]
def decode(self, input: Tensor) -> Tensor:
return input / 2**self.x_frac_width
[docs]
def forward(self, x: Tensor) -> Tensor:
if self.is_lut_inference:
return self.decode(self.lut_forward(x))
else:
return self.math_forward(x)
[docs]
class LinearMXIntHardware(_LinearBase):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
config=None,
out_config=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
assert config is not None, "config is None!"
self.config = config
self.out_config = out_config
self.bypass = config.get("bypass", False)
if self.bypass:
return
[docs]
def forward(self, x):
if self.bypass:
return F.linear(x, self.weight, self.bias)
return linearMXIntHardware(
x, self.weight, self.bias, self.config, self.out_config
)