Language Module

The helion.language module contains the core DSL constructs for writing GPU kernels.

Loop Constructs

tile()

helion.language.tile(begin_or_end, end_or_none=None, /, block_size=None)[source]

Break up an iteration space defined by a size or sequence of sizes into tiles.

The generated tiles can flatten the iteration space into the product of the sizes, perform multidimensional tiling, swizzle the indices for cache locality, reorder dimensions, etc. The only invariant is that every index in the range of the given sizes is covered exactly once.

The exact tiling strategy is determined by a Config object, typically created through autotuning.

If used at the top level of a function, this becomes the grid of the kernel. Otherwise, it becomes a loop in the output kernel.

The key difference from grid() is that tile gives you Tile objects that load a slice of elements, while grid gives you scalar integer indices. It is recommended to use tile in most cases, since it allows more choices in autotuning.

Parameters:
  • begin_or_end (int | Tensor | Sequence[int | Tensor]) – If 2+ positional args provided, the start of iteration space. Otherwise, the end of iteration space.

  • end_or_none (int | Tensor | Sequence[int | Tensor] | None) – If 2+ positional args provided, the end of iteration space.

  • block_size (object) – Fixed block size (overrides autotuning) or None for autotuned size

Returns:

Iterator over tile objects

Return type:

Iterator[Tile] or Iterator[Sequence[Tile]]

Examples

One dimensional tiling:

@helion.kernel
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    result = torch.zeros_like(x)

    for tile in hl.tile(x.size(0)):
        # tile processes multiple elements at once
        result[tile] = x[tile] + y[tile]

    return result

Multi-dimensional tiling:

@helion.kernel()
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    m, k = x.size()
    k, n = y.size()
    out = torch.empty([m, n], dtype=x.dtype, device=x.device)

    for tile_m, tile_n in hl.tile([m, n]):
        acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
        for tile_k in hl.tile(k):
            acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
        out[tile_m, tile_n] = acc


return out

Fixed block size:

@helion.kernel
def process_with_fixed_block(x: torch.Tensor) -> torch.Tensor:
    result = torch.zeros_like(x)

    for tile in hl.tile(x.size(0), block_size=64):
        # Process with fixed block size of 64
        result[tile] = x[tile] * 2

    return result

Using tile properties:

@helion.kernel
def tile_info_example(x: torch.Tensor) -> torch.Tensor:
    result = torch.zeros([x.size(0)], dtype=x.dtype, device=x.device)

    for tile in hl.tile(x.size(0)):
        # Access tile properties
        start = tile.begin
        end = tile.end
        size = tile.block_size
        indices = tile.index  # [start, start+1, ..., end-1]

        # Use in computation
        result[tile] = x[tile] + indices

    return result

See also

Note

Similar to range() with multiple forms:

  • tile(end) iterates 0 to end-1, autotuned block_size

  • tile(begin, end) iterates begin to end-1, autotuned block_size

  • tile(begin, end, block_size) iterates begin to end-1, fixed block_size

  • tile(end, block_size=block_size) iterates 0 to end-1, fixed block_size

Block sizes can be registered for autotuning explicitly with register_block_size() and passed as the block_size argument if one needs two loops to use the same block size. Passing block_size=None is equivalent to calling register_block_size.

Use tile in most cases. Use grid when you need explicit control over the launch grid.

The tile() function is the primary way to create parallel loops in Helion kernels. It provides several key features:

Tiling Strategies: The exact tiling strategy is determined by a Config object, typically created through autotuning. This allows for:

  • Multidimensional tiling

  • Index swizzling for cache locality

  • Dimension reordering

  • Flattening of iteration spaces

Usage Patterns:

# Simple 1D tiling
for tile in hl.tile(1000):
    # tile.begin, tile.end, tile.block_size are available
    # Load entire tile (not just first element)
    data = tensor[tile]  # or hl.load(tensor, tile) for explicit loading
# 2D tiling
for tile_i, tile_j in hl.tile([height, width]):
    # Each tile represents a portion of the 2D space
    pass
