Source code for chop.passes.module.transforms.quantize.quantize
import torch
from chop.nn.quantized.modules import quantized_module_map
from ...module_modify_helper import replace_by_name, instantiate_module
def get_config(config: dict, name: str):
if name in config:
return config[name]["config"]
else:
return config["default"]["config"]
def quantize_by_type(network, pass_args):
for type_name, config in pass_args.items():
n_m = {}
for n, m in network.named_modules():
n_m[n] = m
if type_name == "linear":
module = torch.nn.Linear
elif type_name == "conv2d":
module = torch.nn.Conv2d
else:
raise ValueError(f"{type_name} is not supported!")
config = config["config"]
postfix = config.pop("name")
for n, m in n_m.items():
if isinstance(m, module):
new_m = instantiate_module(
m, postfix, quantized_module_map, {"config": config}
)
network = replace_by_name(network, n, new_m)
return network
def quantize_by_name(network, pass_args):
quantize_names = pass_args.keys()
n_m = {}
for n, m in network.named_modules():
n_m[n] = m
for n, m in n_m.items():
if n in quantize_names:
quan_config = pass_args[n]
quan_config = quan_config["config"]
postfix = quan_config.pop("name")
new_m = instantiate_module(
m, postfix, quantized_module_map, {"config": quan_config}
)
network = replace_by_name(network, n, new_m)
return network
[docs]
def quantize_module_transform_pass(network, pass_args):
"""
Apply quantization transformation to the given nn.Module.
:param network: The input network to be transformed.
:type network: torch.nn.Module
:param pass_args: Additional arguments for the transformation.
:type pass_args: dict, optional
Examples pass_args:
.. code-block:: python
pass_args = {
"by": "type", # quantize by type, name, or regex_name
"default": {"config": {"name": None}}, # default config, this would be used for any node that does not have a specific config
"linear": {
"config": {
"name": "integer", # quantization scheme name supported are ["integer", "fixed" (equivalent to integer), "lutnet" (dev mode), "logicnets" (dev mode), "binary", "binary_residual", "ternary", "minifloat_ieee", "minifloat_denorm", "log", "block_fp", "block_minifloat", "block_log"]
# data
"data_in_width": 8,
"data_in_frac_width": 4,
# weight
"weight_width": 8,
"weight_frac_width": 4,
# bias
"bias_width": 8,
"bias_frac_width": 4,
}
},
}
:return: The transformed torch.nn.Module.
:rtype: tuple
:raises ValueError: If the quantize "by" argument is unsupported.
"""
by = pass_args.pop("by")
match by:
case "type":
network = quantize_by_type(network, pass_args)
case "name":
network = quantize_by_name(network, pass_args)
case _:
raise ValueError(f'Unsupported quantize "by": {by}')
return network, {}