MaseModule Passes

Modules in torch can be seen as composable building blocks for neural networks. They are defined by subclassing torch.nn.Module and implementing the forward method. More detail on torch.nn.Module can be found here.

In mase, we support directly perform passes on torch.nn.Module objects.

import torch
from chop.passes.module.transforms import quantize_module_transform_pass

class MLP(torch.nn.Module):
    """
    Toy quantized FC model for digit recognition on MNIST
    """

    def __init__(self) -> None:
        super().__init__()

        self.fc1 = nn.Linear(28 * 28, 28 * 28)
        self.fc2 = nn.Linear(28 * 28, 28 * 28 * 4)
        self.fc3 = nn.Linear(28 * 28 * 4, 10)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        x = torch.nn.functional.relu(self.fc1(x))
        # w = torch.randn((4, 28 * 28))
        # x = torch.nn.functional.relu(nn.functional.linear(x, w))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


mlp = MLP()
# Sanity check and report
pass_args = {
		"by": "name",
		"fc1": {
				"name": "integer",
				"data_in_width": 8,
				"data_in_frac_width": 4,
				"weight_width": 8,
				"weight_frac_width": 4,
				"bias_width": 8,
				"bias_frac_width": 4,
		},
}
# directly apply quantization on top of a native torch model
quantize_module_transform_pass(mlp, pass_args)