Source code for chop.passes.module.analysis.quantize.calculate_avg_bits_module

import logging
import numpy as np
import torch

logger = logging.getLogger(__name__)


[docs] def calculate_avg_bits_module_analysis_pass( module: torch.nn.Module, pass_args: dict = {} ) -> tuple: """ Analyzes the averaged weight bitwidth of a given module. Considering only Linear and Conv2d layers. :param module: The module to analyze. :type module: torch.nn.Module :param pass_args: Additional arguments for the analysis pass. (default: {}) :type pass_args: dict :return: A tuple containing the modified module and a dictionary with the analysis results. :rtype: tuple(torch.nn.Module, dict) analysis results is a dictionary with the following keys: - 'average_bitwidth' (float): The average number of bits per value for weight. Examples output: .. code-block:: python info = { # this means on average each weight value is represented by 16 bits 'average_bitwidth': 16.0} """ assert isinstance(module, torch.nn.Module), "module must be a nn.Module instance" assert isinstance(pass_args, dict), "pass_args must be a dict instance" return_info = {} weights_size, weight_bits = 0, 0 for n, m in module.named_modules(): # a simple estimation to loop around only linear layers if isinstance(m, torch.nn.Linear) and hasattr(m, "config"): weights_size += m.in_features * m.out_features weight_bits += m.in_features * m.out_features * m.config["weight_width"] if isinstance(m, torch.nn.Conv2d) and hasattr(m, "config"): weights_size += ( m.in_channels * m.out_channels * m.kernel_size[0] * m.kernel_size[1] ) weight_bits += ( m.in_channels * m.out_channels * m.kernel_size[0] * m.kernel_size[1] * m.config["weight_width"] ) if weight_bits == 0: logger.warning( "No quantized layers found in the model, set average_bitwidth to 32" ) return_info |= {"average_bitwidth": 32} return module, return_info else: return_info |= {"average_bitwidth": weight_bits / weights_size} return module, return_info