Source code for chop.nn.quantized.functional.matmul

from functools import partial
from math import ceil, log2

import torch

# from ....graph.mase_tracer import mark_as_leaf_func
from chop.nn.quantizers import (
    block_fp_quantizer,
    block_log_quantizer,
    block_minifloat_quantizer,
    integer_quantizer,
    integer_floor_quantizer,
    log_quantizer,
    minifloat_denorm_quantizer,
    minifloat_ieee_quantizer,
    binary_quantizer,
    ternary_quantizer,
)

# PyTorch has torch.matmul and torch.bmm for matrix multiplication
matmul_mapping = {"matmul": torch.matmul, "bmm": torch.bmm}


[docs] def generic_matmul_integer(x, y, config, style="matmul", out_config=None, floor=False): bypass = config.get("bypass", False) matmul = matmul_mapping[style] if bypass: return matmul(x, y) else: base_quantizer = integer_quantizer x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] y_width, y_frac_width = config["weight_width"], config["weight_frac_width"] x_quantizer = partial(base_quantizer, width=x_width, frac_width=x_frac_width) y_quantizer = partial(base_quantizer, width=y_width, frac_width=y_frac_width) if out_config is not None: out_width, out_frac_width = ( out_config["data_out_width"], out_config["data_out_frac_width"], ) out_quantizer = partial( integer_floor_quantizer, width=out_width, frac_width=out_frac_width ) x = x_quantizer(x) y = y_quantizer(y) if out_config is not None: return out_quantizer(matmul(x, y)) else: return matmul(x, y)
[docs] def generic_matmul_binary(x, y, config, style="matmul"): bypass = config.get("bypass", False) matmul = matmul_mapping[style] if bypass: return matmul(x, y) else: x_stochastic = config["data_in_stochastic"] x_bipolar = config["data_in_bipolar"] x_quantizer = partial( binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar ) y_stochastic = config["weight_stochastic"] y_bipolar = config["weight_bipolar"] y_quantizer = partial( binary_quantizer, stochastic=y_stochastic, bipolar=y_bipolar ) x = x_quantizer(x) y = y_quantizer(y) # y = x_quantizer(y) return matmul(x, y)
[docs] def generic_matmul_ternary(x, y, config, style="matmul"): bypass = config.get("bypass", False) matmul = matmul_mapping[style] if bypass: return matmul(x, y) else: x_scaling_factor = config["data_in_scaling_factor"] x_quantizer = partial(binary_quantizer, scaling_factor=x_scaling_factor) x = x_quantizer(x) y = x_quantizer(y) return matmul(x, y)
[docs] def generic_matmul_minifloat_denorm(x, y, config, style="matmul"): bypass = config.get("bypass", False) matmul = matmul_mapping[style] if bypass: return matmul(x, y) else: x_width, x_exponent_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], ) y_width, y_exponent_width, y_exponent_bias = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias"], ) x_quantizer = partial( minifloat_denorm_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, ) y_quantizer = partial( minifloat_denorm_quantizer, width=y_width, exponent_width=y_exponent_width, exponent_bias=y_exponent_bias, ) x = x_quantizer(x) y = y_quantizer(y) # y = x_quantizer(y) return matmul(x, y)
[docs] def generic_matmul_minifloat_ieee(x, y, config, style="matmul"): bypass = config.get("bypass", False) matmul = matmul_mapping[style] if bypass: return matmul(x, y) else: x_width, x_exponent_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], ) y_width, y_exponent_width, y_exponent_bias = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias"], ) x_quantizer = partial( minifloat_ieee_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, ) y_quantizer = partial( minifloat_ieee_quantizer, width=y_width, exponent_width=y_exponent_width, exponent_bias=y_exponent_bias, ) x = x_quantizer(x) y = y_quantizer(y) return matmul(x, y)
[docs] def generic_matmul_log(x, y, config, style="matmul"): bypass = config.get("bypass", False) matmul = matmul_mapping[style] if bypass: return matmul(x, y) else: x_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_bias"], ) y_width, y_exponent_bias = ( config["weight_width"], config["weight_exponent_bias"], ) x_quantizer = partial( log_quantizer, width=x_width, exponent_bias=x_exponent_bias, ) y_quantizer = partial( log_quantizer, width=y_width, exponent_bias=y_exponent_bias, ) x = x_quantizer(x) y = y_quantizer(y) # y = x_quantizer(y) return matmul(x, y)
[docs] def generic_matmul_block_fp(x, y, config, style="matmul"): bypass = config.get("bypass", False) matmul = matmul_mapping[style] if bypass: return matmul(x, y) else: x_width, x_exponent_width, x_exponent_bias, x_block_size = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], config["data_in_block_size"], ) y_width, y_exponent_width, y_exponent_bias, y_block_size = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias"], config["weight_block_size"], ) x_more_than_2_dims = x.ndim > 2 y_more_than_2_dims = y.ndim > 2 x_quantizer = partial( block_fp_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, block_size=x_block_size, skip_first_dim=x_more_than_2_dims, ) y_quantizer = partial( block_fp_quantizer, width=y_width, exponent_width=y_exponent_width, exponent_bias=y_exponent_bias, block_size=y_block_size, skip_first_dim=y_more_than_2_dims, ) # flatten all other dims except for the last two dims for performing matmul # this is a hack for allowing block/unblock the last two dims of multiple dim tensors x_shape = [i for i in x.shape] y_shape = [i for i in y.shape] if x_more_than_2_dims: x = torch.flatten(x, 0, -3) if y_more_than_2_dims: y = torch.flatten(y, 0, -3) x = x_quantizer(x) # y = x_quantizer(y) y = y_quantizer(y) x = torch.reshape(x, x_shape) y = torch.reshape(y, y_shape) return matmul(x, y)
[docs] def generic_matmul_block_minifloat(x, y, config, style="matmul"): bypass = config.get("bypass", False) matmul = matmul_mapping[style] if bypass: return matmul(x, y) else: x_width, x_exponent_width, x_exponent_bias_width, x_block_size = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias_width"], config["data_in_block_size"], ) y_width, y_exponent_width, y_exponent_bias_width, y_block_size = ( config["weight_width"], config["weight_exponent_width"], config["weight_exponent_bias_width"], config["weight_block_size"], ) x_more_than_2_dims = x.ndim > 2 y_more_than_2_dims = y.ndim > 2 x_quantizer = partial( block_minifloat_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias_width=x_exponent_bias_width, block_size=x_block_size, skip_first_dim=x_more_than_2_dims, ) y_quantizer = partial( block_minifloat_quantizer, width=y_width, exponent_width=y_exponent_width, exponent_bias_width=y_exponent_bias_width, block_size=y_block_size, skip_first_dim=y_more_than_2_dims, ) # flatten all other dims except for the last two dims for performing matmul # this is a hack for allowing block/unblock the last two dims of multiple dim tensors x_shape = [i for i in x.shape] y_shape = [i for i in y.shape] if x_more_than_2_dims: x = torch.flatten(x, 0, -3) if y_more_than_2_dims: y = torch.flatten(y, 0, -3) x = x_quantizer(x) # y = x_quantizer(y) y = y_quantizer(y) x = torch.reshape(x, x_shape) y = torch.reshape(y, y_shape) return matmul(x, y)
[docs] def generic_matmul_block_log(x, y, config, style="matmul"): bypass = config.get("bypass", False) matmul = matmul_mapping[style] if bypass: return matmul(x, y) else: x_width, x_exponent_bias_width, x_block_size = ( config["data_in_width"], config["data_in_exponent_bias_width"], config["data_in_block_size"], ) y_width, y_exponent_bias_width, y_block_size = ( config["weight_width"], config["weight_exponent_bias_width"], config["weight_block_size"], ) x_more_than_2_dims = x.ndim > 2 y_more_than_2_dims = y.ndim > 2 x_quantizer = partial( block_log_quantizer, width=x_width, exponent_bias_width=x_exponent_bias_width, block_size=x_block_size, skip_first_dim=x_more_than_2_dims, ) y_quantizer = partial( block_log_quantizer, width=y_width, exponent_bias_width=y_exponent_bias_width, block_size=y_block_size, skip_first_dim=y_more_than_2_dims, ) # flatten all other dims except for the last two dims for performing matmul # this is a hack for allowing block/unblock the last two dims of multiple dim tensors x_shape = [i for i in x.shape] y_shape = [i for i in y.shape] if x_more_than_2_dims: x = torch.flatten(x, 0, -3) if y_more_than_2_dims: y = torch.flatten(y, 0, -3) x = x_quantizer(x) # y = x_quantizer(y) x = torch.reshape(x, x_shape) y = torch.reshape(y, y_shape) return matmul(x, y)
[docs] def matmul_integer(x, y, config, out_config=None, floor=False): return generic_matmul_integer(x, y, config, "matmul", out_config, floor)
[docs] def matmul_binary(x, y, config): return generic_matmul_binary(x, y, config, "matmul")
[docs] def matmul_ternary(x, y, config): return generic_matmul_ternary(x, y, config, "matmul")
[docs] def matmul_minifloat_denorm(x, y, config): return generic_matmul_minifloat_denorm(x, y, config, "matmul")
[docs] def matmul_minifloat_ieee(x, y, config): return generic_matmul_minifloat_ieee(x, y, config, "matmul")
[docs] def matmul_log(x, y, config): return generic_matmul_log(x, y, config, "matmul")
[docs] def matmul_block_fp(x, y, config): return generic_matmul_block_fp(x, y, config, "matmul")
[docs] def matmul_block_minifloat(x, y, config): return generic_matmul_block_minifloat(x, y, config, "matmul")
[docs] def matmul_block_log(x, y, config): return generic_matmul_block_log(x, y, config, "matmul")
[docs] def bmm_integer(x, y, config): return generic_matmul_integer(x, y, config, "bmm")
[docs] def bmm_binary(x, y, config): return generic_matmul_binary(x, y, config, "bmm")
[docs] def bmm_ternary(x, y, config): return generic_matmul_ternary(x, y, config, "bmm")
# def get_output_bitwidth_bmm_integer(config, x_shape): # w_width, w_frac = config["weight_width"], config["weight_frac_width"] # x_width, x_frac = config["data_in_width"], config["data_in_frac_width"] # ops = x_shape[-1] # product_width = w_width + x_width # product_frac_width = w_frac + x_frac # output_width = product_width + ceil(log2(ops)) # output_frac_width = product_frac_width # o_bitwidth = {} # o_bitwidth["data_out_width"] = output_width # o_bitwidth["data_out_frac_width"] = output_frac_width # return o_bitwidth
[docs] def bmm_minifloat_denorm(x, y, config): return generic_matmul_minifloat_denorm(x, y, config, "bmm")
[docs] def bmm_minifloat_ieee(x, y, config): return generic_matmul_minifloat_ieee(x, y, config, "bmm")
[docs] def bmm_log(x, y, config): return generic_matmul_log(x, y, config, "bmm")
[docs] def bmm_block_fp(x, y, config): return generic_matmul_block_fp(x, y, config, style="bmm")
[docs] def bmm_block_minifloat(x, y, config): return generic_matmul_block_minifloat(x, y, config, style="bmm")
[docs] def bmm_block_log(x, y, config): return generic_matmul_block_log(x, y, config, style="bmm")