Source code for chop.passes.module.transforms.pim.pim_matmul_transform
import torch.nn as nn
import logging
from chop.nn.pim.pim_layer import PIMLinear, PIMConv2d, LoraPIMLinear
from chop.tools import deepsetattr
logger = logging.getLogger(__name__)
def get_module_type(module):
"""
Categorize a module into a predefined type for PIM transformation.
:param module: A tuple containing (module_name, module_instance).
:type module: tuple
:return: The category name of the module (e.g., 'linear', 'conv2d', 'layer_norm', etc.) or None if not recognized.
:rtype: str or None
"""
class_name = module[1].__class__.__name__
if "Linear" in class_name:
return "linear"
elif "LayerNorm" in class_name:
return "layer_norm"
elif "Attention" in class_name:
if "Head" not in class_name:
return "attention"
else:
return "attention_head"
elif "GELU" in class_name:
return "gelu"
elif "Conv2d" in class_name:
return "conv2d"
elif "ReLU" in class_name:
return "relu"
elif "BatchNorm2d" in class_name:
return "batch_norm"
else:
return None
def parse_q_config(module, q_config):
"""
Parse the PIM configuration for a specific module based on its name or type.
:param module: A tuple containing (module_name, module_instance).
:type module: tuple
:param q_config: The global PIM configuration dictionary.
:type q_config: dict
:return: The specific configuration dictionary for the module, or None if no match is found.
:rtype: dict or None
:raises ValueError: If the "by" key in q_config is invalid.
"""
if q_config.get("by") == "name":
if module[0] in q_config:
return q_config[module[0]]["config"]
else:
return None
elif q_config.get("by") == "type":
module_type = get_module_type(module)
if module_type in q_config:
return q_config[module_type]["config"]
else:
return None
else:
raise ValueError(f"Invalid q_config: {q_config}")
[docs]
def pim_matmul_transform_pass(model, q_config={}, lora_config=None):
"""
Apply PIM (Process-in-Memory) transformation to the given nn.Module.
This pass replaces supported layers (Linear, Conv2d) with their PIM-aware counterparts
(PIMLinear, PIMConv2d) or LoRA-enabled PIM layers (LoraPIMLinear).
:param model: The input network to be transformed.
:type model: torch.nn.Module
:param q_config: Configuration for the PIM transformation, specifying how to match modules and their parameters.
:type q_config: dict, optional
:param lora_config: Configuration for LoRA if applying LoRA-enabled PIM transformation.
:type lora_config: dict, optional
Example q_config:
.. code-block:: python
q_config = {
"by": "type",
"linear": {
"config": {
"tile_type": "pcm",
"core_size": 256,
"num_bits": 8,
"programming_noise": True,
"read_noise": True,
"ir_drop": True,
"out_noise": True,
}
},
}
:return: A tuple containing the transformed model and an empty dictionary (for consistency with other passes).
:rtype: tuple
"""
for module in model.named_modules():
config = parse_q_config(module, q_config)
if config is None:
continue
if get_module_type(module) == "conv2d":
ori_module = module[1]
new_module = PIMConv2d(
ori_module.in_channels,
ori_module.out_channels,
ori_module.kernel_size,
ori_module.stride,
ori_module.padding,
q_config=config,
)
new_module.weight = ori_module.weight
new_module.bias = ori_module.bias
deepsetattr(model, module[0], new_module)
logger.debug(f"Replacing module: {module[0]}")
elif get_module_type(module) == "linear":
ori_module = module[1]
if lora_config is not None:
new_module = LoraPIMLinear(
ori_module.in_features,
ori_module.out_features,
q_config=config,
lora_config=lora_config,
)
else:
new_module = PIMLinear(
ori_module.in_features,
ori_module.out_features,
q_config=config,
)
new_module.weight.data = ori_module.weight.data
new_module.bias = ori_module.bias
logger.debug(f"Replacing module: {module[0]}")
deepsetattr(model, module[0], new_module)
else:
continue
return model, {}