helion.language.associative_scan

helion.language.associative_scan(combine_fn, input_tensor, dim, reverse=False)[source]

Applies an associative scan operation along a specified dimension.

Computes the prefix scan (cumulative operation) along a dimension using a custom combine function. Unlike reduce(), this preserves the input shape.

Parameters:
  • combine_fn (Union[Callable[[Tensor, Tensor], Tensor], Callable[..., tuple[Tensor, ...]]]) – A binary function that combines two elements element-wise. Must be associative for correct results. Can be tensor->tensor or tuple->tuple function.

  • input_tensor (Tensor | tuple[Tensor, ...]) – Input tensor or tuple of tensors to scan

  • dim (int) – The dimension along which to scan

  • reverse (bool) – If True, performs the scan in reverse order

Returns:

Tensor(s) with same shape as input

containing the scan result

Return type:

torch.Tensor or tuple[torch.Tensor, …]

See also

  • reduce(): For dimension-reducing operations

  • cumsum(): For cumulative sum

  • cumprod(): For cumulative product

Note

  • combine_fn must be associative (not necessarily commutative)

  • Output has same shape as input (unlike reduce)

  • For standard scans, use cumsum() or cumprod() (faster)

  • Reverse scan applies the operation from right to left