Source code for chop.nn.quantized.modules.group_norm

import logging
from math import ceil, log2
from functools import partial

import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F

from chop.nn.quantizers import (
    integer_quantizer,
)


from mase_components.scalar_operators.fixed.test.isqrt_sw import isqrt_sw2

logger = logging.getLogger(__file__)
logger.setLevel(logging.DEBUG)


class _GroupNormBase(nn.GroupNorm):
    def __init__(
        self,
        num_groups: int,
        num_channels: int,
        eps: float = 0.00001,
        affine: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__(num_groups, num_channels, eps, affine, device, dtype)

        self.bypass = False
        self.x_quantizer = None

    def forward(self, x: Tensor) -> Tensor:
        if not self.bypass:
            x = self.x_quantizer(x)
        return F.group_norm(x, self.num_groups, None, None, self.eps)


[docs] class GroupNormInteger(_GroupNormBase):
[docs] def __init__( self, num_groups: int, num_channels: int, eps: float = 0.00001, affine: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__(num_groups, num_channels, eps, affine, device, dtype) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] self.x_quantizer = partial( integer_quantizer, width=x_width, frac_width=x_frac_width )