Language Module
The helion.language
module contains the core DSL constructs for writing GPU kernels.
Loop Constructs
tile()
- helion.language.tile(begin_or_end, end_or_none=None, /, block_size=None)[source]
Break up an iteration space defined by a size or sequence of sizes into tiles.
The generated tiles can flatten the iteration space into the product of the sizes, perform multidimensional tiling, swizzle the indices for cache locality, reorder dimensions, etc. The only invariant is that every index in the range of the given sizes is covered exactly once.
The exact tiling strategy is determined by a Config object, typically created through autotuning.
If used at the top level of a function, this becomes the grid of the kernel. Otherwise, it becomes a loop in the output kernel.
The key difference from
grid()
is thattile
gives youTile
objects that load a slice of elements, whilegrid
gives you scalar integer indices. It is recommended to usetile
in most cases, since it allows more choices in autotuning.- Parameters:
begin_or_end (
int
|Tensor
|Sequence
[int
|Tensor
]) – If 2+ positional args provided, the start of iteration space. Otherwise, the end of iteration space.end_or_none (
int
|Tensor
|Sequence
[int
|Tensor
] |None
) – If 2+ positional args provided, the end of iteration space.block_size (
object
) – Fixed block size (overrides autotuning) or None for autotuned size
- Returns:
Iterator over tile objects
- Return type:
Examples
One dimensional tiling:
@helion.kernel def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: result = torch.zeros_like(x) for tile in hl.tile(x.size(0)): # tile processes multiple elements at once result[tile] = x[tile] + y[tile] return result
Multi-dimensional tiling:
@helion.kernel() def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: m, k = x.size() k, n = y.size() out = torch.empty([m, n], dtype=x.dtype, device=x.device) for tile_m, tile_n in hl.tile([m, n]): acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) for tile_k in hl.tile(k): acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) out[tile_m, tile_n] = acc return out
Fixed block size:
@helion.kernel def process_with_fixed_block(x: torch.Tensor) -> torch.Tensor: result = torch.zeros_like(x) for tile in hl.tile(x.size(0), block_size=64): # Process with fixed block size of 64 result[tile] = x[tile] * 2 return result
Using tile properties:
@helion.kernel def tile_info_example(x: torch.Tensor) -> torch.Tensor: result = torch.zeros([x.size(0)], dtype=x.dtype, device=x.device) for tile in hl.tile(x.size(0)): # Access tile properties start = tile.begin end = tile.end size = tile.block_size indices = tile.index # [start, start+1, ..., end-1] # Use in computation result[tile] = x[tile] + indices return result
See also
grid()
: For explicit control over the launch gridtile_index()
: For getting tile indicesregister_block_size()
: For registering block sizes
Note
Similar to
range()
with multiple forms:tile(end) iterates 0 to end-1, autotuned block_size
tile(begin, end) iterates begin to end-1, autotuned block_size
tile(begin, end, block_size) iterates begin to end-1, fixed block_size
tile(end, block_size=block_size) iterates 0 to end-1, fixed block_size
Block sizes can be registered for autotuning explicitly with
register_block_size()
and passed as theblock_size
argument if one needs two loops to use the same block size. Passingblock_size=None
is equivalent to calling register_block_size.Use
tile
in most cases. Usegrid
when you need explicit control over the launch grid.
The tile()
function is the primary way to create parallel loops in Helion kernels. It provides several key features:
Tiling Strategies: The exact tiling strategy is determined by a Config object, typically created through autotuning. This allows for:
Multidimensional tiling
Index swizzling for cache locality
Dimension reordering
Flattening of iteration spaces
Usage Patterns:
# Simple 1D tiling
for tile in hl.tile(1000):
# tile.begin, tile.end, tile.block_size are available
# Load entire tile (not just first element)
data = tensor[tile] # or hl.load(tensor, tile) for explicit loading
# 2D tiling
for tile_i, tile_j in hl.tile([height, width]):
# Each tile represents a portion of the 2D space
pass
# With explicit begin/end/block_size
for tile in hl.tile(0, 1000, block_size=64):
pass
Grid vs Loop Behavior:
When used at the top level of a kernel function,
tile()
becomes the grid of the kernel (parallel blocks)When used nested inside another loop, it becomes a sequential loop within each block
grid()
- helion.language.grid(begin_or_end, end_or_none=None, /, step=None)[source]
Iterate over individual indices of the given iteration space.
The key difference from
tile()
is thatgrid
gives you scalar integer indices (torch.SymInt
), whiletile
gives youTile
objects that load a slice of elements. Usetile
in most cases. Usegrid
when you need explicit control over the launch grid or when processing one element at a time.Semantics are equivalent to:
for i in hl.tile(...): # i is a Tile object, accesses multiple elements data = tensor[i] # loads slice of elements (1D tensor)
vs:
for i in hl.grid(...): # i is a scalar index, accesses single element data = tensor[i] # loads single element (0D scalar)
When used at the top level of a function, this becomes the grid of the kernel. Otherwise, it becomes a loop in the output kernel.
- Parameters:
begin_or_end (
int
|Tensor
|Sequence
[int
|Tensor
]) – If 2+ positional args provided, the start of iteration space. Otherwise, the end of iteration space.end_or_none (
int
|Tensor
|Sequence
[int
|Tensor
] |None
) – If 2+ positional args provided, the end of iteration space.step (
object
) – Step size for iteration (default: 1)
- Returns:
Iterator over scalar indices
- Return type:
Iterator[torch.SymInt] or Iterator[Sequence[torch.SymInt]]
See also
tile()
: For processing multiple elements at oncetile_index()
: For getting tile indicesarange()
: For creating index sequences
Note
Similar to
range()
with multiple forms:grid(end) iterates from 0 to end-1, step 1
grid(begin, end) iterates from begin to end-1, step 1
grid(begin, end, step) iterates from begin to end-1, given step
grid(end, step=step) iterates from 0 to end-1, given step
Use
tile
in most cases. Usegrid
when you need explicit control over the launch grid.
The grid()
function iterates over individual indices rather than tiles. It’s equivalent to tile(size, block_size=1)
but returns scalar indices instead of tile objects.
Memory Operations
load()
store()
- helion.language.store(tensor, index, value, extra_mask=None)[source]
Store a value to a tensor using a list of indices.
This function is equivalent to tensor[index] = value but allows setting extra_mask= to mask elements beyond the default masking based on the hl.tile range.
- Parameters:
- Return type:
- Returns:
None
atomic_add()
- helion.language.atomic_add(target, index, value, sem='relaxed')[source]
Atomically add a value to a target tensor.
Performs an atomic read-modify-write operation that adds value to target[index]. This is safe for concurrent access from multiple threads/blocks.
- Parameters:
target (
Tensor
) – The tensor to add toindex (
list
[object
]) – Indices into target for accumulating valuesvalue (
Tensor
|float
) – The value to add (tensor or scalar)sem (
str
) – Memory ordering semantics (default: ‘relaxed’) - ‘relaxed’: No ordering constraints - ‘acquire’: Acquire semantics - ‘release’: Release semantics - ‘acq_rel’: Acquire-release semantics
- Return type:
- Returns:
None
Examples
@helion.kernel def global_sum(x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: # Each tile computes local sum, then atomically adds to global for tile in hl.tile(x.size(0)): local_data = x[tile] local_sum = local_data.sum() hl.atomic_add(result, [0], local_sum) return result
Note
Required for race-free accumulation across parallel execution
Performance depends on memory access patterns and contention
Consider using regular operations when atomicity isn’t needed
Higher memory semantics (acquire/release) have performance overhead
Tensor Creation
zeros()
- helion.language.zeros(shape, dtype=torch.float32)[source]
Return a device-tensor filled with zeros.
Equivalent to
hl.full(shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype)
.Note
Only use within
hl.tile()
loops for creating local tensors. For output tensor creation, usetorch.zeros()
with proper device placement.- Parameters:
- Returns:
A device tensor of the given shape and dtype filled with zeros
- Return type:
Examples
@helion.kernel def process_kernel(input: torch.Tensor) -> torch.Tensor: result = torch.empty_like(input) for tile in hl.tile(input.size(0)): buffer = hl.zeros([tile], dtype=input.dtype) # Local buffer buffer += input[tile] # Add input values to buffer result[tile] = buffer return result
full()
- helion.language.full(shape, value, dtype=torch.float32)[source]
Create a device-tensor filled with a specified value.
Note
Only use within
hl.tile()
loops for creating local tensors. For output tensor creation, usetorch.full()
with proper device placement.- Parameters:
- Returns:
A device tensor of the given shape and dtype filled with value
- Return type:
Examples
@helion.kernel def process_kernel(input: torch.Tensor) -> torch.Tensor: result = torch.empty_like(input) for tile in hl.tile(input.size(0)): # Create local buffer filled with initial value buffer = hl.full([tile], 0.0, dtype=input.dtype) buffer += input[tile] # Add input values to buffer result[tile] = buffer return result
arange()
See arange()
for details.
Tunable Parameters
register_block_size()
- helion.language.register_block_size(min_or_max, max_or_none=None, /)[source]
Explicitly register a block size that should be autotuned and can be used for allocations and inside hl.tile(…, block_size=…).
This is useful if you have two loops where you want them to share a block size, or if you need to allocate a kernel tensor before the hl.tile() loop.
- The signature can one of:
hl.register_block_size(max) hl.register_block_size(min, max)
Where min and max are integers that control the range of block_sizes searched by the autotuner. Max may be a symbolic shape, but min must be a constant integer.
register_tunable()
register_reduction_dim()
See register_reduction_dim()
for details.
Tile Operations
Tile Class
- class helion.language.Tile(block_id)[source]
This class should not be instantiated directly, it is the result of hl.tile(…) and represents a single tile of the iteration space.
Tile’s can be used as indices to tensors, e.g. tensor[tile]. Tile’s can also be use as sizes for allocations, e.g. torch.empty([tile]). There are also properties such as
tile.index
,tile.begin
,tile.end
,tile.id
andtile.block_size
that can be used to retrieve various information about the tile.Masking is implicit for tiles, so if the final tile is smaller than the block size loading that tile will only load the valid elements and reduction operations know to ignore the invalid elements.
- Parameters:
block_id (
int
)
- property index: Tensor
Alias for
tile_index()
, which retrieves a tensor containing the offsets for a tile.
- property begin: int
Alias for
tile_begin()
, which retrieves the start offset of a tile.
- property end: int
Alias for
tile_end()
, which retrieves the end offset of a tile.
- property block_size: int
Alias for
tile_block_size()
, which retrieves the block_size of a tile.
The Tile
class represents a portion of an iteration space with the following key attributes:
begin
: Starting indices of the tileend
: Ending indices of the tileblock_size
: Size of the tile in each dimension
View Operations
subscript()
- helion.language.subscript(tensor, index)[source]
Equivalent to tensor[index] where tensor is a kernel-tensor (not a host-tensor).
Can be used to add dimensions to the tensor, e.g. tensor[None, :] or tensor[:, None].
- Parameters:
- Returns:
The indexed tensor with potentially modified dimensions
- Return type:
Examples
@helion.kernel def broadcast_multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # x has shape (N,), y has shape (M,) result = torch.empty( [x.size(0), y.size(0)], dtype=x.dtype, device=x.device ) for tile_i, tile_j in hl.tile([x.size(0), y.size(0)]): # Get tile data x_tile = x[tile_i] y_tile = y[tile_j] # Make x broadcastable: (tile_size, 1) # same as hl.subscript(x_tile, [slice(None), None]) x_expanded = x_tile[:, None] # Make y broadcastable: (1, tile_size) # same as hl.subscript(y_tile, [None, slice(None)]) y_expanded = y_tile[None, :] result[tile_i, tile_j] = x_expanded * y_expanded return result
Note
Only supports None and : (slice(None)) indexing
Used for reshaping kernel tensors by adding dimensions
Prefer direct indexing syntax when possible:
tensor[None, :]
Does not support integer indexing or slicing with start/stop
Reduction Operations
reduce()
See reduce()
for details.
Scan Operations
associative_scan()
See associative_scan()
for details.
cumsum()
See cumsum()
for details.
cumprod()
See cumprod()
for details.
tile_index()
- helion.language.tile_index(tile)[source]
Retrieve the index (a 1D tensor containing offsets) of the given tile. This can also be written as: tile.index.
Example usage:
@helion.kernel def arange(length: int, device: torch.device) -> torch.Tensor: out = torch.empty(length, dtype=torch.int32, device=device) for tile in hl.tile(length): out[tile] = tile.index return out
tile_begin()
tile_end()
tile_block_size()
tile_id()
Synchronization
wait()
See wait()
for details.
Utilities
device_print()
See device_print()
for details.
Constexpr Operations
constexpr()
See constexpr
for details.
specialize()
See specialize()
for details.