# With explicit begin/end/block_size
for tile in hl.tile(0, 1000, block_size=64):
    pass

Grid vs Loop Behavior:

  • When used at the top level of a kernel function, tile() becomes the grid of the kernel (parallel blocks)

  • When used nested inside another loop, it becomes a sequential loop within each block

grid()

helion.language.grid(begin_or_end, end_or_none=None, /, step=None)[source]

Iterate over individual indices of the given iteration space.

The key difference from tile() is that grid gives you scalar integer indices (torch.SymInt), while tile gives you Tile objects that load a slice of elements. Use tile in most cases. Use grid when you need explicit control over the launch grid or when processing one element at a time.

Semantics are equivalent to:

for i in hl.tile(...):
    # i is a Tile object, accesses multiple elements
    data = tensor[i]  # loads slice of elements (1D tensor)

vs:

for i in hl.grid(...):
    # i is a scalar index, accesses single element
    data = tensor[i]  # loads single element (0D scalar)

When used at the top level of a function, this becomes the grid of the kernel. Otherwise, it becomes a loop in the output kernel.

Parameters:
  • begin_or_end (int | Tensor | Sequence[int | Tensor]) – If 2+ positional args provided, the start of iteration space. Otherwise, the end of iteration space.

  • end_or_none (int | Tensor | Sequence[int | Tensor] | None) – If 2+ positional args provided, the end of iteration space.

  • step (object) – Step size for iteration (default: 1)

Returns:

Iterator over scalar indices

Return type:

Iterator[torch.SymInt] or Iterator[Sequence[torch.SymInt]]

See also

Note

Similar to range() with multiple forms:

  • grid(end) iterates from 0 to end-1, step 1

  • grid(begin, end) iterates from begin to end-1, step 1

  • grid(begin, end, step) iterates from begin to end-1, given step

  • grid(end, step=step) iterates from 0 to end-1, given step

Use tile in most cases. Use grid when you need explicit control over the launch grid.

The grid() function iterates over individual indices rather than tiles. It’s equivalent to tile(size, block_size=1) but returns scalar indices instead of tile objects.

Memory Operations

load()

helion.language.load(tensor, index, extra_mask=None)[source]

Load a value from a tensor using a list of indices.

This function is equivalent to tensor[index] but allows setting extra_mask= to mask elements beyond the default masking based on the hl.tile range.

Parameters:
  • tensor (Tensor) – The tensor to load from

  • index (list[object]) – The indices to use to index into the tensor

  • extra_mask (Tensor | None) – The extra mask (beyond automatic tile bounds masking) to apply to the tensor

Returns:

The loaded value

Return type:

torch.Tensor

store()

helion.language.store(tensor, index, value, extra_mask=None)[source]

Store a value to a tensor using a list of indices.

This function is equivalent to tensor[index] = value but allows setting extra_mask= to mask elements beyond the default masking based on the hl.tile range.

Parameters:
  • tensor (Tensor) – The tensor to store to

  • index (list[object]) – The indices to use to index into the tensor

  • value (Tensor | SymInt | float) – The value to store

  • extra_mask (Tensor | None) – The extra mask (beyond automatic tile bounds masking) to apply to the tensor

Return type:

None

Returns:

None

atomic_add()

helion.language.atomic_add(target, index, value, sem='relaxed')[source]

Atomically add a value to a target tensor.

Performs an atomic read-modify-write operation that adds value to target[index]. This is safe for concurrent access from multiple threads/blocks.

Parameters:
  • target (Tensor) – The tensor to add to

  • index (list[object]) – Indices into target for accumulating values

  • value (Tensor | float) – The value to add (tensor or scalar)

  • sem (str) – Memory ordering semantics (default: ‘relaxed’) - ‘relaxed’: No ordering constraints - ‘acquire’: Acquire semantics - ‘release’: Release semantics - ‘acq_rel’: Acquire-release semantics

Return type:

None

Returns:

None

Examples

