chop.passes.module.transforms.onn#

This module provides transformation passes for converting standard neural network modules into Optical Neural Network (ONN) equivalents. The optical transformer implementation is based on the Optical Transformers paper.

Optical neural networks leverage photonic hardware to perform matrix multiplications with reduced power consumption. This transform simulates the quantization effects and constraints of optical compute hardware, enabling model development and evaluation before deployment on physical optical accelerators.

Note

This module requires the mase-triton package to be installed. Install via: pip install mase-triton

Transform Pass#

optical_transformer_module_transform_pass#

chop.passes.module.transforms.onn.optical_transformer_module_transform_pass(network: Module, pass_args: dict) Module[source]#

Transform a neural network by replacing supported modules with their optical transformer equivalents.

This pass simulates optical neural network (ONN) computation by replacing standard PyTorch modules with quantized optical transformer layers. The optical transformer model is based on the Optical Transformers paper.

Supported module replacements:

  • torch.nn.LinearOtLinear

  • transformers.models.llama.modeling_llama.LlamaAttentionOtLlamaAttention

Parameters:
  • network (torch.nn.Module) – The input network to be transformed.

  • pass_args (dict) –

    A dictionary containing transformation configurations.

    • by (str): Layer matching strategy. Either 'name' for exact name matching or 'regex_name' for regex-based pattern matching. Defaults to 'regex_name'.

    • default (dict, optional): Default configuration applied to all matching layers.

    • <layer_name_or_pattern> (dict): Per-layer configuration. Each layer config can contain the following keys:

      • q_levels (int): Number of quantization levels. Default: 256.

      • q_lut_min (float): Minimum value for lookup table. Default: 0.020040.

      • q_smooth_factor (float): Smoothing factor for running statistics. Default: 0.9.

      • q_init_seed (int): Random seed for quantization initialization. Default: 0.

      • q_bypass (bool): If True, bypass optical quantization. Default: False.

Returns:

The transformed network with optical transformer modules.

Return type:

torch.nn.Module

Raises:

RuntimeError – If mase-triton is not installed.

Example

from chop.passes.module.transforms.onn import optical_transformer_module_transform_pass

# Transform all linear layers with default config
pass_args = {
    "by": "regex_name",
    "default": {
        "q_levels": 256,
        "q_lut_min": 0.020040,
        "q_bypass": False,
    }
}
transformed_model = optical_transformer_module_transform_pass(model, pass_args)

Note

This pass requires the mase-triton package to be installed. Install via pip install mase-triton.

Configuration#

The transform pass accepts configuration through the pass_args dictionary. Layer matching can be done by exact name or regex patterns.

Example configuration:

pass_args = {
    "by": "regex_name",  # or "name" for exact matching
    "default": {
        "q_levels": 256,
        "q_lut_min": 0.020040,
        "q_smooth_factor": 0.9,
        "q_init_seed": 0,
        "q_bypass": False,
    },
    # Override for specific layers using regex
    ".*mlp.*": {
        "q_levels": 128,
        "q_bypass": False,
    },
}

Configuration Parameters#

Parameter

Type

Default

Description

q_levels

int

256

Number of quantization levels for optical simulation

q_lut_min

float

0.020040

Minimum value for the lookup table used in quantization

q_quantiles

tuple[float, float] or None

None

Quantile range for min/max statistics. If None, uses absolute min/max

q_smooth_factor

float

0.9

Exponential moving average factor for updating running statistics

q_init_seed

int

0

Random seed for quantization noise initialization

q_bypass

bool

False

If True, bypass optical quantization (useful for debugging)

Layers#

OtLinear#

chop.passes.module.transforms.onn.layers.linear.OtLinear#

Optical Transformer Linear layer.

This is an alias to mase_triton.optical_compute.layers.OpticalTransformerLinear. It replaces standard torch.nn.Linear layers with quantized optical transformer equivalents that simulate optical neural network hardware constraints.

The layer applies quantization to both the input activations and weights during matrix multiplication, and tracks running min/max statistics for calibration.

Class method:

classmethod from_linear(linear, **kwargs)#

Create an OtLinear from an existing torch.nn.Linear layer.

Parameters:
  • linear (torch.nn.Linear) – Source linear layer

  • kwargs – Quantization parameters (q_levels, q_lut_min, q_smooth_factor, q_init_seed, q_bypass, etc.)

Returns:

Optical transformer linear layer with copied weights

OtLlamaAttention#

class chop.passes.module.transforms.onn.layers.attn.OtLlamaAttention(config: LlamaConfig, layer_idx: int, q_levels: int = 256, q_lut_min: float = 0.02004, q_quantiles: tuple[float, float] | None = None, q_smooth_factor: float = 0.9, q_init_seed: int = 0, q_bypass: bool = False)[source]#

Bases: Module

Optical Transformer attention module for LLaMA models.

This module replaces the standard HuggingFace LlamaAttention with an optical transformer equivalent that simulates quantized matrix multiplications as would occur in optical neural network hardware. The implementation is based on the Optical Transformers paper.

