from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from .. import exc
from .._compiler.ast_extension import expr_from_string
from .._compiler.compile_environment import CompileEnvironment
from . import _decorators
if TYPE_CHECKING:
import ast
from .._compiler.inductor_lowering import CodegenState
from .loops import Tile
[docs]
@_decorators.api(tiles_as_sizes=True)
def tile_index(tile: Tile) -> torch.Tensor:
"""
Retrieve the index (a 1D tensor containing offsets) of the given tile.
This can also be written as: `tile.index`.
Example usage::
@helion.kernel
def arange(length: int, device: torch.device) -> torch.Tensor:
out = torch.empty(length, dtype=torch.int32, device=device)
for tile in hl.tile(length):
out[tile] = tile.index
return out
"""
raise exc.NotInsideKernel
@_decorators.register_fake(tile_index)
def _(tile: torch.SymInt) -> torch.Tensor:
assert isinstance(tile, torch.SymInt)
env = CompileEnvironment.current()
assert env.get_block_id(tile) is not None
return torch.empty([tile], dtype=env.settings.index_dtype, device=env.device)
@_decorators.codegen(tile_index)
def _(state: CodegenState) -> ast.AST:
index = _disable_flatten_get_tile(state.proxy_arg(0))
return expr_from_string(state.codegen.index_var(index))
[docs]
@_decorators.api(tiles_as_sizes=True)
def tile_begin(tile: Tile) -> int:
"""
Retrieve the start offset of the given tile.
This can also be written as: `tile.begin`.
"""
raise exc.NotInsideKernel
@_decorators.register_fake(tile_begin)
def _(tile: torch.SymInt) -> torch.SymInt:
_disable_flatten_get_tile(tile) # update config spec if needed
return CompileEnvironment.current().cached_create_unbacked_symint(
("tile_begin", tile)
)
def _disable_flatten_get_tile(tile: object) -> int:
"""Helper to extract tile index from state."""
assert isinstance(tile, torch.SymInt), (type(tile), tile)
env = CompileEnvironment.current()
index = env.get_block_id(tile)
assert index is not None
# The functions in this file can't be used in flattened loops.
env.config_spec.flatten_loops.disable_block_id(index)
return index
@_decorators.codegen(tile_begin)
def _(state: CodegenState) -> ast.AST:
index = _disable_flatten_get_tile(state.proxy_arg(0))
return expr_from_string(state.codegen.offset_var(index))
[docs]
@_decorators.api(tiles_as_sizes=True)
def tile_end(tile: Tile) -> int:
"""
Retrieve the end offset of the given tile.
For the first 0 to N-1 tiles, this is equivalent to `tile.begin + tile.block_size`.
For the last tile, this is the end offset passed to `hl.tile()`.
This can also be written as: `tile.end`.
"""
raise exc.NotInsideKernel
@_decorators.register_fake(tile_end)
def _(tile: torch.SymInt) -> torch.SymInt:
_disable_flatten_get_tile(tile) # update config spec if needed
return CompileEnvironment.current().cached_create_unbacked_symint(
("tile_end", tile)
)
@_decorators.codegen(tile_end)
def _(state: CodegenState) -> ast.AST:
index = _disable_flatten_get_tile(state.proxy_arg(0))
offset_var = state.codegen.offset_var(index)
block_size_var = state.device_function.block_size_var(index)
if block_size_var is None:
block_size_var = "1"
naive_exp = f"{offset_var} + {block_size_var}"
if state.codegen.mask_var(index) is not None:
# if masking is used, we must update the end bound of the last tile
end_var = (
state.codegen.active_device_loops[index][-1]
.block_id_to_info[index]
.end_var_name
)
return expr_from_string(f"tl.minimum({naive_exp}, {end_var})")
# If we don't have a mask, we can simply return the offset + block size
return expr_from_string(naive_exp)
[docs]
@_decorators.api(tiles_as_sizes=True)
def tile_block_size(tile: Tile) -> int:
"""
Retrieve block size of a given tile, usually set the autotuner.
This can also be written as: `tile.block_size`.
"""
raise exc.NotInsideKernel
@_decorators.register_fake(tile_block_size)
def _(tile: torch.SymInt) -> torch.SymInt:
return tile
# since we return tile above, no codegen is needed for this function.
# codegen is handled in _get_symnode()
[docs]
@_decorators.api(tiles_as_sizes=True)
def tile_id(tile: Tile) -> int:
"""
Retrieve tile_id of a given tile or list of tiles.
This is equivalent to `tile.begin // tile.block_size`.
This can also be written as: `tile.id`.
"""
raise exc.NotInsideKernel
@_decorators.register_fake(tile_id)
def _(tile: torch.SymInt) -> torch.SymInt:
_disable_flatten_get_tile(tile) # update config spec if needed
assert isinstance(tile, torch.SymInt)
return CompileEnvironment.current().cached_create_unbacked_symint(("tile_id", tile))
@_decorators.codegen(tile_id)
def _(state: CodegenState) -> ast.AST:
index = _disable_flatten_get_tile(state.proxy_arg(0))
offset = state.codegen.offset_var(index)
block_size = state.device_function.block_size_var(index)
if block_size is None:
expr_str = offset
else:
expr_str = f"{offset} // {block_size}"
return expr_from_string(expr_str)