Source code for chop.nn.quantized.functional.mult

from functools import partial

import torch

from chop.nn.quantizers import (
    block_fp_quantizer,
    block_log_quantizer,
    block_minifloat_quantizer,
    integer_quantizer,
    log_quantizer,
    minifloat_denorm_quantizer,
    minifloat_ieee_quantizer,
    binary_quantizer,
    ternary_quantizer,
)


[docs] def mult_integer(x, y, config): bypass = config.get("bypass", False) if bypass: return x * y else: # establish quantizers x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] x_quantizer = partial(integer_quantizer, width=x_width, frac_width=x_frac_width) x = x_quantizer(x) y = x_quantizer(y) return x * y
[docs] def mult_binary(x, y, config): bypass = config.get("bypass", False) if bypass: return x * y else: # establish quantizers x_stochastic = config["data_in_stochastic"] x_bipolar = config["data_in_bipolar"] x_quantizer = partial( binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar ) x = x_quantizer(x) y = x_quantizer(y) return x * y
[docs] def mult_ternary(x, y, config): bypass = config.get("bypass", False) if bypass: return x * y else: # quantiser x_scaling_factor = config["data_in_scaling_factor"] x_quantizer = partial(ternary_quantizer, scaling_factor=x_scaling_factor) x = x_quantizer(x) y = x_quantizer(y) return x * y
[docs] def mult_minifloat_denorm(x, y, config): bypass = config.get("bypass", False) if bypass: return x * y else: x_width, x_exponent_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], ) x_quantizer = partial( minifloat_denorm_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, ) x = x_quantizer(x) y = x_quantizer(y) return x * y
[docs] def mult_minifloat_ieee(x, y, config): bypass = config.get("bypass", False) if bypass: return x * y else: x_width, x_exponent_width, x_exponent_bias = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], ) x_quantizer = partial( minifloat_ieee_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, ) x = x_quantizer(x) y = x_quantizer(y) return x * y
[docs] def mult_log(x, y, config): bypass = config.get("bypass", False) if bypass: return x * y else: x_width, x_exponent_bias = ( config["data_in_exponent_width"], config["data_in_exponent_bias"], ) x_quantizer = partial( log_quantizer, width=x_width, exponent_bias=x_exponent_bias ) x = x_quantizer(x) y = x_quantizer(y) return x * y
[docs] def mult_block_fp(x, y, config): bypass = config.get("bypass", False) if bypass: return x * y else: x_width, x_exponent_width, x_exponent_bias, x_block_size = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias"], config["data_in_block_size"], ) x_more_than_2_dims = x.ndim > 2 x_quantizer = partial( block_fp_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias=x_exponent_bias, block_size=x_block_size, skip_first_dim=x_more_than_2_dims, ) # a hack to use 2d blocking if x.shape != y.shape: x, y = torch.broadcast_tensors(x, y) # assert x.shape == y.shape x_shape = [i for i in x.shape] if x_more_than_2_dims: x = torch.flatten(x, 0, -3) y = torch.flatten(y, 0, -3) x = x_quantizer(x) y = x_quantizer(y) x = torch.reshape(x, x_shape) y = torch.reshape(y, x_shape) return x * y
[docs] def mult_block_minifloat(x, y, config): bypass = config.get("bypass", False) if bypass: return x * y else: x_width, x_exponent_width, x_exponent_bias_width, x_block_size = ( config["data_in_width"], config["data_in_exponent_width"], config["data_in_exponent_bias_width"], config["data_in_block_size"], ) x_more_than_2_dims = x.ndim > 2 x_quantizer = partial( block_minifloat_quantizer, width=x_width, exponent_width=x_exponent_width, exponent_bias_width=x_exponent_bias_width, block_size=x_block_size, skip_first_dim=x_more_than_2_dims, ) # a hack to use 2d blocking if x.shape != y.shape: x, y = torch.broadcast_tensors(x, y) # assert x.shape == y.shape x_shape = [i for i in x.shape] if x_more_than_2_dims: x = torch.flatten(x, 0, -3) y = torch.flatten(y, 0, -3) x = x_quantizer(x) y = x_quantizer(y) x = torch.reshape(x, x_shape) y = torch.reshape(y, x_shape) return x * y
[docs] def mult_block_log(x, y, config): bypass = config.get("bypass", False) if bypass: return x * y else: x_width, x_exponent_bias_width, x_block_size = ( config["data_in_width"], config["data_in_exponent_bias_width"], config["data_in_block_size"], ) x_more_than_2_dims = x.ndim > 2 x_quantizer = partial( block_log_quantizer, width=x_width, exponent_bias_width=x_exponent_bias_width, block_size=x_block_size, skip_first_dim=x_more_than_2_dims, ) # a hack to use 2d blocking if x.shape != y.shape: x, y = torch.broadcast_tensors(x, y) # assert x.shape == y.shape x_shape = [i for i in x.shape] if x_more_than_2_dims: x = torch.flatten(x, 0, -3) y = torch.flatten(y, 0, -3) x = x_quantizer(x) y = x_quantizer(y) x = torch.reshape(x, x_shape) y = torch.reshape(y, x_shape) return x * y