helion.language.associative_scan
- helion.language.associative_scan(combine_fn, input_tensor, dim, reverse=False)[source]
Applies an associative scan operation along a specified dimension.
Computes the prefix scan (cumulative operation) along a dimension using a custom combine function. Unlike
reduce()
, this preserves the input shape.- Parameters:
combine_fn (
Union
[Callable
[[Tensor
,Tensor
],Tensor
],Callable
[...
,tuple
[Tensor
,...
]]]) – A binary function that combines two elements element-wise. Must be associative for correct results. Can be tensor->tensor or tuple->tuple function.input_tensor (
Tensor
|tuple
[Tensor
,...
]) – Input tensor or tuple of tensors to scandim (
int
) – The dimension along which to scanreverse (
bool
) – If True, performs the scan in reverse order
- Returns:
- Tensor(s) with same shape as input
containing the scan result
- Return type:
torch.Tensor or tuple[torch.Tensor, …]
See also
reduce()
: For dimension-reducing operationscumsum()
: For cumulative sumcumprod()
: For cumulative product
Note
combine_fn must be associative (not necessarily commutative)
Output has same shape as input (unlike reduce)
For standard scans, use
cumsum()
orcumprod()
(faster)Reverse scan applies the operation from right to left