Source code for chop.nn.functional.softermax
from torch import Tensor
[docs]
def softermax(input: Tensor, dim: int) -> Tensor:
    """Softermax implementation, according to "Softermax: Hardware/Software Co-Design of an Efficient Softmax for Transformers" paper (https://arxiv.org/abs/2103.09301).
    Args:
        input (Tensor): Input tensor
    Returns:
        Tensor: Output tensor
    """
    out = input - input.max(dim=dim, keepdim=True).values.floor()
    out = 2**out
    row_sum = out.sum(dim=dim, keepdim=True)
    # Elementwise division
    out = out / row_sum
    return out