Source code for chop.nn.functional.softermax
from torch import Tensor
[docs]
def softermax(input: Tensor, dim: int) -> Tensor:
"""Softermax implementation, according to "Softermax: Hardware/Software Co-Design of an Efficient Softmax for Transformers" paper (https://arxiv.org/abs/2103.09301).
Args:
input (Tensor): Input tensor
Returns:
Tensor: Output tensor
"""
out = input - input.max(dim=dim, keepdim=True).values.floor()
out = 2**out
row_sum = out.sum(dim=dim, keepdim=True)
# Elementwise division
out = out / row_sum
return out