The attention computation uses optical transformer scaled dot-product attention (SDPA) which applies quantization to the query-key and attention-value matrix multiplications to simulate optical compute constraints.

Parameters:
  • config – HuggingFace LLaMA configuration object.

  • layer_idx (int) – Index of this attention layer in the model.

  • q_levels (int) – Number of quantization levels for optical simulation. Default: 256.

  • q_lut_min (float) – Minimum value for the lookup table used in quantization. Default: 0.020040.

  • q_quantiles (tuple[float, float], optional) – Quantile range for min/max statistics. If None, uses absolute min/max. Default: None.

  • q_smooth_factor (float) – Exponential moving average factor for updating running min/max statistics during training. Default: 0.9.

  • q_init_seed (int) – Random seed for quantization noise initialization. Default: 0.

  • q_bypass (bool) – If True, bypasses optical quantization and uses standard PyTorch attention. Useful for debugging or comparison. Default: False.

query_min_max#

Running min/max statistics for query tensors.

Type:

Tensor

key_min_max#

Running min/max statistics for key tensors.

Type:

Tensor

value_min_max#

Running min/max statistics for value tensors.

Type:

Tensor

qk_min_max#

Running min/max statistics for query-key products.

Type:

Tensor

attn_min_max#

Min/max range for attention weights (fixed at [0, 1]).

Type:

Tensor

av_min_max#

Running min/max statistics for attention-value products.

Type:

Tensor

seed#

Current random seed state for quantization.

Type:

Tensor

Example

from chop.passes.module.transforms.onn.layers.attn import OtLlamaAttention

# Create from existing HuggingFace attention layer
ot_attn = OtLlamaAttention.from_pretrained(
    hf_attention_layer,
    layer_idx=0,
    q_levels=256,
    q_bypass=False,
)
__init__(config: LlamaConfig, layer_idx: int, q_levels: int = 256, q_lut_min: float = 0.02004, q_quantiles: tuple[float, float] | None = None, q_smooth_factor: float = 0.9, q_init_seed: int = 0, q_bypass: bool = False)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(hidden_states: Tensor, position_embeddings: tuple[Tensor, Tensor], attention_mask: Tensor | None, past_key_value=None, cache_position: LongTensor | None = None, **kwargs) tuple[Tensor, Tensor | None, tuple[Tensor] | None][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod from_pretrained(attn: LlamaAttention, layer_idx: int, q_levels: int = 256, q_lut_min: float = 0.02004, q_quantiles: tuple[float, float] | None = None, q_smooth_factor: float = 0.9, q_init_seed: int = 0, q_bypass: bool = False) OtLlamaAttention[source]#

Functional API#

ot_eager_attention_forward#

chop.passes.module.transforms.onn.layers.attn.ot_eager_attention_forward(module: OtLlamaAttention, query: Tensor, key: Tensor, value: Tensor, attention_mask: Tensor | None, scaling: float, dropout: float = 0.0, **kwargs)[source]#

Optical Transformer Scaled Dot-Product Attention.

Computes scaled dot-product attention with quantized matrix multiplications to simulate optical neural network hardware constraints. This function applies quantization to both the query-key and attention-value matrix products.

The quantization statistics (min/max values) are updated in-place during training using an exponential moving average controlled by q_smooth_factor.

Parameters:
  • query (Tensor) – Query tensor of shape (batch, heads, seq_len, head_dim).

  • key (Tensor) – Key tensor of shape (batch, kv_heads, seq_len, head_dim).

  • value (Tensor) – Value tensor of shape (batch, kv_heads, seq_len, head_dim).

  • attention_mask (Tensor, optional) – Attention mask. Default: None.

  • dropout (float) – Dropout probability. Default: 0.0.

  • scaling (float, optional) – Scaling factor. If None, uses 1/sqrt(head_dim).

Returns:

Attention output of shape (batch, heads, seq_len, head_dim).

Return type:

Tensor

Usage Example#

Basic usage with a LLaMA model:

from transformers import AutoModelForCausalLM
from chop.passes.module.transforms.onn import optical_transformer_module_transform_pass

# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

# Define transformation configuration
pass_args = {
    "by": "regex_name",
    "default": {
        "q_levels": 256,
        "q_lut_min": 0.020040,
        "q_smooth_factor": 0.9,
        "q_init_seed": 0,
        "q_bypass": False,
    },
}

# Apply the optical transformer transform
model = optical_transformer_module_transform_pass(model, pass_args)

# The model now uses OtLinear and OtLlamaAttention layers
# Continue with training or inference as usual

Selective Layer Transformation#

Transform only specific layers using regex patterns:

pass_args = {
    "by": "regex_name",
    # Only transform attention layers
    ".*self_attn.*": {
        "q_levels": 256,
        "q_bypass": False,
    },
    # Transform MLP with different settings
    ".*mlp.*": {
        "q_levels": 128,
        "q_bypass": False,
    },
}

Bypass Mode for Debugging#

Use q_bypass=True to disable quantization while keeping the module structure:

pass_args = {
    "by": "regex_name",
    "default": {
        "q_levels": 256,
        "q_bypass": True,  # Disable quantization
    },
}