Config
The Config
class represents kernel optimization parameters that control how Helion kernels are compiled and executed.
- class helion.Config(*, block_sizes=None, loop_orders=None, flatten_loops=None, l2_groupings=None, reduction_loops=None, range_unroll_factors=None, range_warp_specializes=None, range_num_stages=None, range_multi_buffers=None, range_flattens=None, static_ranges=None, num_warps=None, num_stages=None, pid_type=None, indexing=None, **kwargs)[source]
-
- Parameters:
- __init__(*, block_sizes=None, loop_orders=None, flatten_loops=None, l2_groupings=None, reduction_loops=None, range_unroll_factors=None, range_warp_specializes=None, range_num_stages=None, range_multi_buffers=None, range_flattens=None, static_ranges=None, num_warps=None, num_stages=None, pid_type=None, indexing=None, **kwargs)[source]
Initialize a Config object.
- Parameters:
block_sizes (
list
[int
] |None
) – Controls tile sizes for hl.tile invocations.loop_orders (
list
[list
[int
]] |None
) – Permutes iteration order of tiles.l2_groupings (
list
[int
] |None
) – Reorders program IDs for L2 cache locality.reduction_loops (
list
[int
|None
] |None
) – Configures reduction loop behavior.range_unroll_factors (
list
[int
] |None
) – Loop unroll factors for tl.range calls.range_warp_specializes (
list
[bool
|None
] |None
) – Warp specialization for tl.range calls.range_num_stages (
list
[int
] |None
) – Number of stages for tl.range calls.range_multi_buffers (
list
[bool
|None
] |None
) – Controls disallow_acc_multi_buffer for tl.range calls.range_flattens (
list
[bool
|None
] |None
) – Controls flatten parameter for tl.range calls.static_ranges (
list
[bool
] |None
) – Whether to use tl.static_range instead tl.range.num_stages (
int
|None
) – Number of stages for software pipelining.pid_type (
Optional
[Literal
['flat'
,'xyz'
,'persistent_blocked'
,'persistent_interleaved'
]]) – Program ID type strategy (“flat”, “xyz”, “persistent_blocked”, “persistent_interleaved”).indexing (
Optional
[Literal
['pointer'
,'tensor_descriptor'
,'block_ptr'
]]) – Indexing strategy (“pointer”, “tensor_descriptor”, “block_ptr”).**kwargs (
object
) – Additional user-defined configuration parameters.
Overview
Config objects specify optimization parameters that control how Helion kernels run on the hardware.
Key Characteristics
Performance-focused: Control GPU resource allocation, memory access patterns, and execution strategies
Autotuned: The autotuner searches through different Config combinations to find optimal performance
Kernel-specific: Each kernel can have different optimal Config parameters based on its computation pattern
Hardware-dependent: Optimal configs vary based on GPU architecture and problem size
Config vs Settings
Aspect |
Config |
Settings |
---|---|---|
Purpose |
Control execution performance |
Control compilation behavior |
Autotuning |
✅ Automatically optimized |
❌ Never autotuned |
Examples |
|
|
When to use |
Performance optimization |
Development, debugging, environment setup |
Configs are typically discovered automatically through autotuning, but can also be manually specified for more control.
Configuration Parameters
Block Sizes and Resources
- Config.block_sizes
List of tile sizes for
hl.tile()
loops. Each value controls the number of elements processed per GPU thread block for the corresponding tile dimension.
- Config.reduction_loops
Configuration for reduction operations within loops.
- Config.num_warps
Number of warps (groups of 32 threads) per thread block. Higher values increase parallelism but may reduce occupancy.
- Config.num_stages
Number of pipeline stages for software pipelining. Higher values can improve memory bandwidth utilization.
Loop Optimizations
- Config.loop_orders
Permutation of loop iteration order for each
hl.tile()
loop. Used to optimize memory access patterns.
- Config.flatten_loops
Whether to flatten nested loops for each
hl.tile()
invocation.
- Config.range_unroll_factors
Unroll factors for
tl.range
loops in generated Triton code.
- Config.range_warp_specializes
Whether to enable warp specialization for
tl.range
loops.
- Config.range_num_stages
Number of pipeline stages for
tl.range
loops.
- Config.range_multi_buffers
Controls
disallow_acc_multi_buffer
parameter fortl.range
loops.
- Config.range_flattens
Controls
flatten
parameter fortl.range
loops.
- Config.static_ranges
Whether to use
tl.static_range
instead oftl.range
.
Execution and Indexing
- Config.pid_type
Program ID layout strategy:
"flat"
: Standard linear program ID assignment"xyz"
: 3D program ID layout"persistent_blocked"
: Persistent kernels with blocked work distribution"persistent_interleaved"
: Persistent kernels with interleaved distribution
- Config.l2_groupings
Controls reordering of program IDs to improve L2 cache locality.
- Config.indexing
Memory indexing strategy:
"pointer"
: Pointer-based indexing"tensor_descriptor"
: Tensor descriptor indexing"block_ptr"
: Block pointer indexing
Usage Examples
Manual Config Creation
import torch
import helion
import helion.language as hl
# Create a specific configuration
config = helion.Config(
block_sizes=[64, 32], # 64 elements per tile in dim 0, 32 in dim 1
num_warps=8, # Use 8 warps (256 threads) per block
num_stages=4, # 4-stage pipeline
pid_type="xyz" # Use 3D program ID layout
)
# Use with kernel
@helion.kernel(config=config)
def my_kernel(x: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)
for i, j in hl.tile(x.shape):
result[i, j] = x[i, j] * 2
return result
Config Serialization
# Save config to file
config.save("my_config.json")
# Load config from file
loaded_config = helion.Config.load("my_config.json")
# JSON serialization
config_dict = config.to_json()
restored_config = helion.Config.from_json(config_dict)
Autotuning with Restricted Configs
# Restrict autotuning to specific configurations
configs = [
helion.Config(block_sizes=[32, 32], num_warps=4),
helion.Config(block_sizes=[64, 16], num_warps=8),
helion.Config(block_sizes=[16, 64], num_warps=4),
]
@helion.kernel(configs=configs)
def matrix_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
m, k = a.size()
k2, n = b.size()
assert k == k2, f"size mismatch {k} != {k2}"
out = torch.empty([m, n], dtype=a.dtype, device=a.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
out[tile_m, tile_n] = acc
return out
See Also
Settings - Compilation settings and environment variables
Kernel - Kernel execution and autotuning
Autotuner Module - Autotuning configuration