diff --git a/README.md b/README.md index 35befa6..c0eeb23 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ pip install .[ingress_torch_cpu] \ ## Running tests -Running the tests is as simple as `lit .` in the root of the project. +Running the tests is as simple as `lit .` in the root of the project (in a suitable Python environment, e.g. through `uv run lit .`). We assume that the [`FileCheck`](https://llvm.org/docs/CommandGuide/FileCheck.html) and [`lit`](https://llvm.org/docs/CommandGuide/lit.html) executables are available on the `PATH`. diff --git a/examples/ingress/convert-kernel-bench-to-mlir.py b/examples/ingress/convert-kernel-bench-to-mlir.py index 8a2290b..0dbd150 100755 --- a/examples/ingress/convert-kernel-bench-to-mlir.py +++ b/examples/ingress/convert-kernel-bench-to-mlir.py @@ -1,4 +1,6 @@ -# RUN: %PYTHON %s 1,1 1,2 1,3 2,1 2,2 2,3 +# RUN: %PYTHON %s 1,1 1,2 2,1 +# REQUIRES: torch +# REQUIRES: kernel_bench # Basic conversion of KernelBench PyTorch kernels to mlir kernels, relying on # torch-mlir for the conversion. As there are a number of kernels for which @@ -13,7 +15,7 @@ from typing import Iterable from mlir import ir, passmanager -from lighthouse.ingress import torch as torch_ingress +import lighthouse.ingress as lh_ingress project_root = Path(__file__).parent.parent.parent torch_kernels_dir = project_root / "third_party" / "KernelBench" / "KernelBench" @@ -171,7 +173,7 @@ def process_task(task: KernelConversionTask): print("Processing:", kernel_relative_name) try: - mlir_kernel = torch_ingress.import_from_file(task.torch_path, ir_context=ctx) + mlir_kernel = lh_ingress.torch.import_from_file(task.torch_path, ir_context=ctx) assert isinstance(mlir_kernel, ir.Module) except Exception as e: print( diff --git a/examples/ingress/torch/lit.local.cfg b/examples/ingress/torch/lit.local.cfg index 2813a1b..935efbf 100644 --- a/examples/ingress/torch/lit.local.cfg +++ b/examples/ingress/torch/lit.local.cfg @@ -1 +1,4 @@ config.excludes = ["MLPModel"] + +if "torch" not in config.available_features: + config.unsupported = True diff --git a/examples/llama/lit.local.cfg b/examples/llama/lit.local.cfg index c37d087..872eb0b 100644 --- a/examples/llama/lit.local.cfg +++ b/examples/llama/lit.local.cfg @@ -1 +1 @@ -config.excludes = ["ref_model.py"] \ No newline at end of file +config.excludes = ["ref_model.py"] diff --git a/examples/llama/test_llama3.py b/examples/llama/test_llama3.py index 0866b06..185e660 100644 --- a/examples/llama/test_llama3.py +++ b/examples/llama/test_llama3.py @@ -1,21 +1,21 @@ -# RUN: %pytest %s +# RUN: %PYTHON %s +# REQUIRES: torch import functools import math as pymath import pytest import torch - from mlir import ir from mlir.dialects import transform, func, linalg, tensor, arith, complex, math from mlir.dialects.linalg import ElementwiseKind from mlir.dialects.transform import structured, bufferization, interpreter from mlir.passmanager import PassManager -from mlir.runtime.np_to_memref import ( - get_ranked_memref_descriptor, -) +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor from mlir.execution_engine import ExecutionEngine +from lighthouse import utils as lh_utils + from ref_model import ( Attention, ModelArgs, @@ -25,7 +25,6 @@ TransformerBlock, Transformer, ) -from lighthouse import utils as lh_utils def with_mlir_ctx_and_location(func): @@ -1020,7 +1019,7 @@ def bin_op(a, b, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("bin_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) a = torch.randn(*shape, dtype=torch_dtype) b = torch.randn(*shape, dtype=torch_dtype) out_ref = references[op](a, b) @@ -1030,7 +1029,7 @@ def bin_op(a, b, out): a_mem = get_ranked_memref_descriptor(a.numpy()) b_mem = get_ranked_memref_descriptor(b.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem]) + args = lh_utils.memref.to_packed_args([a_mem, b_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1076,14 +1075,14 @@ def unary_op(a, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("unary_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) a = torch.randn(*shape, dtype=torch_dtype) out_ref = references[op](a) out = torch.empty_like(out_ref) a_mem = get_ranked_memref_descriptor(a.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([a_mem, out_mem]) + args = lh_utils.memref.to_packed_args([a_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1112,13 +1111,13 @@ def rms_norm(a, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("rms_norm") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) a = torch.randn(*shape, dtype=torch_dtype) out_ref = references[get_l2_norm](a, eps) out = torch.empty_like(out_ref) a_mem = get_ranked_memref_descriptor(a.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([a_mem, out_mem]) + args = lh_utils.memref.to_packed_args([a_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1160,7 +1159,7 @@ def linear_op(x, w, b, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("linear_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(*shape, in_features, dtype=torch_dtype) w = torch.randn(out_features, in_features, dtype=torch_dtype) b = torch.randn(out_features, dtype=torch_dtype) @@ -1171,7 +1170,7 @@ def linear_op(x, w, b, out): w_mem = get_ranked_memref_descriptor(w.numpy()) b_mem = get_ranked_memref_descriptor(b.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([x_mem, w_mem, b_mem, out_mem]) + args = lh_utils.memref.to_packed_args([x_mem, w_mem, b_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1201,7 +1200,7 @@ def polar_op(magnitude, angle, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("polar_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) magnitude = torch.randn(4, 16, dtype=torch_dtype) angle = torch.randn(4, 16, dtype=torch_dtype) out_ref = references[get_polar](magnitude, angle) @@ -1209,7 +1208,7 @@ def polar_op(magnitude, angle, out): magnitude_mem = get_ranked_memref_descriptor(magnitude.numpy()) angle_mem = get_ranked_memref_descriptor(angle.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([magnitude_mem, angle_mem, out_mem]) + args = lh_utils.memref.to_packed_args([magnitude_mem, angle_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1237,14 +1236,14 @@ def repeat_kv_op(x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("repeat_kv_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(2, 512, 8, 64, dtype=torch_dtype) out_ref = references[get_repeat_kv](x, n_rep) out = torch.empty_like(out_ref) x_mem = get_ranked_memref_descriptor(x.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([x_mem, out_mem]) + args = lh_utils.memref.to_packed_args([x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1275,7 +1274,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("reshape_for_broadcast") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) freqs_cis = torch.randn(512, 64, dtype=torch_dtype) x = torch.randn(2, 512, 32, 128, dtype=torch_dtype) # Convert x to complex view as expected by reshape_for_broadcast @@ -1286,7 +1285,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out): freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy()) x_mem = get_ranked_memref_descriptor(x.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([freqs_cis_mem, x_mem, out_mem]) + args = lh_utils.memref.to_packed_args([freqs_cis_mem, x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1317,7 +1316,7 @@ def view_as_complex_op(x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("view_as_complex_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(2, 512, 32, 128, dtype=torch_dtype) x_reshaped = x.reshape(2, 512, 32, 64, 2) out_ref = torch.view_as_complex(x_reshaped) @@ -1325,7 +1324,7 @@ def view_as_complex_op(x, out): x_mem = get_ranked_memref_descriptor(x_reshaped.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([x_mem, out_mem]) + args = lh_utils.memref.to_packed_args([x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1353,7 +1352,7 @@ def as_real_op(x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("as_real_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(2, 512, 32, 64, 2, dtype=torch_dtype) x_complex = torch.view_as_complex(x) out_ref = torch.view_as_real(x_complex) @@ -1361,7 +1360,7 @@ def as_real_op(x, out): x_mem = get_ranked_memref_descriptor(x_complex.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([x_mem, out_mem]) + args = lh_utils.memref.to_packed_args([x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1394,7 +1393,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out): return module ir_type = to_ir_type(elem_type) - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) xq_shape = (batch_size, seq_len, n_heads, head_dim) xk_shape = (batch_size, seq_len, n_kv_heads, head_dim) freqs_cis_shape = (seq_len, head_dim // 2) @@ -1423,7 +1422,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out): freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy()) out1_mem = get_ranked_memref_descriptor(out1.numpy()) out2_mem = get_ranked_memref_descriptor(out2.numpy()) - args = lh_utils.memrefs_to_packed_args( + args = lh_utils.memref.to_packed_args( [a_mem, b_mem, freqs_cis_mem, out1_mem, out2_mem] ) func_ptr(args) @@ -1488,7 +1487,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("feed_forward") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(4, 16, dtype=torch_dtype) w1 = torch.randn(64, 16, dtype=torch_dtype) b1 = torch.randn(64, dtype=torch_dtype) @@ -1511,7 +1510,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out): w3_mem = get_ranked_memref_descriptor(w3.numpy()) b3_mem = get_ranked_memref_descriptor(b3.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args( + args = lh_utils.memref.to_packed_args( [x_mem, w1_mem, b1_mem, w2_mem, b2_mem, w3_mem, b3_mem, out_mem] ) func_ptr(args) @@ -1644,7 +1643,7 @@ def attention_op(x, wq, wk, wv, wo, freqs_cis, mask, out): freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis_real.numpy()) mask_mem = get_ranked_memref_descriptor(mask.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args( + args = lh_utils.memref.to_packed_args( [x_mem, wq_mem, wk_mem, wv_mem, wo_mem, freqs_cis_mem, mask_mem, out_mem] ) func_ptr(args) @@ -1791,7 +1790,7 @@ def transformer_block_op( b3_mem = get_ranked_memref_descriptor(b3.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args( + args = lh_utils.memref.to_packed_args( [ x_mem, wq_mem, @@ -1980,7 +1979,7 @@ def transformer_op(*params): out_mem = get_ranked_memref_descriptor(out.numpy()) memrefs.append(out_mem) - args = lh_utils.memrefs_to_packed_args(memrefs) + args = lh_utils.memref.to_packed_args(memrefs) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) diff --git a/examples/mlir/compile_and_run.py b/examples/mlir/compile_and_run.py index 0d8b8f1..20b5cc5 100644 --- a/examples/mlir/compile_and_run.py +++ b/examples/mlir/compile_and_run.py @@ -1,4 +1,5 @@ # RUN: %PYTHON %s +# REQUIRES: torch import torch import argparse @@ -9,7 +10,7 @@ from mlir.execution_engine import ExecutionEngine from mlir.passmanager import PassManager -from lighthouse import utils as lh_utils +import lighthouse.utils as lh_utils def create_kernel(ctx: ir.Context) -> ir.Module: @@ -167,7 +168,7 @@ def main(args): out = torch.empty_like(out_ref) # Execute the kernel. - args = lh_utils.torch_to_packed_args([a, b, out]) + args = lh_utils.torch.to_packed_args([a, b, out]) add_func(args) ### Verification ### diff --git a/examples/workload/example.py b/examples/workload/example.py index 3137dad..c37b226 100644 --- a/examples/workload/example.py +++ b/examples/workload/example.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %PYTHON %s | FileCheck %s # CHECK: func.func @payload # CHECK: PASSED # CHECK: Throughput: @@ -6,26 +6,20 @@ Workload example: Element-wise sum of two (M, N) float32 arrays on CPU. """ +import ctypes +from contextlib import contextmanager +from functools import cached_property +from typing import Optional + import numpy as np from mlir import ir from mlir.runtime.np_to_memref import get_ranked_memref_descriptor from mlir.dialects import func, linalg, bufferization from mlir.dialects import transform from mlir.execution_engine import ExecutionEngine -from contextlib import contextmanager -from functools import cached_property -import ctypes -from typing import Optional -from lighthouse.utils.mlir import ( - apply_registered_pass, - canonicalize, - match, -) -from lighthouse.workload import ( - Workload, - execute, - benchmark, -) + +from lighthouse.utils.mlir import apply_registered_pass, canonicalize, match +from lighthouse.workload import Workload, execute, benchmark class ElementwiseSum(Workload): diff --git a/examples/workload/example_mlir.py b/examples/workload/example_mlir.py index 7d3211a..376993e 100644 --- a/examples/workload/example_mlir.py +++ b/examples/workload/example_mlir.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %PYTHON %s | FileCheck %s # CHECK: func.func @payload # CHECK: PASSED # CHECK: Throughput: @@ -8,6 +8,9 @@ In this example, allocation and deallocation of input arrays is done in MLIR. """ +import ctypes +from contextlib import contextmanager + import numpy as np from mlir import ir from mlir.runtime.np_to_memref import ( @@ -17,18 +20,11 @@ ) from mlir.dialects import func, linalg, arith, memref from mlir.execution_engine import ExecutionEngine -import ctypes -from contextlib import contextmanager -from lighthouse.utils import ( - get_packed_arg, - memrefs_to_packed_args, - memref_to_ctype, -) + +from lighthouse.workload import execute, benchmark +import lighthouse.utils as lh_utils + from example import ElementwiseSum -from lighthouse.workload import ( - execute, - benchmark, -) def emit_host_alloc(suffix: str, element_type: ir.Type, rank: int = 2): @@ -114,16 +110,16 @@ def _allocate_array( # construct a memref descriptor for the result memref shape = (self.M, self.N) mref = make_nd_memref_descriptor(len(shape), as_ctype(self.dtype))() - ptr_mref = memref_to_ctype(mref) + ptr_mref = lh_utils.memref.to_ctype(mref) ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape] - alloc_func(get_packed_arg([ptr_mref, *ptr_dims])) + alloc_func(lh_utils.memref.get_packed_arg([ptr_mref, *ptr_dims])) self.memrefs[name] = mref return mref def _deallocate_all(self, execution_engine: ExecutionEngine): for mref in self.memrefs.values(): dealloc_func = execution_engine.lookup("host_dealloc_f32") - dealloc_func(memrefs_to_packed_args([mref])) + dealloc_func(lh_utils.memref.to_packed_args([mref])) self.memrefs = {} def get_input_arrays( @@ -136,10 +132,9 @@ def get_input_arrays( # initialize with MLIR fill_zero_func = execution_engine.lookup("host_fill_constant_zero_f32") fill_random_func = execution_engine.lookup("host_fill_random_f32") - fill_zero_func(memrefs_to_packed_args([C])) - fill_random_func(memrefs_to_packed_args([A])) - fill_random_func(memrefs_to_packed_args([B])) - + fill_zero_func(lh_utils.memref.to_packed_args([C])) + fill_random_func(lh_utils.memref.to_packed_args([A])) + fill_random_func(lh_utils.memref.to_packed_args([B])) return [A, B, C] @contextmanager diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index 0a677dc..d86478e 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -5,6 +5,12 @@ XeGPU matrix multiplication benchmark. """ +import argparse +import ctypes +from typing import Optional +from contextlib import contextmanager +from functools import cached_property + import numpy as np from mlir import ir from mlir.runtime.np_to_memref import ( @@ -13,18 +19,14 @@ as_ctype, ) from mlir.execution_engine import ExecutionEngine -from typing import Optional -import ctypes -from contextlib import contextmanager -from functools import cached_property -from lighthouse.utils import get_packed_arg, memref_to_ctype from lighthouse.workload import Workload, benchmark +from lighthouse.utils.memref import get_packed_arg, to_ctype as memref_to_ctype + +# Import from sibling files: from schedule import get_schedule_module from payload import generate_matmul_payload -import argparse - def numpy_to_ctype(arr: np.ndarray) -> ctypes._Pointer: """Convert numpy array to memref and ctypes **void pointer.""" diff --git a/lighthouse/ingress/__init__.py b/lighthouse/ingress/__init__.py index 2aa3859..989c237 100644 --- a/lighthouse/ingress/__init__.py +++ b/lighthouse/ingress/__init__.py @@ -1 +1,21 @@ -"""Provides functions to convert source objects (code, models, designs) into MLIR files that the MLIR project can consume""" +__all__ = ["mlir_gen", "torch"] + +import sys +import importlib + + +def __getattr__(name): + """Enable lazy loading of submodules. + + Enables `import lighthouse.ingress as lh_ingress; lh_ingress.` with + loading of (the submodule's heavy) depenendencies only upon being needed. + """ + + if name in __all__: + # Import the submodule and cache it on the current module. That is, + # upon the next access __getattr__ will not be called. + submodule = importlib.import_module("lighthouse.ingress." + name) + lighthouse_ingress_mod = sys.modules[__name__] + setattr(lighthouse_ingress_mod, name, submodule) + return submodule + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/lighthouse/utils/__init__.py b/lighthouse/utils/__init__.py index 474b748..6a85edd 100644 --- a/lighthouse/utils/__init__.py +++ b/lighthouse/utils/__init__.py @@ -1,19 +1,21 @@ -"""A collection of utility tools""" +__all__ = ["memref", "mlir", "torch"] -from .runtime_args import ( - get_packed_arg, - memref_to_ctype, - memrefs_to_packed_args, - torch_to_memref, - torch_to_packed_args, - mlir_type_to_torch_dtype, -) +import sys +import importlib -__all__ = [ - "get_packed_arg", - "memref_to_ctype", - "memrefs_to_packed_args", - "mlir_type_to_torch_dtype", - "torch_to_memref", - "torch_to_packed_args", -] + +def __getattr__(name): + """Enable lazy loading of submodules. + + Enables `import lighthouse.utils as lh_utils; lh_utils.` with + loading of (the submodule's heavy) depenendencies only upon being needed. + """ + + if name in __all__: + # Import the submodule and cache it on the current module. That is, + # upon the next access __getattr__ will not be called. + submodule = importlib.import_module("lighthouse.utils." + name) + lighthouse_utils_mod = sys.modules[__name__] + setattr(lighthouse_utils_mod, name, submodule) + return submodule + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/lighthouse/utils/memref.py b/lighthouse/utils/memref.py new file mode 100644 index 0000000..63f67a1 --- /dev/null +++ b/lighthouse/utils/memref.py @@ -0,0 +1,39 @@ +import ctypes +from typing import Sequence + + +def to_ctype(memref_desc) -> ctypes._Pointer: + """ + Convert a memref descriptor into a ctype argument. + + Args: + memref_desc: An MLIR memref descriptor. + """ + return ctypes.pointer(ctypes.pointer(memref_desc)) + + +def get_packed_arg( + ctypes_args: Sequence[ctypes._Pointer], +) -> ctypes.Array[ctypes.c_void_p]: + """ + Return a list of packed ctype arguments compatible with + jitted MLIR function's interface. + + Args: + ctypes_args: A list of ctype pointer arguments. + """ + packed_args = (ctypes.c_void_p * len(ctypes_args))() + for argNum in range(len(ctypes_args)): + packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) + return packed_args + + +def to_packed_args(memref_descs) -> ctypes.Array[ctypes.c_void_p]: + """ + Convert a list of memref descriptors into packed ctype arguments. + + Args: + memref_descs: A list of memref descriptors. + """ + ctype_args = [to_ctype(memref) for memref in memref_descs] + return get_packed_arg(ctype_args) diff --git a/lighthouse/utils/runtime_args.py b/lighthouse/utils/runtime_args.py deleted file mode 100644 index eb6b22a..0000000 --- a/lighthouse/utils/runtime_args.py +++ /dev/null @@ -1,97 +0,0 @@ -import ctypes -import torch - -from mlir.runtime.np_to_memref import ( - get_ranked_memref_descriptor, -) -from mlir import ir - - -def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]: - """ - Return a list of packed ctype arguments compatible with - jitted MLIR function's interface. - - Args: - ctypes_args: A list of ctype pointer arguments. - """ - packed_args = (ctypes.c_void_p * len(ctypes_args))() - for argNum in range(len(ctypes_args)): - packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) - return packed_args - - -def memref_to_ctype(memref_desc) -> ctypes._Pointer: - """ - Convert a memref descriptor into a ctype argument. - - Args: - memref_desc: An MLIR memref descriptor. - """ - return ctypes.pointer(ctypes.pointer(memref_desc)) - - -def memrefs_to_packed_args(memref_descs) -> list[ctypes.c_void_p]: - """ - Convert a list of memref descriptors into packed ctype arguments. - - Args: - memref_descs: A list of memref descriptors. - """ - ctype_args = [memref_to_ctype(memref) for memref in memref_descs] - return get_packed_arg(ctype_args) - - -def torch_to_memref(input: torch.Tensor) -> ctypes.Structure: - """ - Convert a PyTorch tensor into a memref descriptor. - - Args: - input: PyTorch tensor. - """ - return get_ranked_memref_descriptor(input.numpy()) - - -def torch_to_packed_args(inputs: list[torch.Tensor]) -> list[ctypes.c_void_p]: - """ - Convert a list of PyTorch tensors into packed ctype arguments. - - Args: - inputs: A list of PyTorch tensors. - """ - memrefs = [torch_to_memref(input) for input in inputs] - return memrefs_to_packed_args(memrefs) - - -def mlir_type_to_torch_dtype(mlir_type: ir.Type): - """ - Convert an MLIR type to a PyTorch dtype. - Args: - mlir_type: An MLIR type (e.g., ir.F32Type, ir.F64Type) - Returns: - Corresponding PyTorch dtype - """ - import torch - - if isinstance(mlir_type, ir.F32Type): - return torch.float32 - elif isinstance(mlir_type, ir.F64Type): - return torch.float64 - elif isinstance(mlir_type, ir.F16Type): - return torch.float16 - elif isinstance(mlir_type, ir.BF16Type): - return torch.bfloat16 - elif isinstance(mlir_type, ir.IntegerType): - width = mlir_type.width - if width == 64: - return torch.int64 - elif width == 32: - return torch.int32 - elif width == 16: - return torch.int16 - elif width == 8: - return torch.int8 - elif width == 1: - return torch.bool - - raise ValueError(f"Unsupported MLIR type: {mlir_type}") diff --git a/lighthouse/utils/torch.py b/lighthouse/utils/torch.py new file mode 100644 index 0000000..4fa3d1c --- /dev/null +++ b/lighthouse/utils/torch.py @@ -0,0 +1,62 @@ +import ctypes + +import torch +from mlir import ir +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor + +from . import memref as memref_utils + + +def to_memref(input: torch.Tensor) -> ctypes.Structure: + """ + Convert a PyTorch tensor into a memref descriptor. + + Args: + input: PyTorch tensor. + """ + return get_ranked_memref_descriptor(input.numpy()) + + +def to_packed_args(inputs: list[torch.Tensor]) -> ctypes.Array[ctypes.c_void_p]: + """ + Convert a list of PyTorch tensors into packed ctype arguments. + + Args: + inputs: A list of PyTorch tensors. + """ + memrefs = [to_memref(input) for input in inputs] + return memref_utils.to_packed_args(memrefs) + + +def dtype_from_mlir_type(mlir_type: ir.Type): + """ + Convert an MLIR type to a PyTorch dtype. + Args: + mlir_type: An MLIR type (e.g., ir.F32Type, ir.F64Type) + Returns: + Corresponding PyTorch dtype + """ + import torch + + if isinstance(mlir_type, ir.F32Type): + return torch.float32 + elif isinstance(mlir_type, ir.F64Type): + return torch.float64 + elif isinstance(mlir_type, ir.F16Type): + return torch.float16 + elif isinstance(mlir_type, ir.BF16Type): + return torch.bfloat16 + elif isinstance(mlir_type, ir.IntegerType): + width = mlir_type.width + if width == 64: + return torch.int64 + elif width == 32: + return torch.int32 + elif width == 16: + return torch.int16 + elif width == 8: + return torch.int8 + elif width == 1: + return torch.bool + + raise ValueError(f"Unsupported MLIR type: {mlir_type}") diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index ae5e07e..9e9a5d0 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -9,7 +9,7 @@ from mlir.execution_engine import ExecutionEngine from mlir.runtime.np_to_memref import get_ranked_memref_descriptor from lighthouse.utils.mlir import get_mlir_library_path -from lighthouse.utils import memrefs_to_packed_args +from lighthouse.utils.memref import to_packed_args from lighthouse.workload import Workload from typing import Optional @@ -44,7 +44,7 @@ def execute( with workload.allocate_inputs(execution_engine=engine) as inputs: # prepare function arguments - packed_args = memrefs_to_packed_args(inputs) + packed_args = to_packed_args(inputs) # handle to payload function payload_func = engine.lookup(workload.payload_function_name) @@ -143,7 +143,7 @@ def benchmark( if check_correctness: # call payload once to verify correctness # prepare function arguments - packed_args = memrefs_to_packed_args(inputs) + packed_args = to_packed_args(inputs) payload_func = engine.lookup(workload.payload_function_name) payload_func(packed_args) @@ -156,7 +156,7 @@ def benchmark( # allocate buffer for timings and prepare arguments time_array = np.zeros((nruns,), dtype=np.float64) time_memref = get_ranked_memref_descriptor(time_array) - packed_args_with_time = memrefs_to_packed_args(inputs + [time_memref]) + packed_args_with_time = to_packed_args(inputs + [time_memref]) # call benchmark function benchmark_func = engine.lookup("benchmark") diff --git a/lit.cfg.py b/lit.cfg.py index 3ed67ad..b9f0dac 100644 --- a/lit.cfg.py +++ b/lit.cfg.py @@ -1,18 +1,28 @@ import os +import importlib.util import lit.formats from lit.TestingConfig import TestingConfig -# Imagine that, all your variables defined and with types! +# Imagine that, all your variables defined and with type information! assert isinstance(config := eval("config"), TestingConfig) +project_root = os.path.dirname(__file__) + config.name = "Lighthouse test suite" config.test_format = lit.formats.ShTest(True) -config.test_source_root = os.path.dirname(__file__) -config.test_exec_root = os.path.dirname(__file__) + "/lit.out" +config.test_source_root = project_root +config.test_exec_root = project_root + "/lit.out" -config.substitutions.append(("%CACHE", os.path.dirname(__file__) + "/cache")) -config.substitutions.append(("%pytest", "uv run")) -config.substitutions.append(("%PYTHON", "uv run")) +config.substitutions.append(("%CACHE", project_root + "/cache")) +python = os.environ.get("PYTHON", "python") +config.substitutions.append(("%PYTHON", python)) if filecheck_path := os.environ.get("FILECHECK"): config.substitutions.append(("FileCheck", filecheck_path)) + +if importlib.util.find_spec("torch"): + config.available_features.add("torch") + +torch_kernels_dir = project_root + "/third_party/KernelBench/KernelBench" +if os.path.isdir(torch_kernels_dir): + config.available_features.add("kernel_bench")