@helion.kernel
def global_sum(x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
    # Each tile computes local sum, then atomically adds to global
    for tile in hl.tile(x.size(0)):
        local_data = x[tile]
        local_sum = local_data.sum()
        hl.atomic_add(result, [0], local_sum)

    return result

See also

Note

  • Required for race-free accumulation across parallel execution

  • Performance depends on memory access patterns and contention

  • Consider using regular operations when atomicity isn’t needed

  • Higher memory semantics (acquire/release) have performance overhead

Tensor Creation

zeros()

helion.language.zeros(shape, dtype=torch.float32)[source]

Return a device-tensor filled with zeros.

Equivalent to hl.full(shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype).

Note

Only use within hl.tile() loops for creating local tensors. For output tensor creation, use torch.zeros() with proper device placement.

Parameters:
  • shape (list[object]) – A list of sizes (or tile indices which are implicitly converted to sizes)

  • dtype (dtype) – Data type of the tensor (default: torch.float32)

Returns:

A device tensor of the given shape and dtype filled with zeros

Return type:

torch.Tensor

Examples

@helion.kernel
def process_kernel(input: torch.Tensor) -> torch.Tensor:
    result = torch.empty_like(input)

    for tile in hl.tile(input.size(0)):
        buffer = hl.zeros([tile], dtype=input.dtype)  # Local buffer
        buffer += input[tile]  # Add input values to buffer
        result[tile] = buffer

    return result

See also

  • full(): For filling with arbitrary values

  • arange(): For creating sequences

full()

helion.language.full(shape, value, dtype=torch.float32)[source]

Create a device-tensor filled with a specified value.

Note

Only use within hl.tile() loops for creating local tensors. For output tensor creation, use torch.full() with proper device placement.

Parameters:
  • shape (list[object]) – A list of sizes (or tile indices which are implicitly converted to sizes)

  • value (float) – The value to fill the tensor with

  • dtype (dtype) – The data type of the tensor (default: torch.float32)

Returns:

A device tensor of the given shape and dtype filled with value

Return type:

torch.Tensor

Examples

@helion.kernel
def process_kernel(input: torch.Tensor) -> torch.Tensor:
    result = torch.empty_like(input)

    for tile in hl.tile(input.size(0)):
        # Create local buffer filled with initial value
        buffer = hl.full([tile], 0.0, dtype=input.dtype)
        buffer += input[tile]  # Add input values to buffer
        result[tile] = buffer

    return result

See also

arange()

See arange() for details.

Tunable Parameters

register_block_size()

helion.language.register_block_size(min_or_max, max_or_none=None, /)[source]

Explicitly register a block size that should be autotuned and can be used for allocations and inside hl.tile(…, block_size=…).

This is useful if you have two loops where you want them to share a block size, or if you need to allocate a kernel tensor before the hl.tile() loop.

The signature can one of:

hl.register_block_size(max) hl.register_block_size(min, max)

Where min and max are integers that control the range of block_sizes searched by the autotuner. Max may be a symbolic shape, but min must be a constant integer.

Parameters:
Return type:

int

register_tunable()

helion.language.register_tunable(name, fragment)[source]

Register a tunable parameter for autotuning.

This function allows you to define parameters that can be automatically tuned during the autotuning process. The fragment defines the search space and default value.

Parameters:
  • name (str) – The key for the tunable parameter in the Config().

  • fragment (ConfigSpecFragment) – A ConfigSpecFragment that defines the search space (e.g., PowerOfTwoFragment)

Returns:

The value assigned to this tunable parameter in the current configuration.

Return type:

int

register_reduction_dim()

See register_reduction_dim() for details.

Tile Operations

Tile Class

class helion.language.Tile(block_id)[source]

This class should not be instantiated directly, it is the result of hl.tile(…) and represents a single tile of the iteration space.

Tile’s can be used as indices to tensors, e.g. tensor[tile]. Tile’s can also be use as sizes for allocations, e.g. torch.empty([tile]). There are also properties such as tile.index, tile.begin, tile.end, tile.id and tile.block_size that can be used to retrieve various information about the tile.

Masking is implicit for tiles, so if the final tile is smaller than the block size loading that tile will only load the valid elements and reduction operations know to ignore the invalid elements.

Parameters:

block_id (int)

__init__(block_id)[source]
Parameters:

block_id (int)

property index: Tensor

Alias for tile_index(), which retrieves a tensor containing the offsets for a tile.

property begin: int

Alias for tile_begin(), which retrieves the start offset of a tile.

property end: int

Alias for tile_end(), which retrieves the end offset of a tile.

property block_size: int

Alias for tile_block_size(), which retrieves the block_size of a tile.

property id: int

Alias for tile_id(), which retrieves the id of a tile.

The Tile class represents a portion of an iteration space with the following key attributes:

  • begin: Starting indices of the tile

  • end: Ending indices of the tile

  • block_size: Size of the tile in each dimension

View Operations

subscript()

helion.language.subscript(tensor, index)[source]

Equivalent to tensor[index] where tensor is a kernel-tensor (not a host-tensor).

Can be used to add dimensions to the tensor, e.g. tensor[None, :] or tensor[:, None].

Parameters:
  • tensor (Tensor) – The kernel tensor to index

  • index (list[object]) – List of indices, including None for new dimensions and : for existing dimensions

Returns:

The indexed tensor with potentially modified dimensions

Return type:

torch.Tensor

Examples

@helion.kernel
def broadcast_multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (N,), y has shape (M,)
    result = torch.empty(
        [x.size(0), y.size(0)], dtype=x.dtype, device=x.device
    )

    for tile_i, tile_j in hl.tile([x.size(0), y.size(0)]):
        # Get tile data
        x_tile = x[tile_i]
        y_tile = y[tile_j]

        # Make x broadcastable: (tile_size, 1)
        # same as hl.subscript(x_tile, [slice(None), None])
        x_expanded = x_tile[:, None]
        # Make y broadcastable: (1, tile_size)
        # same as hl.subscript(y_tile, [None, slice(None)])
        y_expanded = y_tile[None, :]

        result[tile_i, tile_j] = x_expanded * y_expanded

    return result

See also

  • load(): For loading tensor values

  • store(): For storing tensor values

Note

  • Only supports None and : (slice(None)) indexing

  • Used for reshaping kernel tensors by adding dimensions

  • Prefer direct indexing syntax when possible: tensor[None, :]

  • Does not support integer indexing or slicing with start/stop

Reduction Operations

reduce()

See reduce() for details.

Scan Operations

associative_scan()

See associative_scan() for details.

cumsum()

See cumsum() for details.

cumprod()

See cumprod() for details.

tile_index()

helion.language.tile_index(tile)[source]

Retrieve the index (a 1D tensor containing offsets) of the given tile. This can also be written as: tile.index.

Example usage:

@helion.kernel
def arange(length: int, device: torch.device) -> torch.Tensor:
    out = torch.empty(length, dtype=torch.int32, device=device)
    for tile in hl.tile(length):
        out[tile] = tile.index
    return out
Parameters:

tile (Tile)

Return type:

Tensor

tile_begin()

helion.language.tile_begin(tile)[source]

Retrieve the start offset of the given tile. This can also be written as: tile.begin.

Parameters:

tile (Tile)

Return type:

int

tile_end()

helion.language.tile_end(tile)[source]

Retrieve the end offset of the given tile. For the first 0 to N-1 tiles, this is equivalent to tile.begin + tile.block_size. For the last tile, this is the end offset passed to hl.tile(). This can also be written as: tile.end.

Parameters:

tile (Tile)

Return type:

int

tile_block_size()

helion.language.tile_block_size(tile)[source]

Retrieve block size of a given tile, usually set the autotuner. This can also be written as: tile.block_size.

Parameters:

tile (Tile)

Return type:

int

tile_id()

helion.language.tile_id(tile)[source]

Retrieve tile_id of a given tile or list of tiles. This is equivalent to tile.begin // tile.block_size. This can also be written as: tile.id.

Parameters:

tile (Tile)

Return type:

int

Synchronization

wait()

See wait() for details.

Utilities

device_print()

See device_print() for details.

Constexpr Operations

constexpr()

See constexpr for details.

specialize()

See specialize() for details.