chop.passes.module.transforms.onn#
This module provides transformation passes for converting standard neural network modules into Optical Neural Network (ONN) equivalents. The optical transformer implementation is based on the Optical Transformers paper.
Optical neural networks leverage photonic hardware to perform matrix multiplications with reduced power consumption. This transform simulates the quantization effects and constraints of optical compute hardware, enabling model development and evaluation before deployment on physical optical accelerators.
Note
This module requires the mase-triton package to be installed.
Install via: pip install mase-triton
Transform Pass#
optical_transformer_module_transform_pass#
- chop.passes.module.transforms.onn.optical_transformer_module_transform_pass(network: Module, pass_args: dict) Module[source]#
Transform a neural network by replacing supported modules with their optical transformer equivalents.
This pass simulates optical neural network (ONN) computation by replacing standard PyTorch modules with quantized optical transformer layers. The optical transformer model is based on the Optical Transformers paper.
Supported module replacements:
torch.nn.Linear→OtLineartransformers.models.llama.modeling_llama.LlamaAttention→OtLlamaAttention
- Parameters:
network (torch.nn.Module) – The input network to be transformed.
pass_args (dict) –
A dictionary containing transformation configurations.
by(str): Layer matching strategy. Either'name'for exact name matching or'regex_name'for regex-based pattern matching. Defaults to'regex_name'.default(dict, optional): Default configuration applied to all matching layers.<layer_name_or_pattern>(dict): Per-layer configuration. Each layer config can contain the following keys:q_levels(int): Number of quantization levels. Default: 256.q_lut_min(float): Minimum value for lookup table. Default: 0.020040.q_smooth_factor(float): Smoothing factor for running statistics. Default: 0.9.q_init_seed(int): Random seed for quantization initialization. Default: 0.q_bypass(bool): If True, bypass optical quantization. Default: False.
- Returns:
The transformed network with optical transformer modules.
- Return type:
torch.nn.Module
- Raises:
RuntimeError – If
mase-tritonis not installed.
Example
from chop.passes.module.transforms.onn import optical_transformer_module_transform_pass # Transform all linear layers with default config pass_args = { "by": "regex_name", "default": { "q_levels": 256, "q_lut_min": 0.020040, "q_bypass": False, } } transformed_model = optical_transformer_module_transform_pass(model, pass_args)
Note
This pass requires the
mase-tritonpackage to be installed. Install viapip install mase-triton.
Configuration#
The transform pass accepts configuration through the pass_args dictionary.
Layer matching can be done by exact name or regex patterns.
Example configuration:
pass_args = {
"by": "regex_name", # or "name" for exact matching
"default": {
"q_levels": 256,
"q_lut_min": 0.020040,
"q_smooth_factor": 0.9,
"q_init_seed": 0,
"q_bypass": False,
},
# Override for specific layers using regex
".*mlp.*": {
"q_levels": 128,
"q_bypass": False,
},
}
Configuration Parameters#
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
int |
256 |
Number of quantization levels for optical simulation |
|
float |
0.020040 |
Minimum value for the lookup table used in quantization |
|
tuple[float, float] or None |
None |
Quantile range for min/max statistics. If None, uses absolute min/max |
|
float |
0.9 |
Exponential moving average factor for updating running statistics |
|
int |
0 |
Random seed for quantization noise initialization |
|
bool |
False |
If True, bypass optical quantization (useful for debugging) |
Layers#
OtLinear#
- chop.passes.module.transforms.onn.layers.linear.OtLinear#
Optical Transformer Linear layer.
This is an alias to
mase_triton.optical_compute.layers.OpticalTransformerLinear. It replaces standardtorch.nn.Linearlayers with quantized optical transformer equivalents that simulate optical neural network hardware constraints.The layer applies quantization to both the input activations and weights during matrix multiplication, and tracks running min/max statistics for calibration.
Class method:
- classmethod from_linear(linear, **kwargs)#
Create an OtLinear from an existing
torch.nn.Linearlayer.- Parameters:
linear (torch.nn.Linear) – Source linear layer
kwargs – Quantization parameters (q_levels, q_lut_min, q_smooth_factor, q_init_seed, q_bypass, etc.)
- Returns:
Optical transformer linear layer with copied weights
OtLlamaAttention#
- class chop.passes.module.transforms.onn.layers.attn.OtLlamaAttention(config: LlamaConfig, layer_idx: int, q_levels: int = 256, q_lut_min: float = 0.02004, q_quantiles: tuple[float, float] | None = None, q_smooth_factor: float = 0.9, q_init_seed: int = 0, q_bypass: bool = False)[source]#
Bases:
ModuleOptical Transformer attention module for LLaMA models.
This module replaces the standard HuggingFace LlamaAttention with an optical transformer equivalent that simulates quantized matrix multiplications as would occur in optical neural network hardware. The implementation is based on the Optical Transformers paper.
The attention computation uses optical transformer scaled dot-product attention (SDPA) which applies quantization to the query-key and attention-value matrix multiplications to simulate optical compute constraints.
- Parameters:
config – HuggingFace LLaMA configuration object.
layer_idx (int) – Index of this attention layer in the model.
q_levels (int) – Number of quantization levels for optical simulation. Default: 256.
q_lut_min (float) – Minimum value for the lookup table used in quantization. Default: 0.020040.
q_quantiles (tuple[float, float], optional) – Quantile range for min/max statistics. If None, uses absolute min/max. Default: None.
q_smooth_factor (float) – Exponential moving average factor for updating running min/max statistics during training. Default: 0.9.
q_init_seed (int) – Random seed for quantization noise initialization. Default: 0.
q_bypass (bool) – If True, bypasses optical quantization and uses standard PyTorch attention. Useful for debugging or comparison. Default: False.
- query_min_max#
Running min/max statistics for query tensors.
- Type:
Tensor
- key_min_max#
Running min/max statistics for key tensors.
- Type:
Tensor
- value_min_max#
Running min/max statistics for value tensors.
- Type:
Tensor
- qk_min_max#
Running min/max statistics for query-key products.
- Type:
Tensor
- attn_min_max#
Min/max range for attention weights (fixed at [0, 1]).
- Type:
Tensor
- av_min_max#
Running min/max statistics for attention-value products.
- Type:
Tensor
- seed#
Current random seed state for quantization.
- Type:
Tensor
Example
from chop.passes.module.transforms.onn.layers.attn import OtLlamaAttention # Create from existing HuggingFace attention layer ot_attn = OtLlamaAttention.from_pretrained( hf_attention_layer, layer_idx=0, q_levels=256, q_bypass=False, )
- __init__(config: LlamaConfig, layer_idx: int, q_levels: int = 256, q_lut_min: float = 0.02004, q_quantiles: tuple[float, float] | None = None, q_smooth_factor: float = 0.9, q_init_seed: int = 0, q_bypass: bool = False)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(hidden_states: Tensor, position_embeddings: tuple[Tensor, Tensor], attention_mask: Tensor | None, past_key_value=None, cache_position: LongTensor | None = None, **kwargs) tuple[Tensor, Tensor | None, tuple[Tensor] | None][source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- classmethod from_pretrained(attn: LlamaAttention, layer_idx: int, q_levels: int = 256, q_lut_min: float = 0.02004, q_quantiles: tuple[float, float] | None = None, q_smooth_factor: float = 0.9, q_init_seed: int = 0, q_bypass: bool = False) OtLlamaAttention[source]#
Functional API#
ot_eager_attention_forward#
- chop.passes.module.transforms.onn.layers.attn.ot_eager_attention_forward(module: OtLlamaAttention, query: Tensor, key: Tensor, value: Tensor, attention_mask: Tensor | None, scaling: float, dropout: float = 0.0, **kwargs)[source]#
Optical Transformer Scaled Dot-Product Attention.
Computes scaled dot-product attention with quantized matrix multiplications to simulate optical neural network hardware constraints. This function applies quantization to both the query-key and attention-value matrix products.
The quantization statistics (min/max values) are updated in-place during training using an exponential moving average controlled by
q_smooth_factor.- Parameters:
query (Tensor) – Query tensor of shape
(batch, heads, seq_len, head_dim).key (Tensor) – Key tensor of shape
(batch, kv_heads, seq_len, head_dim).value (Tensor) – Value tensor of shape
(batch, kv_heads, seq_len, head_dim).attention_mask (Tensor, optional) – Attention mask. Default: None.
dropout (float) – Dropout probability. Default: 0.0.
scaling (float, optional) – Scaling factor. If None, uses
1/sqrt(head_dim).
- Returns:
Attention output of shape
(batch, heads, seq_len, head_dim).- Return type:
Tensor
Usage Example#
Basic usage with a LLaMA model:
from transformers import AutoModelForCausalLM
from chop.passes.module.transforms.onn import optical_transformer_module_transform_pass
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Define transformation configuration
pass_args = {
"by": "regex_name",
"default": {
"q_levels": 256,
"q_lut_min": 0.020040,
"q_smooth_factor": 0.9,
"q_init_seed": 0,
"q_bypass": False,
},
}
# Apply the optical transformer transform
model = optical_transformer_module_transform_pass(model, pass_args)
# The model now uses OtLinear and OtLlamaAttention layers
# Continue with training or inference as usual
Selective Layer Transformation#
Transform only specific layers using regex patterns:
pass_args = {
"by": "regex_name",
# Only transform attention layers
".*self_attn.*": {
"q_levels": 256,
"q_bypass": False,
},
# Transform MLP with different settings
".*mlp.*": {
"q_levels": 128,
"q_bypass": False,
},
}
Bypass Mode for Debugging#
Use q_bypass=True to disable quantization while keeping the module structure:
pass_args = {
"by": "regex_name",
"default": {
"q_levels": 256,
"q_bypass": True, # Disable quantization
},
}