Importing MASE as a Python package

An easy way to use MASE is to make a direct import.

As mentioned in Getting Started, you can install an editable version of MASE through pip by

pip install -e . -vvv

You can then test the installation by

python -c"import chop; print(chop)"

Transforming torch.Module

chop.passes offers a range of different passes that offer the capability to replace certain components for in the original neural network for various purposes. Some of these passes are Module Passes, that can directly operate on native torch.nn.Module, which basically means any arbitrary networks.

The following example is applying a MASE module pass to a pre-built resnet50.

import chop 

from torchvision.models import resnet50, ResNet50_Weights

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

from chop.passes import quantize_module_transform_pass

pass_args = {
    "by": "type",
    "linear": {
        "name": "integer",
        "data_in_width": 8,
        "data_in_frac_width": 4,
        "weight_width": 8,
        "weight_frac_width": 4,
        "bias_width": 8,
        "bias_frac_width": 4,
    },
    "conv2d": {
        "name": "integer",
        "data_in_width": 8,
        "data_in_frac_width": 4,
        "weight_width": 8,
        "weight_frac_width": 4,
        "bias_width": 8,
        "bias_frac_width": 4,
    },
}
transformed_model = quantize_module_transform_pass(model, pass_args)
print(transformed_model)

Transforming a MaseGraph

To support manipulation on a finer-level, eg. graph-level, MASE has provided built-in functionality to transformer any arbitrary torch.nn.Module into MaseGraph. This is because the ordinary torch.Module level view is normally not enough for finer manipulation – many details are omitted at this level, we thus provide this MaseGraph to capture these detail. Correspondingly, we have provided a series of passes on the graph-level.

The following example is applying a MASE graph-level pass to a vgg7 network. It also tries to use many MASE built-in functions to fetch the data, fetch the model, transform the model to MaseGraph land and then apply graph-level passes.

import logging

import chop 

from chop.dataset import MaseDataModule, get_dataset_info
from chop.ir.graph.mase_graph import MaseGraph
from chop import models
from chop.tools.get_input import InputGenerator, get_dummy_input

from chop.passes.graph import (
    init_metadata_analysis_pass,
    add_common_metadata_analysis_pass,
    quantize_transform_pass,
    summarize_quantization_analysis_pass,
    verify_common_metadata_analysis_pass,
)

from chop.passes.graph.utils import deepcopy_mase_graph

model_name = "vgg7"
dataset_name = "cifar10"
BATCH_SIZE = 32

# get dataset information
dataset_info = get_dataset_info(dataset_name)

# get model information
model_info = models.get_model_info(model_name)

# get data module
data_module = MaseDataModule(
    model_name=model_name,
    name=dataset_name,
    batch_size=BATCH_SIZE,
    num_workers=0,
    tokenizer=None,
    max_token_len=None,
)
data_module.prepare_data()
data_module.setup()
# NOTE: We only support vision classification models for now.
dummy_input = get_dummy_input(model_info, data_module, "cls", "cpu")

# get an input generator so that we can drive to get sample inputs
input_generator = InputGenerator(
    model_info=model_info,
    data_module=data_module,
        task="cls",
        which_dataloader="train",
    )

model = models.get_model(model_name, "cls", dataset_info, pretrained=True)

# This line transforms a nn.Module to a MaseGraph
mg = MaseGraph(model=model)

# Apply initialization passes to populate information in the graph
mg, _ = init_metadata_analysis_pass(mg, {})
mg, _ = add_common_metadata_analysis_pass(
    mg, {"dummy_in": dummy_input, "add_value": False}
)
# Sanity check and report
# mg = verify_common_metadata_analysis_pass(mg)
quan_args = {
    "by": "type",
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

# deep copy is only possible if we put "add_value" to False
ori_mg = deepcopy_mase_graph(mg)
mg, _ = quantize_transform_pass(mg, quan_args)

summarize_quantization_analysis_pass(mg, pass_args={"save_dir": "quantize_summary", "original_mg": ori_mg})