Source code for chop.nn.quantized.utils
from torch import Tensor
[docs]
def get_stats(config: dict, stat_name: str) -> float | None:
if not config.get(stat_name) in [
None,
"NA",
]: # if entry does not exists, None is returned, "NA" if no stats available in config file
return config[stat_name]
else:
if "weight" in stat_name:
stat = config["node_meta_stat"]["weight"]["stat"]
elif "bias" in stat_name:
stat = config["node_meta_stat"]["bias"]["stat"]
elif "data_in" in stat_name:
stat = config["node_meta_stat"]["data_in_0"]["stat"]
# TODO FIX MULTI ARG
if "mean" in stat_name:
return stat["abs_mean"]["abs_mean"] if "abs_mean" in stat else None
elif "median" in stat_name:
return stat["range_quantile"]["max"] if "range_quantile" in stat else None
elif "max" in stat_name:
return stat["range_min_max"]["max"] if "range_min_max" in stat else None
[docs]
def quantiser_passthrough(x: Tensor):
return x