from functools import partial
from typing import Union
import torch
from torch import Tensor
from torch.nn.common_types import _size_1_t
from chop.nn.quantizers import (
block_fp_quantizer,
block_log_quantizer,
block_minifloat_quantizer,
integer_quantizer,
log_quantizer,
minifloat_denorm_quantizer,
minifloat_ieee_quantizer,
binary_quantizer,
ternary_quantizer,
)
from ..utils import get_stats
class _Conv1dBase(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t | str = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
self.bypass = False
self.w_quantizer = None
self.x_quantizer = None
self.b_quantizer = None
self.pruning_masks = None
def forward(self, x: Tensor) -> Tensor:
if self.bypass:
return self._conv_forward(x, self.weight, self.bias)
x = self.x_quantizer(x)
w = self.w_quantizer(self.weight)
bias = self.b_quantizer(self.bias) if self.bias is not None else None
# WARNING: this may have been simplified, we are assuming here the accumulation is lossless!
# The addition size is in_channels * K * K
return self._conv_forward(x, w, bias)
# def get_quantized_weight(self) -> Tensor:
# return self.w_quantizer(self.weight)
# def get_quantized_weights_with_inputs(self, x: Tensor) -> Tensor:
# x = self.x_quantizer(x)
# w = self.w_quantizer(self.weight)
# bias = self.b_quantizer(self.bias) if self.bias is not None else None
# y = self._conv_forward(x, w, bias)
# return {
# "x": x,
# "w": w,
# "bias": bias,
# "y": y,
# }
# def get_output_bitwidth(self):
# """output bit width info for HW gen"""
# raise NotImplementedError()
[docs]
class Conv1dInteger(_Conv1dBase):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros", # TODO: refine this type
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
# establish quantizers
w_width, w_frac_width = config["weight_width"], config["weight_frac_width"]
x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"]
b_width, b_frac_width = config["bias_width"], config["bias_frac_width"]
self.w_quantizer = partial(
integer_quantizer, width=w_width, frac_width=w_frac_width
)
self.x_quantizer = partial(
integer_quantizer, width=x_width, frac_width=x_frac_width
)
self.b_quantizer = partial(
integer_quantizer, width=b_width, frac_width=b_frac_width
)
# def get_output_bitwidth(self):
# config = self.config
# w_width, w_frac = config["weight_width"], config["weight_frac_width"]
# x_width, x_frac = config["data_in_width"], config["data_in_frac_width"]
# bias_width = config["bias_width"]
# ops = self.in_channels * self.kernel_size[0]
# product_width = w_width + x_width
# product_frac_width = w_frac + x_frac
# # *: +1 for bias
# output_width = max(bias_width, product_width + ceil(log2(ops))) + 1
# output_frac_width = product_frac_width
# o_bitwidth = {}
# o_bitwidth["data_out_width"] = output_width
# o_bitwidth["data_out_frac_width"] = output_frac_width
# # output_bitwidth_info["product_width"] = product_width
# # output_bitwidth_info["product_frac_width"] = product_frac_width
# return o_bitwidth
[docs]
class Conv1dMinifloatDenorm(_Conv1dBase):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
config: dict = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
w_width, w_exponent_width, w_exponent_bias = (
config["weight_width"],
config["weight_exponent_width"],
config["weight_exponent_bias"],
)
x_width, x_exponent_width, x_exponent_bias = (
config["data_in_width"],
config["data_in_exponent_width"],
config["data_in_exponent_bias"],
)
b_width, b_exponent_width, b_exponent_bias = (
config["bias_width"],
config["bias_exponent_width"],
config["bias_exponent_bias"],
)
self.w_quantizer = partial(
minifloat_denorm_quantizer,
width=w_width,
exponent_width=w_exponent_width,
exponent_bias=w_exponent_bias,
)
self.x_quantizer = partial(
minifloat_denorm_quantizer,
width=x_width,
exponent_width=x_exponent_width,
exponent_bias=x_exponent_bias,
)
self.b_quantizer = partial(
minifloat_denorm_quantizer,
width=b_width,
exponent_width=b_exponent_width,
exponent_bias=b_exponent_bias,
)
[docs]
class Conv1dLog(_Conv1dBase):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
config: dict = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
w_width, w_exponent_bias = (
config["weight_width"],
config["weight_exponent_bias"],
)
x_width, x_exponent_bias = (
config["data_in_width"],
config["data_in_exponent_bias"],
)
b_width, b_exponent_bias = (
config["bias_width"],
config["bias_exponent_bias"],
)
self.w_quantizer = partial(
log_quantizer,
width=w_width,
exponent_bias=w_exponent_bias,
)
self.x_quantizer = partial(
log_quantizer,
width=x_width,
exponent_bias=x_exponent_bias,
)
self.b_quantizer = partial(
log_quantizer,
width=b_width,
exponent_bias=b_exponent_bias,
)
[docs]
class Conv1dMinifloatIEEE(_Conv1dBase):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
config: dict = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
w_width, w_exponent_width, w_exponent_bias = (
config["weight_width"],
config["weight_exponent_width"],
config["weight_exponent_bias"],
)
x_width, x_exponent_width, x_exponent_bias = (
config["data_in_width"],
config["data_in_exponent_width"],
config["data_in_exponent_bias"],
)
b_width, b_exponent_width, b_exponent_bias = (
config["bias_width"],
config["bias_exponent_width"],
config["bias_exponent_bias"],
)
self.w_quantizer = partial(
minifloat_ieee_quantizer,
width=w_width,
exponent_width=w_exponent_width,
exponent_bias=w_exponent_bias,
)
self.x_quantizer = partial(
minifloat_ieee_quantizer,
width=x_width,
exponent_width=x_exponent_width,
exponent_bias=x_exponent_bias,
)
self.b_quantizer = partial(
minifloat_ieee_quantizer,
width=b_width,
exponent_width=b_exponent_width,
exponent_bias=b_exponent_bias,
)
[docs]
class Conv1dBlockFP(_Conv1dBase):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
config: dict = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
w_width, w_exponent_width, w_exponent_bias, w_block_size = (
config["weight_width"],
config["weight_exponent_width"],
config["weight_exponent_bias"],
config["weight_block_size"],
)
x_width, x_exponent_width, x_exponent_bias, x_block_size = (
config["data_in_width"],
config["data_in_exponent_width"],
config["data_in_exponent_bias"],
config["data_in_block_size"],
)
b_width, b_exponent_width, b_exponent_bias, b_block_size = (
config["bias_width"],
config["bias_exponent_width"],
config["bias_exponent_bias"],
config["bias_block_size"],
)
self.w_quantizer = partial(
block_fp_quantizer,
width=w_width,
exponent_width=w_exponent_width,
exponent_bias=w_exponent_bias,
block_size=w_block_size,
skip_first_dim=True,
)
self.x_quantizer = partial(
block_fp_quantizer,
width=x_width,
exponent_width=x_exponent_width,
exponent_bias=x_exponent_bias,
block_size=x_block_size,
skip_first_dim=True,
)
self.b_quantizer = partial(
block_fp_quantizer,
width=b_width,
exponent_width=b_exponent_width,
exponent_bias=b_exponent_bias,
block_size=b_block_size,
skip_first_dim=False,
)
[docs]
class Conv1dBlockMinifloat(_Conv1dBase):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
config: dict = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
w_width, w_exponent_width, w_exponent_bias_width, w_block_size = (
config["weight_width"],
config["weight_exponent_width"],
config["weight_exponent_bias_width"],
config["weight_block_size"],
)
x_width, x_exponent_width, x_exponent_bias_width, x_block_size = (
config["data_in_width"],
config["data_in_exponent_width"],
config["data_in_exponent_bias_width"],
config["data_in_block_size"],
)
b_width, b_exponent_width, b_exponent_bias_width, b_block_size = (
config["bias_width"],
config["bias_exponent_width"],
config["bias_exponent_bias_width"],
config["bias_block_size"],
)
self.w_quantizer = partial(
block_minifloat_quantizer,
width=w_width,
exponent_width=w_exponent_width,
exponent_bias_width=w_exponent_bias_width,
block_size=w_block_size,
skip_first_dim=True,
)
self.x_quantizer = partial(
block_minifloat_quantizer,
width=x_width,
exponent_width=x_exponent_width,
exponent_bias_width=x_exponent_bias_width,
block_size=x_block_size,
skip_first_dim=True,
)
self.b_quantizer = partial(
block_minifloat_quantizer,
width=b_width,
exponent_width=b_exponent_width,
exponent_bias_width=b_exponent_bias_width,
block_size=b_block_size,
skip_first_dim=False,
)
[docs]
class Conv1dBlockLog(_Conv1dBase):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t | str = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
config: dict = None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
width, exponent_bias_width, block_size = (
config["weight_width"],
config["weight_exponent_bias_width"],
config["weight_block_size"],
)
x_width, x_exponent_bias_width, x_block_size = (
config["data_in_width"],
config["data_in_exponent_bias_width"],
config["data_in_block_size"],
)
b_width, b_exponent_bias_width, b_block_size = (
config["bias_width"],
config["bias_exponent_bias_width"],
config["bias_block_size"],
)
self.w_quantizer = partial(
block_log_quantizer,
width=width,
exponent_bias_width=exponent_bias_width,
block_size=block_size,
skip_first_dim=True,
)
self.x_quantizer = partial(
block_log_quantizer,
width=x_width,
exponent_bias_width=x_exponent_bias_width,
block_size=x_block_size,
skip_first_dim=True,
)
self.b_quantizer = partial(
block_log_quantizer,
width=b_width,
exponent_bias_width=b_exponent_bias_width,
block_size=b_block_size,
skip_first_dim=False,
)
[docs]
class Conv1dBinary(_Conv1dBase):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros", # TODO: refine this type
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
x_stochastic = config["data_in_stochastic"]
x_bipolar = config["data_in_bipolar"]
self.w_quantizer = partial(
binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar
)
self.x_quantizer = partial(
binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar
)
self.b_quantizer = partial(
binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar
)
[docs]
class Conv1dTernary(_Conv1dBase):
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t | str = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
config=None,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
assert config is not None, "config is None!"
self.config = config
self.bypass = config.get("bypass", False)
if self.bypass:
return
x_scaling_factor = config["data_in_scaling_factor"]
w_scaling_factor = config["weight_scaling_factor"]
b_scaling_factor = config["bias_scaling_factor"]
x_mean = get_stats(config, "data_in_mean")
x_median = get_stats(config, "data_in_median")
x_max = get_stats(config, "data_in_max")
w_mean = get_stats(config, "weight_mean")
w_median = get_stats(config, "weight_median")
w_max = get_stats(config, "weight_max")
b_mean = get_stats(config, "bias_mean")
b_median = get_stats(config, "bias_median")
b_max = get_stats(config, "bias_max")
self.x_quantizer = partial(
ternary_quantizer,
scaling_factor=x_scaling_factor,
maximum=x_max,
median=x_median,
mean=x_mean,
)
self.w_quantizer = partial(
ternary_quantizer,
scaling_factor=w_scaling_factor,
maximum=w_max,
median=w_median,
mean=w_mean,
)
self.b_quantizer = partial(
ternary_quantizer,
scaling_factor=b_scaling_factor,
maximum=b_max,
median=b_median,
mean=b_mean,
)