Source code for chop.nn.quantized.modules.pool2d

from functools import partial
from math import ceil, log2
from typing import Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.common_types import _size_2_t

from ..utils import get_stats, quantiser_passthrough

from chop.nn.quantizers import (
    block_fp_quantizer,
    integer_quantizer,
    minifloat_denorm_quantizer,
    minifloat_ieee_quantizer,
    binary_quantizer,
    ternary_quantizer,
)


class _AvgPool2dBase(torch.nn.AvgPool2d):
    def __init__(
        self,
        kernel_size: _size_2_t,
        stride: _size_2_t | None = None,
        padding: _size_2_t = 0,
        ceil_mode: bool = False,
        count_include_pad: bool = True,
        divisor_override: int | None = None,
    ) -> None:
        super().__init__(
            kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
        )
        self.bypass = False
        self.x_quantizer = None

    def forward(self, x: Tensor) -> Tensor:
        if self.bypass:
            return F.avg_pool2d(
                x,
                self.kernel_size,
                self.stride,
                self.padding,
                self.ceil_mode,
                self.count_include_pad,
                self.divisor_override,
            )
        x = self.x_quantizer(x)
        # Here we have the same problem as quantized conv2d
        # we assume the accumulation is lossless
        return F.avg_pool2d(
            x,
            self.kernel_size,
            self.stride,
            self.padding,
            self.ceil_mode,
            self.count_include_pad,
            self.divisor_override,
        )

    # def get_output_bitwidth(self) -> dict:
    #     raise NotImplementedError


[docs] class AvgPool2dInteger(_AvgPool2dBase):
[docs] def __init__( self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None, config=None, ) -> None: super().__init__( kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] self.x_quantizer = partial( integer_quantizer, width=x_width, frac_width=x_frac_width )
# def get_output_bitwidth(self) -> dict: # config = self.config # x_width, x_frac = config["data_in_width"], config["data_in_frac_width"] # num_add_operands = self.kernel_size[0] * self.kernel_size[1] # output_width = x_width + ceil(log2(num_add_operands)) # output_frac_width = x_frac # o_bitwidth = {} # o_bitwidth["data_out_width"] = output_width # o_bitwidth["data_out_frac_width"] = output_frac_width # return o_bitwidth class _AdaptiveAvgPool2dBase(torch.nn.AdaptiveAvgPool2d): """ Refer to https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993#63603993?newreg=f2a34d7176564a5288717e984bdc21c7 """ bypass = False def __init__(self, output_size) -> None: if isinstance(output_size, int): output_size = (output_size, output_size) super().__init__(output_size) self.bypass = False self.x_quantizer = None def forward(self, x: Tensor) -> Tensor: if self.bypass: return F.adaptive_avg_pool2d(input=x, output_size=self.output_size) else: pool2d_kwargs = self._get_pool2d_kwargs(x.shape) f_padding = pool2d_kwargs.pop("f_padding") x = F.pad(x, f_padding) x = self.x_quantizer(x) return F.avg_pool2d(x, **pool2d_kwargs) def _get_pool2d_kwargs(self, x_shape): h_in_new = ceil(x_shape[-2] / self.output_size[0]) * self.output_size[0] w_in_new = ceil(x_shape[-1] / self.output_size[1]) * self.output_size[1] f_padding = (0, h_in_new - x_shape[-2], 0, w_in_new - x_shape[-1]) stride = (h_in_new // self.output_size[0], w_in_new // self.output_size[1]) kernel_size = ( h_in_new - (self.output_size[0] - 1) * stride[0], w_in_new - (self.output_size[1] - 1) * stride[1], ) return { "kernel_size": kernel_size, "stride": stride, "padding": 0, "f_padding": f_padding, } # def get_output_bitwidth(self) -> dict: # raise NotImplementedError
[docs] class AdaptiveAvgPool2dInteger(_AdaptiveAvgPool2dBase):
[docs] def __init__(self, output_size, config) -> None: super().__init__(output_size) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] self.x_quantizer = partial( integer_quantizer, width=x_width, frac_width=x_frac_width )
# def get_output_bitwidth(self, x_shape) -> dict: # config = self.config # pool2d_kwargs = self._get_pool2d_kwargs(x_shape=x_shape) # kernel_size = pool2d_kwargs["kernel_size"] # x_width, x_frac = config["data_in_width"], config["data_in_frac_width"] # num_add_operands = kernel_size[0] * kernel_size[1] # output_width = x_width + ceil(log2(num_add_operands)) # output_frac_width = x_frac # o_bitwidth = {} # o_bitwidth["data_out_width"] = output_width # o_bitwidth["data_out_frac_width"] = output_frac_width # return o_bitwidth
[docs] class AvgPool2dBinary(_AvgPool2dBase):
[docs] def __init__( self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None, config=None, ) -> None: super().__init__( kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return self.x_quantizer = quantiser_passthrough
[docs] class AvgPool2dTernary(_AvgPool2dBase):
[docs] def __init__( self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None, config=None, ) -> None: super().__init__( kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return self.x_quantizer = quantiser_passthrough