Source code for chop.nn.quantized.modules.batch_norm2d

from functools import partial

import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F

from chop.nn.quantizers import (
    integer_quantizer,
    binary_quantizer,
)
from ..utils import quantiser_passthrough


class _BatchNorm2dBase(nn.BatchNorm2d):
    def __init__(
        self,
        num_features: int,
        eps: float = 0.00001,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__(
            num_features, eps, momentum, affine, track_running_stats, device, dtype
        )
        self.bypass = False
        self.x_quantizer = None

    def forward(self, x: Tensor) -> Tensor:
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        if not self.bypass:
            x = self.x_quantizer(x)

        if not self.training or self.track_running_stats:
            running_mean = self.running_mean
            running_var = self.running_var
        else:
            running_mean = None
            running_var = None

        return F.batch_norm(
            x,
            # If buffers are not to be tracked, ensure that they won't be updated
            running_mean,
            running_var,
            self.weight,
            self.bias,
            bn_training,
            self.momentum,
            self.eps,
        )


[docs] class BatchNorm2dInteger(_BatchNorm2dBase):
[docs] def __init__( self, num_features: int, eps: float = 0.00001, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__( num_features, eps, momentum, affine, track_running_stats, device, dtype ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] self.x_quantizer = partial( integer_quantizer, width=x_width, frac_width=x_frac_width ) if affine: w_width, w_frac_width = config["weight_width"], config["weight_frac_width"] b_width, b_frac_width = config["bias_width"], config["bias_frac_width"] self.w_quantizer = partial( integer_quantizer, width=w_width, frac_width=w_frac_width ) self.b_quantizer = partial( integer_quantizer, width=b_width, frac_width=b_frac_width )
[docs] class BatchNorm2dBinary(_BatchNorm2dBase):
[docs] def __init__( self, num_features: int, eps: float = 0.00001, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, dtype=None, config=None, ) -> None: super().__init__( num_features, eps, momentum, affine, track_running_stats, device, dtype ) assert config is not None, "config is None!" self.config = config self.bypass = config.get("bypass", False) if self.bypass: return w_stochastic = config["weight_stochastic"] w_bipolar = config["weight_bipolar"] self.w_quantizer = partial( binary_quantizer, stochastic=w_stochastic, bipolar=w_bipolar ) self.b_quantizer = quantiser_passthrough self.x_quantizer = quantiser_passthrough