From 82aeef10e16b0a02edc13f0e6bd04ab84c878e31 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 8 Dec 2025 05:59:38 -0800 Subject: [PATCH 1/7] [testing] Tests mark themselves unsupported if features aren't available Also loosens dependency upon `uv` for running tests -- all that is required is that `lit` is running in a suitable environemnt (e.g. through `uv run lit .` or by first entering an venv and then `lit .`). --- README.md | 2 +- .../ingress/convert-kernel-bench-to-mlir.py | 4 +++- examples/ingress/torch/mlp_from_file.py | 3 ++- examples/ingress/torch/mlp_from_model.py | 3 ++- examples/llama/lit.local.cfg | 2 +- examples/llama/test_llama3.py | 3 ++- examples/mlir/compile_and_run.py | 3 ++- ...sform_a_payload_according_to_a_schedule.py | 2 +- examples/xegpu_matmul/matmul.py | 3 ++- lit.cfg.py | 22 ++++++++++++++----- 10 files changed, 32 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 35befa6..c0eeb23 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ pip install .[ingress_torch_cpu] \ ## Running tests -Running the tests is as simple as `lit .` in the root of the project. +Running the tests is as simple as `lit .` in the root of the project (in a suitable Python environment, e.g. through `uv run lit .`). We assume that the [`FileCheck`](https://llvm.org/docs/CommandGuide/FileCheck.html) and [`lit`](https://llvm.org/docs/CommandGuide/lit.html) executables are available on the `PATH`. diff --git a/examples/ingress/convert-kernel-bench-to-mlir.py b/examples/ingress/convert-kernel-bench-to-mlir.py index 8a2290b..055ddd6 100755 --- a/examples/ingress/convert-kernel-bench-to-mlir.py +++ b/examples/ingress/convert-kernel-bench-to-mlir.py @@ -1,4 +1,6 @@ -# RUN: %PYTHON %s 1,1 1,2 1,3 2,1 2,2 2,3 +# RUN: python %s 1,1 1,2 1,3 2,1 2,2 2,3 +# REQUIRES: torch +# REQUIRES: kernel_bench # Basic conversion of KernelBench PyTorch kernels to mlir kernels, relying on # torch-mlir for the conversion. As there are a number of kernels for which diff --git a/examples/ingress/torch/mlp_from_file.py b/examples/ingress/torch/mlp_from_file.py index 748de85..da9c014 100644 --- a/examples/ingress/torch/mlp_from_file.py +++ b/examples/ingress/torch/mlp_from_file.py @@ -1,4 +1,5 @@ -# RUN: %PYTHON %s +# RUN: python %s +# REQUIRES: torch """ Example demonstrating how to load a PyTorch model to MLIR using Lighthouse diff --git a/examples/ingress/torch/mlp_from_model.py b/examples/ingress/torch/mlp_from_model.py index 4590429..e7ebf79 100644 --- a/examples/ingress/torch/mlp_from_model.py +++ b/examples/ingress/torch/mlp_from_model.py @@ -1,4 +1,5 @@ -# RUN: %PYTHON %s +# RUN: python %s +# REQUIRES: torch """ Example demonstrating how to load an already initialized PyTorch model diff --git a/examples/llama/lit.local.cfg b/examples/llama/lit.local.cfg index c37d087..872eb0b 100644 --- a/examples/llama/lit.local.cfg +++ b/examples/llama/lit.local.cfg @@ -1 +1 @@ -config.excludes = ["ref_model.py"] \ No newline at end of file +config.excludes = ["ref_model.py"] diff --git a/examples/llama/test_llama3.py b/examples/llama/test_llama3.py index 0866b06..e81f92e 100644 --- a/examples/llama/test_llama3.py +++ b/examples/llama/test_llama3.py @@ -1,4 +1,5 @@ -# RUN: %pytest %s +# RUN: python %s +# REQUIRES: torch import functools import math as pymath diff --git a/examples/mlir/compile_and_run.py b/examples/mlir/compile_and_run.py index 0d8b8f1..f9e71bd 100644 --- a/examples/mlir/compile_and_run.py +++ b/examples/mlir/compile_and_run.py @@ -1,4 +1,5 @@ -# RUN: %PYTHON %s +# RUN: python %s +# REQUIRES: torch import torch import argparse diff --git a/examples/schedule/transform_a_payload_according_to_a_schedule.py b/examples/schedule/transform_a_payload_according_to_a_schedule.py index fcbbe7b..71cbd4e 100644 --- a/examples/schedule/transform_a_payload_according_to_a_schedule.py +++ b/examples/schedule/transform_a_payload_according_to_a_schedule.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: python %s | FileCheck %s # Simply demonstrates applying a schedule to a payload. # To do so generates a basic payload and a basic schedule, purely as an example. diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index 0a677dc..8e2bf6a 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -1,4 +1,5 @@ -# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# RUN: python %s --dump-kernel=xegpu-wg | FileCheck %s +# REQUIRES: torch # CHECK: module attributes {gpu.container_module} { """ diff --git a/lit.cfg.py b/lit.cfg.py index 3ed67ad..a83f501 100644 --- a/lit.cfg.py +++ b/lit.cfg.py @@ -1,18 +1,28 @@ import os +import importlib.util import lit.formats from lit.TestingConfig import TestingConfig -# Imagine that, all your variables defined and with types! +# Imagine that, all your variables defined and with type information! assert isinstance(config := eval("config"), TestingConfig) +project_root = os.path.dirname(__file__) + config.name = "Lighthouse test suite" config.test_format = lit.formats.ShTest(True) -config.test_source_root = os.path.dirname(__file__) -config.test_exec_root = os.path.dirname(__file__) + "/lit.out" +config.test_source_root = project_root +config.test_exec_root = project_root + "/lit.out" -config.substitutions.append(("%CACHE", os.path.dirname(__file__) + "/cache")) -config.substitutions.append(("%pytest", "uv run")) -config.substitutions.append(("%PYTHON", "uv run")) +config.substitutions.append(("%CACHE", project_root + "/cache")) +if python_path := os.environ.get("PYTHON"): + config.substitutions.append(("python", python_path)) if filecheck_path := os.environ.get("FILECHECK"): config.substitutions.append(("FileCheck", filecheck_path)) + +if importlib.util.find_spec("torch"): + config.available_features.add("torch") + +torch_kernels_dir = project_root + "/third_party/KernelBench/KernelBench" +if os.path.isdir(torch_kernels_dir): + config.available_features.add("kernel_bench") From e91ed2137834f9e335fe24c4b4f8a49005794b3d Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 9 Dec 2025 10:33:33 -0800 Subject: [PATCH 2/7] Split out torch dependent functionality in lighthouse.utils Means at least one more example can run without torch being installed. --- examples/llama/test_llama3.py | 52 ++--- examples/mlir/compile_and_run.py | 4 +- examples/xegpu_matmul/matmul.py | 17 +- examples/xegpu_matmul/runner.py | 177 ++++++++++++++++++ lighthouse/utils/__init__.py | 19 -- lighthouse/utils/runtime/__init__.py | 0 lighthouse/utils/runtime/ffi.py | 37 ++++ .../{runtime_args.py => runtime/torch.py} | 41 +--- 8 files changed, 254 insertions(+), 93 deletions(-) create mode 100644 examples/xegpu_matmul/runner.py delete mode 100644 lighthouse/utils/__init__.py create mode 100644 lighthouse/utils/runtime/__init__.py create mode 100644 lighthouse/utils/runtime/ffi.py rename lighthouse/utils/{runtime_args.py => runtime/torch.py} (59%) diff --git a/examples/llama/test_llama3.py b/examples/llama/test_llama3.py index e81f92e..ed45140 100644 --- a/examples/llama/test_llama3.py +++ b/examples/llama/test_llama3.py @@ -26,7 +26,7 @@ TransformerBlock, Transformer, ) -from lighthouse import utils as lh_utils +from lighthouse.utils.runtime import ffi as ffi_utils, torch as torch_utils def with_mlir_ctx_and_location(func): @@ -1021,7 +1021,7 @@ def bin_op(a, b, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("bin_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) a = torch.randn(*shape, dtype=torch_dtype) b = torch.randn(*shape, dtype=torch_dtype) out_ref = references[op](a, b) @@ -1031,7 +1031,7 @@ def bin_op(a, b, out): a_mem = get_ranked_memref_descriptor(a.numpy()) b_mem = get_ranked_memref_descriptor(b.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem]) + args = ffi_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1077,14 +1077,14 @@ def unary_op(a, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("unary_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) a = torch.randn(*shape, dtype=torch_dtype) out_ref = references[op](a) out = torch.empty_like(out_ref) a_mem = get_ranked_memref_descriptor(a.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([a_mem, out_mem]) + args = ffi_utils.memrefs_to_packed_args([a_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1113,13 +1113,13 @@ def rms_norm(a, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("rms_norm") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) a = torch.randn(*shape, dtype=torch_dtype) out_ref = references[get_l2_norm](a, eps) out = torch.empty_like(out_ref) a_mem = get_ranked_memref_descriptor(a.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([a_mem, out_mem]) + args = ffi_utils.memrefs_to_packed_args([a_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1161,7 +1161,7 @@ def linear_op(x, w, b, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("linear_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) x = torch.randn(*shape, in_features, dtype=torch_dtype) w = torch.randn(out_features, in_features, dtype=torch_dtype) b = torch.randn(out_features, dtype=torch_dtype) @@ -1172,7 +1172,7 @@ def linear_op(x, w, b, out): w_mem = get_ranked_memref_descriptor(w.numpy()) b_mem = get_ranked_memref_descriptor(b.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([x_mem, w_mem, b_mem, out_mem]) + args = ffi_utils.memrefs_to_packed_args([x_mem, w_mem, b_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1202,7 +1202,7 @@ def polar_op(magnitude, angle, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("polar_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) magnitude = torch.randn(4, 16, dtype=torch_dtype) angle = torch.randn(4, 16, dtype=torch_dtype) out_ref = references[get_polar](magnitude, angle) @@ -1210,7 +1210,7 @@ def polar_op(magnitude, angle, out): magnitude_mem = get_ranked_memref_descriptor(magnitude.numpy()) angle_mem = get_ranked_memref_descriptor(angle.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([magnitude_mem, angle_mem, out_mem]) + args = ffi_utils.memrefs_to_packed_args([magnitude_mem, angle_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1238,14 +1238,14 @@ def repeat_kv_op(x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("repeat_kv_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) x = torch.randn(2, 512, 8, 64, dtype=torch_dtype) out_ref = references[get_repeat_kv](x, n_rep) out = torch.empty_like(out_ref) x_mem = get_ranked_memref_descriptor(x.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([x_mem, out_mem]) + args = ffi_utils.memrefs_to_packed_args([x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1276,7 +1276,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("reshape_for_broadcast") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) freqs_cis = torch.randn(512, 64, dtype=torch_dtype) x = torch.randn(2, 512, 32, 128, dtype=torch_dtype) # Convert x to complex view as expected by reshape_for_broadcast @@ -1287,7 +1287,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out): freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy()) x_mem = get_ranked_memref_descriptor(x.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([freqs_cis_mem, x_mem, out_mem]) + args = ffi_utils.memrefs_to_packed_args([freqs_cis_mem, x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1318,7 +1318,7 @@ def view_as_complex_op(x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("view_as_complex_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) x = torch.randn(2, 512, 32, 128, dtype=torch_dtype) x_reshaped = x.reshape(2, 512, 32, 64, 2) out_ref = torch.view_as_complex(x_reshaped) @@ -1326,7 +1326,7 @@ def view_as_complex_op(x, out): x_mem = get_ranked_memref_descriptor(x_reshaped.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([x_mem, out_mem]) + args = ffi_utils.memrefs_to_packed_args([x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1354,7 +1354,7 @@ def as_real_op(x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("as_real_op") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) x = torch.randn(2, 512, 32, 64, 2, dtype=torch_dtype) x_complex = torch.view_as_complex(x) out_ref = torch.view_as_real(x_complex) @@ -1362,7 +1362,7 @@ def as_real_op(x, out): x_mem = get_ranked_memref_descriptor(x_complex.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args([x_mem, out_mem]) + args = ffi_utils.memrefs_to_packed_args([x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1395,7 +1395,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out): return module ir_type = to_ir_type(elem_type) - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) xq_shape = (batch_size, seq_len, n_heads, head_dim) xk_shape = (batch_size, seq_len, n_kv_heads, head_dim) freqs_cis_shape = (seq_len, head_dim // 2) @@ -1424,7 +1424,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out): freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy()) out1_mem = get_ranked_memref_descriptor(out1.numpy()) out2_mem = get_ranked_memref_descriptor(out2.numpy()) - args = lh_utils.memrefs_to_packed_args( + args = ffi_utils.memrefs_to_packed_args( [a_mem, b_mem, freqs_cis_mem, out1_mem, out2_mem] ) func_ptr(args) @@ -1489,7 +1489,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("feed_forward") - torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) x = torch.randn(4, 16, dtype=torch_dtype) w1 = torch.randn(64, 16, dtype=torch_dtype) b1 = torch.randn(64, dtype=torch_dtype) @@ -1512,7 +1512,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out): w3_mem = get_ranked_memref_descriptor(w3.numpy()) b3_mem = get_ranked_memref_descriptor(b3.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args( + args = ffi_utils.memrefs_to_packed_args( [x_mem, w1_mem, b1_mem, w2_mem, b2_mem, w3_mem, b3_mem, out_mem] ) func_ptr(args) @@ -1645,7 +1645,7 @@ def attention_op(x, wq, wk, wv, wo, freqs_cis, mask, out): freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis_real.numpy()) mask_mem = get_ranked_memref_descriptor(mask.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args( + args = ffi_utils.memrefs_to_packed_args( [x_mem, wq_mem, wk_mem, wv_mem, wo_mem, freqs_cis_mem, mask_mem, out_mem] ) func_ptr(args) @@ -1792,7 +1792,7 @@ def transformer_block_op( b3_mem = get_ranked_memref_descriptor(b3.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = lh_utils.memrefs_to_packed_args( + args = ffi_utils.memrefs_to_packed_args( [ x_mem, wq_mem, @@ -1981,7 +1981,7 @@ def transformer_op(*params): out_mem = get_ranked_memref_descriptor(out.numpy()) memrefs.append(out_mem) - args = lh_utils.memrefs_to_packed_args(memrefs) + args = ffi_utils.memrefs_to_packed_args(memrefs) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) diff --git a/examples/mlir/compile_and_run.py b/examples/mlir/compile_and_run.py index f9e71bd..b3726ee 100644 --- a/examples/mlir/compile_and_run.py +++ b/examples/mlir/compile_and_run.py @@ -10,7 +10,7 @@ from mlir.execution_engine import ExecutionEngine from mlir.passmanager import PassManager -from lighthouse import utils as lh_utils +from lighthouse.utils.runtime import torch as torch_utils def create_kernel(ctx: ir.Context) -> ir.Module: @@ -168,7 +168,7 @@ def main(args): out = torch.empty_like(out_ref) # Execute the kernel. - args = lh_utils.torch_to_packed_args([a, b, out]) + args = torch_utils.torch_to_packed_args([a, b, out]) add_func(args) ### Verification ### diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index 8e2bf6a..dd21d8d 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -1,11 +1,16 @@ # RUN: python %s --dump-kernel=xegpu-wg | FileCheck %s -# REQUIRES: torch # CHECK: module attributes {gpu.container_module} { """ XeGPU matrix multiplication benchmark. """ +import argparse +import ctypes +from typing import Optional +from contextlib import contextmanager +from functools import cached_property + import numpy as np from mlir import ir from mlir.runtime.np_to_memref import ( @@ -14,18 +19,14 @@ as_ctype, ) from mlir.execution_engine import ExecutionEngine -from typing import Optional -import ctypes -from contextlib import contextmanager -from functools import cached_property -from lighthouse.utils import get_packed_arg, memref_to_ctype +from lighthouse.utils.memref import get_packed_arg, memref_to_ctype from lighthouse.workload import Workload, benchmark + +# Import from sibling files: from schedule import get_schedule_module from payload import generate_matmul_payload -import argparse - def numpy_to_ctype(arr: np.ndarray) -> ctypes._Pointer: """Convert numpy array to memref and ctypes **void pointer.""" diff --git a/examples/xegpu_matmul/runner.py b/examples/xegpu_matmul/runner.py new file mode 100644 index 0000000..219672a --- /dev/null +++ b/examples/xegpu_matmul/runner.py @@ -0,0 +1,177 @@ +import numpy as np +import ctypes +import os +from typing import Optional + +from mlir.dialects import func, arith, scf, memref +from mlir.execution_engine import ExecutionEngine +from mlir import ir +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor + +from lighthouse.utils.memref import get_packed_arg +from mlir_utils import get_mlir_library_path + + +def get_engine(payload_module: ir.Module, opt_level: int = 3) -> ExecutionEngine: + lib_dir = get_mlir_library_path() + libs = [ + "libmlir_levelzero_runtime.so", + "libmlir_runner_utils.so", + "libmlir_c_runner_utils.so", + ] + libs = [os.path.join(lib_dir, lib) for lib in libs] + execution_engine = ExecutionEngine( + payload_module, opt_level=opt_level, shared_libs=libs + ) + execution_engine.initialize() + return execution_engine + + +def apply_transform_schedule( + payload_module: ir.Module, + schedule_module: ir.Module, + dump_kernel: Optional[str] = None, + dump_schedule: bool = False, +): + if not dump_kernel or dump_kernel != "initial": + # apply schedule on payload module + named_seq = schedule_module.body.operations[0] + named_seq.apply(payload_module) + if dump_kernel: + print(payload_module) + if dump_schedule: + print(schedule_module) + + +def lower_payload( + workload: object, + dump_kernel: Optional[str] = None, + dump_schedule: bool = False, + schedule_parameters: Optional[dict] = None, +) -> ir.Module: + payload_module = workload.payload_module() + schedule_module = workload.schedule_module( + dump_kernel=dump_kernel, parameters=schedule_parameters + ) + apply_transform_schedule( + payload_module, + schedule_module, + dump_kernel=dump_kernel, + dump_schedule=dump_schedule, + ) + return payload_module + + +def execute( + workload: object, + check_correctness: bool = True, + schedule_parameters: Optional[dict] = None, + verbose: int = 0, +): + # lower payload with schedule + payload_module = lower_payload(workload, schedule_parameters=schedule_parameters) + # get execution engine + engine = get_engine(payload_module, requirements=workload.requirements()) + + with workload.allocate_inputs(execution_engine=engine) as inputs: + # prepare function arguments + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] + packed_args = get_packed_arg(pointers) + + # handle to payload function + payload_func = engine.lookup(workload.payload_function_name) + + # call + payload_func(packed_args) + + if check_correctness: + workload.check_correctness(execution_engine=engine, verbose=verbose) + + +def benchmark( + workload: object, + nruns: int = 100, + nwarmup: int = 10, + schedule_parameters: Optional[dict] = None, + check_correctness: bool = True, + verbose: int = 0, +) -> np.ndarray: + # get original payload module + payload_module = workload.payload_module() + + # find payload function + payload_func = None + for op in payload_module.operation.regions[0].blocks[0]: + if ( + isinstance(op, func.FuncOp) + and op.name.value == workload.payload_function_name + ): + payload_func = op + break + assert payload_func is not None, "Could not find payload function" + payload_arguments = payload_func.type.inputs + + # emit benchmark function that calls payload and times it + with ir.InsertionPoint(payload_module.body): + # define rtclock function + f64_t = ir.F64Type.get() + func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit benchmark function + time_memref_t = ir.MemRefType.get((nruns,), f64_t) + args = payload_arguments + [time_memref_t] + + @func.func(*args) + def benchmark(*args): + index_t = ir.IndexType.get() + zero = arith.constant(index_t, 0) + one = arith.constant(index_t, 1) + nwarmup_cst = arith.constant(index_t, nwarmup) + for i in scf.for_(zero, nwarmup_cst, one): + # FIXME(upstream): func.call is broken for this use case? + func.CallOp(payload_func, list(args[: len(payload_arguments)])) + scf.yield_(()) + nruns_cst = arith.constant(index_t, nruns) + for i in scf.for_(zero, nruns_cst, one): + tic = func.call((f64_t,), "rtclock", ()) + func.CallOp(payload_func, list(args[: len(payload_arguments)])) + toc = func.call((f64_t,), "rtclock", ()) + time = arith.subf(toc, tic) + memref.store(time, args[-1], [i]) + scf.yield_(()) + + benchmark.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + # lower + apply_transform_schedule( + payload_module, + workload.schedule_module(parameters=schedule_parameters), + ) + # get execution engine, rtclock requires mlir_c_runner + engine = get_engine(payload_module) + + with workload.allocate_inputs(execution_engine=engine) as inputs: + pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] + if check_correctness: + # call payload once to verify correctness + # prepare function arguments + packed_args = get_packed_arg(pointers) + + payload_func = engine.lookup(workload.payload_function_name) + payload_func(packed_args) + success = workload.check_correctness( + execution_engine=engine, verbose=verbose + ) + if not success: + raise ValueError("Benchmark verification failed.") + + # allocate buffer for timings and prepare arguments + time_array = np.zeros((nruns,), dtype=np.float64) + time_memref = get_ranked_memref_descriptor(time_array) + time_pointer = ctypes.pointer(ctypes.pointer(time_memref)) + packed_args_with_time = get_packed_arg(pointers + [time_pointer]) + + # call benchmark function + benchmark_func = engine.lookup("benchmark") + benchmark_func(packed_args_with_time) + + return time_array diff --git a/lighthouse/utils/__init__.py b/lighthouse/utils/__init__.py deleted file mode 100644 index 474b748..0000000 --- a/lighthouse/utils/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""A collection of utility tools""" - -from .runtime_args import ( - get_packed_arg, - memref_to_ctype, - memrefs_to_packed_args, - torch_to_memref, - torch_to_packed_args, - mlir_type_to_torch_dtype, -) - -__all__ = [ - "get_packed_arg", - "memref_to_ctype", - "memrefs_to_packed_args", - "mlir_type_to_torch_dtype", - "torch_to_memref", - "torch_to_packed_args", -] diff --git a/lighthouse/utils/runtime/__init__.py b/lighthouse/utils/runtime/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lighthouse/utils/runtime/ffi.py b/lighthouse/utils/runtime/ffi.py new file mode 100644 index 0000000..f47e726 --- /dev/null +++ b/lighthouse/utils/runtime/ffi.py @@ -0,0 +1,37 @@ +import ctypes + + +def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]: + """ + Return a list of packed ctype arguments compatible with + jitted MLIR function's interface. + + Args: + ctypes_args: A list of ctype pointer arguments. + """ + packed_args = (ctypes.c_void_p * len(ctypes_args))() + for argNum in range(len(ctypes_args)): + packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) + return packed_args + + +def memref_to_ctype(memref_desc) -> ctypes._Pointer: + """ + Convert a memref descriptor into a ctype argument. + + Args: + memref_desc: An MLIR memref descriptor. + """ + return ctypes.pointer(ctypes.pointer(memref_desc)) + + +def memrefs_to_packed_args(memref_descs) -> list[ctypes.c_void_p]: + """ + Convert a list of memref descriptors into packed ctype arguments. + + Args: + memref_descs: A list of memref descriptors. + """ + ctype_args = [memref_to_ctype(memref) for memref in memref_descs] + return get_packed_arg(ctype_args) + diff --git a/lighthouse/utils/runtime_args.py b/lighthouse/utils/runtime/torch.py similarity index 59% rename from lighthouse/utils/runtime_args.py rename to lighthouse/utils/runtime/torch.py index eb6b22a..43c2f95 100644 --- a/lighthouse/utils/runtime_args.py +++ b/lighthouse/utils/runtime/torch.py @@ -1,45 +1,10 @@ import ctypes -import torch -from mlir.runtime.np_to_memref import ( - get_ranked_memref_descriptor, -) +import torch from mlir import ir +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor - -def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]: - """ - Return a list of packed ctype arguments compatible with - jitted MLIR function's interface. - - Args: - ctypes_args: A list of ctype pointer arguments. - """ - packed_args = (ctypes.c_void_p * len(ctypes_args))() - for argNum in range(len(ctypes_args)): - packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) - return packed_args - - -def memref_to_ctype(memref_desc) -> ctypes._Pointer: - """ - Convert a memref descriptor into a ctype argument. - - Args: - memref_desc: An MLIR memref descriptor. - """ - return ctypes.pointer(ctypes.pointer(memref_desc)) - - -def memrefs_to_packed_args(memref_descs) -> list[ctypes.c_void_p]: - """ - Convert a list of memref descriptors into packed ctype arguments. - - Args: - memref_descs: A list of memref descriptors. - """ - ctype_args = [memref_to_ctype(memref) for memref in memref_descs] - return get_packed_arg(ctype_args) +from .ffi import memrefs_to_packed_args def torch_to_memref(input: torch.Tensor) -> ctypes.Structure: From dd87ca40e834cf1ef66c20cb3e8935efd4d09c15 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 9 Dec 2025 10:53:05 -0800 Subject: [PATCH 3/7] Fixes for recently merged workload component 1# with '#' will be ignored, and an empty message aborts the commit. --- examples/workload/example.py | 2 +- examples/workload/example_mlir.py | 4 ++-- lighthouse/workload/runner.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/workload/example.py b/examples/workload/example.py index 3137dad..34d9088 100644 --- a/examples/workload/example.py +++ b/examples/workload/example.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: python %s | FileCheck %s # CHECK: func.func @payload # CHECK: PASSED # CHECK: Throughput: diff --git a/examples/workload/example_mlir.py b/examples/workload/example_mlir.py index 7d3211a..b27586e 100644 --- a/examples/workload/example_mlir.py +++ b/examples/workload/example_mlir.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: python %s | FileCheck %s # CHECK: func.func @payload # CHECK: PASSED # CHECK: Throughput: @@ -19,7 +19,7 @@ from mlir.execution_engine import ExecutionEngine import ctypes from contextlib import contextmanager -from lighthouse.utils import ( +from lighthouse.utils.runtime.ffi import ( get_packed_arg, memrefs_to_packed_args, memref_to_ctype, diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index ae5e07e..9f475fa 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -9,7 +9,7 @@ from mlir.execution_engine import ExecutionEngine from mlir.runtime.np_to_memref import get_ranked_memref_descriptor from lighthouse.utils.mlir import get_mlir_library_path -from lighthouse.utils import memrefs_to_packed_args +from lighthouse.utils.runtime.ffi import memrefs_to_packed_args from lighthouse.workload import Workload from typing import Optional From ff999ecebac23f6fdc77457871c2d364f5e3e323 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 9 Dec 2025 11:10:03 -0800 Subject: [PATCH 4/7] Fewer kernelbench examples to speed it up further --- examples/ingress/convert-kernel-bench-to-mlir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ingress/convert-kernel-bench-to-mlir.py b/examples/ingress/convert-kernel-bench-to-mlir.py index 055ddd6..414e557 100755 --- a/examples/ingress/convert-kernel-bench-to-mlir.py +++ b/examples/ingress/convert-kernel-bench-to-mlir.py @@ -1,4 +1,4 @@ -# RUN: python %s 1,1 1,2 1,3 2,1 2,2 2,3 +# RUN: python %s 1,1 1,2 2,1 # REQUIRES: torch # REQUIRES: kernel_bench From ee64cf39bfa4ddde788e763236673dbe579f1e87 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 9 Dec 2025 11:10:46 -0800 Subject: [PATCH 5/7] REQUIRES in cfg instead --- examples/ingress/torch/lit.local.cfg | 3 +++ examples/ingress/torch/mlp_from_file.py | 1 - examples/ingress/torch/mlp_from_model.py | 1 - lighthouse/utils/runtime/ffi.py | 1 - 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ingress/torch/lit.local.cfg b/examples/ingress/torch/lit.local.cfg index 2813a1b..935efbf 100644 --- a/examples/ingress/torch/lit.local.cfg +++ b/examples/ingress/torch/lit.local.cfg @@ -1 +1,4 @@ config.excludes = ["MLPModel"] + +if "torch" not in config.available_features: + config.unsupported = True diff --git a/examples/ingress/torch/mlp_from_file.py b/examples/ingress/torch/mlp_from_file.py index da9c014..d149151 100644 --- a/examples/ingress/torch/mlp_from_file.py +++ b/examples/ingress/torch/mlp_from_file.py @@ -1,5 +1,4 @@ # RUN: python %s -# REQUIRES: torch """ Example demonstrating how to load a PyTorch model to MLIR using Lighthouse diff --git a/examples/ingress/torch/mlp_from_model.py b/examples/ingress/torch/mlp_from_model.py index e7ebf79..cfdce50 100644 --- a/examples/ingress/torch/mlp_from_model.py +++ b/examples/ingress/torch/mlp_from_model.py @@ -1,5 +1,4 @@ # RUN: python %s -# REQUIRES: torch """ Example demonstrating how to load an already initialized PyTorch model diff --git a/lighthouse/utils/runtime/ffi.py b/lighthouse/utils/runtime/ffi.py index f47e726..3c87b44 100644 --- a/lighthouse/utils/runtime/ffi.py +++ b/lighthouse/utils/runtime/ffi.py @@ -34,4 +34,3 @@ def memrefs_to_packed_args(memref_descs) -> list[ctypes.c_void_p]: """ ctype_args = [memref_to_ctype(memref) for memref in memref_descs] return get_packed_arg(ctype_args) - From f095c8d04f8957daf5fdd7a3bf8e5422b840c15c Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 9 Dec 2025 11:34:10 -0800 Subject: [PATCH 6/7] Revert to %PYTHON in RUN lines --- examples/ingress/convert-kernel-bench-to-mlir.py | 2 +- examples/ingress/torch/mlp_from_file.py | 2 +- examples/ingress/torch/mlp_from_model.py | 2 +- examples/llama/test_llama3.py | 2 +- examples/mlir/compile_and_run.py | 2 +- .../schedule/transform_a_payload_according_to_a_schedule.py | 2 +- examples/workload/example.py | 2 +- examples/workload/example_mlir.py | 2 +- examples/xegpu_matmul/matmul.py | 2 +- lit.cfg.py | 4 ++-- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/ingress/convert-kernel-bench-to-mlir.py b/examples/ingress/convert-kernel-bench-to-mlir.py index 414e557..2cd59ba 100755 --- a/examples/ingress/convert-kernel-bench-to-mlir.py +++ b/examples/ingress/convert-kernel-bench-to-mlir.py @@ -1,4 +1,4 @@ -# RUN: python %s 1,1 1,2 2,1 +# RUN: %PYTHON %s 1,1 1,2 2,1 # REQUIRES: torch # REQUIRES: kernel_bench diff --git a/examples/ingress/torch/mlp_from_file.py b/examples/ingress/torch/mlp_from_file.py index d149151..748de85 100644 --- a/examples/ingress/torch/mlp_from_file.py +++ b/examples/ingress/torch/mlp_from_file.py @@ -1,4 +1,4 @@ -# RUN: python %s +# RUN: %PYTHON %s """ Example demonstrating how to load a PyTorch model to MLIR using Lighthouse diff --git a/examples/ingress/torch/mlp_from_model.py b/examples/ingress/torch/mlp_from_model.py index cfdce50..4590429 100644 --- a/examples/ingress/torch/mlp_from_model.py +++ b/examples/ingress/torch/mlp_from_model.py @@ -1,4 +1,4 @@ -# RUN: python %s +# RUN: %PYTHON %s """ Example demonstrating how to load an already initialized PyTorch model diff --git a/examples/llama/test_llama3.py b/examples/llama/test_llama3.py index ed45140..a2115b3 100644 --- a/examples/llama/test_llama3.py +++ b/examples/llama/test_llama3.py @@ -1,4 +1,4 @@ -# RUN: python %s +# RUN: %PYTHON %s # REQUIRES: torch import functools diff --git a/examples/mlir/compile_and_run.py b/examples/mlir/compile_and_run.py index b3726ee..dc2d537 100644 --- a/examples/mlir/compile_and_run.py +++ b/examples/mlir/compile_and_run.py @@ -1,4 +1,4 @@ -# RUN: python %s +# RUN: %PYTHON %s # REQUIRES: torch import torch diff --git a/examples/schedule/transform_a_payload_according_to_a_schedule.py b/examples/schedule/transform_a_payload_according_to_a_schedule.py index 71cbd4e..fcbbe7b 100644 --- a/examples/schedule/transform_a_payload_according_to_a_schedule.py +++ b/examples/schedule/transform_a_payload_according_to_a_schedule.py @@ -1,4 +1,4 @@ -# RUN: python %s | FileCheck %s +# RUN: %PYTHON %s | FileCheck %s # Simply demonstrates applying a schedule to a payload. # To do so generates a basic payload and a basic schedule, purely as an example. diff --git a/examples/workload/example.py b/examples/workload/example.py index 34d9088..1293998 100644 --- a/examples/workload/example.py +++ b/examples/workload/example.py @@ -1,4 +1,4 @@ -# RUN: python %s | FileCheck %s +# RUN: %PYTHON %s | FileCheck %s # CHECK: func.func @payload # CHECK: PASSED # CHECK: Throughput: diff --git a/examples/workload/example_mlir.py b/examples/workload/example_mlir.py index b27586e..acf2730 100644 --- a/examples/workload/example_mlir.py +++ b/examples/workload/example_mlir.py @@ -1,4 +1,4 @@ -# RUN: python %s | FileCheck %s +# RUN: %PYTHON %s | FileCheck %s # CHECK: func.func @payload # CHECK: PASSED # CHECK: Throughput: diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index dd21d8d..ee78e08 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -1,4 +1,4 @@ -# RUN: python %s --dump-kernel=xegpu-wg | FileCheck %s +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s # CHECK: module attributes {gpu.container_module} { """ diff --git a/lit.cfg.py b/lit.cfg.py index a83f501..b9f0dac 100644 --- a/lit.cfg.py +++ b/lit.cfg.py @@ -15,8 +15,8 @@ config.test_exec_root = project_root + "/lit.out" config.substitutions.append(("%CACHE", project_root + "/cache")) -if python_path := os.environ.get("PYTHON"): - config.substitutions.append(("python", python_path)) +python = os.environ.get("PYTHON", "python") +config.substitutions.append(("%PYTHON", python)) if filecheck_path := os.environ.get("FILECHECK"): config.substitutions.append(("FileCheck", filecheck_path)) From e13be22e3ee57794e8ad3a611d2a70f4b7767330 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 15 Dec 2025 13:40:31 -0800 Subject: [PATCH 7/7] Fix utils modules naming with lazy loading support --- .../ingress/convert-kernel-bench-to-mlir.py | 4 +- examples/llama/test_llama3.py | 58 +++--- examples/mlir/compile_and_run.py | 4 +- examples/workload/example.py | 22 +-- examples/workload/example_mlir.py | 31 ++- examples/xegpu_matmul/matmul.py | 2 +- examples/xegpu_matmul/runner.py | 177 ------------------ lighthouse/ingress/__init__.py | 22 ++- lighthouse/utils/__init__.py | 21 +++ .../utils/{runtime/ffi.py => memref.py} | 29 +-- lighthouse/utils/runtime/__init__.py | 0 lighthouse/utils/{runtime => }/torch.py | 12 +- lighthouse/workload/runner.py | 8 +- 13 files changed, 122 insertions(+), 268 deletions(-) delete mode 100644 examples/xegpu_matmul/runner.py create mode 100644 lighthouse/utils/__init__.py rename lighthouse/utils/{runtime/ffi.py => memref.py} (72%) delete mode 100644 lighthouse/utils/runtime/__init__.py rename lighthouse/utils/{runtime => }/torch.py (80%) diff --git a/examples/ingress/convert-kernel-bench-to-mlir.py b/examples/ingress/convert-kernel-bench-to-mlir.py index 2cd59ba..0dbd150 100755 --- a/examples/ingress/convert-kernel-bench-to-mlir.py +++ b/examples/ingress/convert-kernel-bench-to-mlir.py @@ -15,7 +15,7 @@ from typing import Iterable from mlir import ir, passmanager -from lighthouse.ingress import torch as torch_ingress +import lighthouse.ingress as lh_ingress project_root = Path(__file__).parent.parent.parent torch_kernels_dir = project_root / "third_party" / "KernelBench" / "KernelBench" @@ -173,7 +173,7 @@ def process_task(task: KernelConversionTask): print("Processing:", kernel_relative_name) try: - mlir_kernel = torch_ingress.import_from_file(task.torch_path, ir_context=ctx) + mlir_kernel = lh_ingress.torch.import_from_file(task.torch_path, ir_context=ctx) assert isinstance(mlir_kernel, ir.Module) except Exception as e: print( diff --git a/examples/llama/test_llama3.py b/examples/llama/test_llama3.py index a2115b3..185e660 100644 --- a/examples/llama/test_llama3.py +++ b/examples/llama/test_llama3.py @@ -6,17 +6,16 @@ import pytest import torch - from mlir import ir from mlir.dialects import transform, func, linalg, tensor, arith, complex, math from mlir.dialects.linalg import ElementwiseKind from mlir.dialects.transform import structured, bufferization, interpreter from mlir.passmanager import PassManager -from mlir.runtime.np_to_memref import ( - get_ranked_memref_descriptor, -) +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor from mlir.execution_engine import ExecutionEngine +from lighthouse import utils as lh_utils + from ref_model import ( Attention, ModelArgs, @@ -26,7 +25,6 @@ TransformerBlock, Transformer, ) -from lighthouse.utils.runtime import ffi as ffi_utils, torch as torch_utils def with_mlir_ctx_and_location(func): @@ -1021,7 +1019,7 @@ def bin_op(a, b, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("bin_op") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) a = torch.randn(*shape, dtype=torch_dtype) b = torch.randn(*shape, dtype=torch_dtype) out_ref = references[op](a, b) @@ -1031,7 +1029,7 @@ def bin_op(a, b, out): a_mem = get_ranked_memref_descriptor(a.numpy()) b_mem = get_ranked_memref_descriptor(b.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem]) + args = lh_utils.memref.to_packed_args([a_mem, b_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1077,14 +1075,14 @@ def unary_op(a, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("unary_op") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) a = torch.randn(*shape, dtype=torch_dtype) out_ref = references[op](a) out = torch.empty_like(out_ref) a_mem = get_ranked_memref_descriptor(a.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args([a_mem, out_mem]) + args = lh_utils.memref.to_packed_args([a_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1113,13 +1111,13 @@ def rms_norm(a, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("rms_norm") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) a = torch.randn(*shape, dtype=torch_dtype) out_ref = references[get_l2_norm](a, eps) out = torch.empty_like(out_ref) a_mem = get_ranked_memref_descriptor(a.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args([a_mem, out_mem]) + args = lh_utils.memref.to_packed_args([a_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1161,7 +1159,7 @@ def linear_op(x, w, b, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("linear_op") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(*shape, in_features, dtype=torch_dtype) w = torch.randn(out_features, in_features, dtype=torch_dtype) b = torch.randn(out_features, dtype=torch_dtype) @@ -1172,7 +1170,7 @@ def linear_op(x, w, b, out): w_mem = get_ranked_memref_descriptor(w.numpy()) b_mem = get_ranked_memref_descriptor(b.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args([x_mem, w_mem, b_mem, out_mem]) + args = lh_utils.memref.to_packed_args([x_mem, w_mem, b_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1202,7 +1200,7 @@ def polar_op(magnitude, angle, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("polar_op") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) magnitude = torch.randn(4, 16, dtype=torch_dtype) angle = torch.randn(4, 16, dtype=torch_dtype) out_ref = references[get_polar](magnitude, angle) @@ -1210,7 +1208,7 @@ def polar_op(magnitude, angle, out): magnitude_mem = get_ranked_memref_descriptor(magnitude.numpy()) angle_mem = get_ranked_memref_descriptor(angle.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args([magnitude_mem, angle_mem, out_mem]) + args = lh_utils.memref.to_packed_args([magnitude_mem, angle_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1238,14 +1236,14 @@ def repeat_kv_op(x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("repeat_kv_op") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(2, 512, 8, 64, dtype=torch_dtype) out_ref = references[get_repeat_kv](x, n_rep) out = torch.empty_like(out_ref) x_mem = get_ranked_memref_descriptor(x.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args([x_mem, out_mem]) + args = lh_utils.memref.to_packed_args([x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1276,7 +1274,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("reshape_for_broadcast") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) freqs_cis = torch.randn(512, 64, dtype=torch_dtype) x = torch.randn(2, 512, 32, 128, dtype=torch_dtype) # Convert x to complex view as expected by reshape_for_broadcast @@ -1287,7 +1285,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out): freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy()) x_mem = get_ranked_memref_descriptor(x.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args([freqs_cis_mem, x_mem, out_mem]) + args = lh_utils.memref.to_packed_args([freqs_cis_mem, x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1318,7 +1316,7 @@ def view_as_complex_op(x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("view_as_complex_op") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(2, 512, 32, 128, dtype=torch_dtype) x_reshaped = x.reshape(2, 512, 32, 64, 2) out_ref = torch.view_as_complex(x_reshaped) @@ -1326,7 +1324,7 @@ def view_as_complex_op(x, out): x_mem = get_ranked_memref_descriptor(x_reshaped.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args([x_mem, out_mem]) + args = lh_utils.memref.to_packed_args([x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1354,7 +1352,7 @@ def as_real_op(x, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("as_real_op") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(2, 512, 32, 64, 2, dtype=torch_dtype) x_complex = torch.view_as_complex(x) out_ref = torch.view_as_real(x_complex) @@ -1362,7 +1360,7 @@ def as_real_op(x, out): x_mem = get_ranked_memref_descriptor(x_complex.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args([x_mem, out_mem]) + args = lh_utils.memref.to_packed_args([x_mem, out_mem]) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) @@ -1395,7 +1393,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out): return module ir_type = to_ir_type(elem_type) - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) xq_shape = (batch_size, seq_len, n_heads, head_dim) xk_shape = (batch_size, seq_len, n_kv_heads, head_dim) freqs_cis_shape = (seq_len, head_dim // 2) @@ -1424,7 +1422,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out): freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy()) out1_mem = get_ranked_memref_descriptor(out1.numpy()) out2_mem = get_ranked_memref_descriptor(out2.numpy()) - args = ffi_utils.memrefs_to_packed_args( + args = lh_utils.memref.to_packed_args( [a_mem, b_mem, freqs_cis_mem, out1_mem, out2_mem] ) func_ptr(args) @@ -1489,7 +1487,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out): eng = ExecutionEngine(module, opt_level=2) func_ptr = eng.lookup("feed_forward") - torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type) + torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type) x = torch.randn(4, 16, dtype=torch_dtype) w1 = torch.randn(64, 16, dtype=torch_dtype) b1 = torch.randn(64, dtype=torch_dtype) @@ -1512,7 +1510,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out): w3_mem = get_ranked_memref_descriptor(w3.numpy()) b3_mem = get_ranked_memref_descriptor(b3.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args( + args = lh_utils.memref.to_packed_args( [x_mem, w1_mem, b1_mem, w2_mem, b2_mem, w3_mem, b3_mem, out_mem] ) func_ptr(args) @@ -1645,7 +1643,7 @@ def attention_op(x, wq, wk, wv, wo, freqs_cis, mask, out): freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis_real.numpy()) mask_mem = get_ranked_memref_descriptor(mask.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args( + args = lh_utils.memref.to_packed_args( [x_mem, wq_mem, wk_mem, wv_mem, wo_mem, freqs_cis_mem, mask_mem, out_mem] ) func_ptr(args) @@ -1792,7 +1790,7 @@ def transformer_block_op( b3_mem = get_ranked_memref_descriptor(b3.numpy()) out_mem = get_ranked_memref_descriptor(out.numpy()) - args = ffi_utils.memrefs_to_packed_args( + args = lh_utils.memref.to_packed_args( [ x_mem, wq_mem, @@ -1981,7 +1979,7 @@ def transformer_op(*params): out_mem = get_ranked_memref_descriptor(out.numpy()) memrefs.append(out_mem) - args = ffi_utils.memrefs_to_packed_args(memrefs) + args = lh_utils.memref.to_packed_args(memrefs) func_ptr(args) assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True) diff --git a/examples/mlir/compile_and_run.py b/examples/mlir/compile_and_run.py index dc2d537..20b5cc5 100644 --- a/examples/mlir/compile_and_run.py +++ b/examples/mlir/compile_and_run.py @@ -10,7 +10,7 @@ from mlir.execution_engine import ExecutionEngine from mlir.passmanager import PassManager -from lighthouse.utils.runtime import torch as torch_utils +import lighthouse.utils as lh_utils def create_kernel(ctx: ir.Context) -> ir.Module: @@ -168,7 +168,7 @@ def main(args): out = torch.empty_like(out_ref) # Execute the kernel. - args = torch_utils.torch_to_packed_args([a, b, out]) + args = lh_utils.torch.to_packed_args([a, b, out]) add_func(args) ### Verification ### diff --git a/examples/workload/example.py b/examples/workload/example.py index 1293998..c37b226 100644 --- a/examples/workload/example.py +++ b/examples/workload/example.py @@ -6,26 +6,20 @@ Workload example: Element-wise sum of two (M, N) float32 arrays on CPU. """ +import ctypes +from contextlib import contextmanager +from functools import cached_property +from typing import Optional + import numpy as np from mlir import ir from mlir.runtime.np_to_memref import get_ranked_memref_descriptor from mlir.dialects import func, linalg, bufferization from mlir.dialects import transform from mlir.execution_engine import ExecutionEngine -from contextlib import contextmanager -from functools import cached_property -import ctypes -from typing import Optional -from lighthouse.utils.mlir import ( - apply_registered_pass, - canonicalize, - match, -) -from lighthouse.workload import ( - Workload, - execute, - benchmark, -) + +from lighthouse.utils.mlir import apply_registered_pass, canonicalize, match +from lighthouse.workload import Workload, execute, benchmark class ElementwiseSum(Workload): diff --git a/examples/workload/example_mlir.py b/examples/workload/example_mlir.py index acf2730..376993e 100644 --- a/examples/workload/example_mlir.py +++ b/examples/workload/example_mlir.py @@ -8,6 +8,9 @@ In this example, allocation and deallocation of input arrays is done in MLIR. """ +import ctypes +from contextlib import contextmanager + import numpy as np from mlir import ir from mlir.runtime.np_to_memref import ( @@ -17,18 +20,11 @@ ) from mlir.dialects import func, linalg, arith, memref from mlir.execution_engine import ExecutionEngine -import ctypes -from contextlib import contextmanager -from lighthouse.utils.runtime.ffi import ( - get_packed_arg, - memrefs_to_packed_args, - memref_to_ctype, -) + +from lighthouse.workload import execute, benchmark +import lighthouse.utils as lh_utils + from example import ElementwiseSum -from lighthouse.workload import ( - execute, - benchmark, -) def emit_host_alloc(suffix: str, element_type: ir.Type, rank: int = 2): @@ -114,16 +110,16 @@ def _allocate_array( # construct a memref descriptor for the result memref shape = (self.M, self.N) mref = make_nd_memref_descriptor(len(shape), as_ctype(self.dtype))() - ptr_mref = memref_to_ctype(mref) + ptr_mref = lh_utils.memref.to_ctype(mref) ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape] - alloc_func(get_packed_arg([ptr_mref, *ptr_dims])) + alloc_func(lh_utils.memref.get_packed_arg([ptr_mref, *ptr_dims])) self.memrefs[name] = mref return mref def _deallocate_all(self, execution_engine: ExecutionEngine): for mref in self.memrefs.values(): dealloc_func = execution_engine.lookup("host_dealloc_f32") - dealloc_func(memrefs_to_packed_args([mref])) + dealloc_func(lh_utils.memref.to_packed_args([mref])) self.memrefs = {} def get_input_arrays( @@ -136,10 +132,9 @@ def get_input_arrays( # initialize with MLIR fill_zero_func = execution_engine.lookup("host_fill_constant_zero_f32") fill_random_func = execution_engine.lookup("host_fill_random_f32") - fill_zero_func(memrefs_to_packed_args([C])) - fill_random_func(memrefs_to_packed_args([A])) - fill_random_func(memrefs_to_packed_args([B])) - + fill_zero_func(lh_utils.memref.to_packed_args([C])) + fill_random_func(lh_utils.memref.to_packed_args([A])) + fill_random_func(lh_utils.memref.to_packed_args([B])) return [A, B, C] @contextmanager diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index ee78e08..d86478e 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -19,9 +19,9 @@ as_ctype, ) from mlir.execution_engine import ExecutionEngine -from lighthouse.utils.memref import get_packed_arg, memref_to_ctype from lighthouse.workload import Workload, benchmark +from lighthouse.utils.memref import get_packed_arg, to_ctype as memref_to_ctype # Import from sibling files: from schedule import get_schedule_module diff --git a/examples/xegpu_matmul/runner.py b/examples/xegpu_matmul/runner.py deleted file mode 100644 index 219672a..0000000 --- a/examples/xegpu_matmul/runner.py +++ /dev/null @@ -1,177 +0,0 @@ -import numpy as np -import ctypes -import os -from typing import Optional - -from mlir.dialects import func, arith, scf, memref -from mlir.execution_engine import ExecutionEngine -from mlir import ir -from mlir.runtime.np_to_memref import get_ranked_memref_descriptor - -from lighthouse.utils.memref import get_packed_arg -from mlir_utils import get_mlir_library_path - - -def get_engine(payload_module: ir.Module, opt_level: int = 3) -> ExecutionEngine: - lib_dir = get_mlir_library_path() - libs = [ - "libmlir_levelzero_runtime.so", - "libmlir_runner_utils.so", - "libmlir_c_runner_utils.so", - ] - libs = [os.path.join(lib_dir, lib) for lib in libs] - execution_engine = ExecutionEngine( - payload_module, opt_level=opt_level, shared_libs=libs - ) - execution_engine.initialize() - return execution_engine - - -def apply_transform_schedule( - payload_module: ir.Module, - schedule_module: ir.Module, - dump_kernel: Optional[str] = None, - dump_schedule: bool = False, -): - if not dump_kernel or dump_kernel != "initial": - # apply schedule on payload module - named_seq = schedule_module.body.operations[0] - named_seq.apply(payload_module) - if dump_kernel: - print(payload_module) - if dump_schedule: - print(schedule_module) - - -def lower_payload( - workload: object, - dump_kernel: Optional[str] = None, - dump_schedule: bool = False, - schedule_parameters: Optional[dict] = None, -) -> ir.Module: - payload_module = workload.payload_module() - schedule_module = workload.schedule_module( - dump_kernel=dump_kernel, parameters=schedule_parameters - ) - apply_transform_schedule( - payload_module, - schedule_module, - dump_kernel=dump_kernel, - dump_schedule=dump_schedule, - ) - return payload_module - - -def execute( - workload: object, - check_correctness: bool = True, - schedule_parameters: Optional[dict] = None, - verbose: int = 0, -): - # lower payload with schedule - payload_module = lower_payload(workload, schedule_parameters=schedule_parameters) - # get execution engine - engine = get_engine(payload_module, requirements=workload.requirements()) - - with workload.allocate_inputs(execution_engine=engine) as inputs: - # prepare function arguments - pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] - packed_args = get_packed_arg(pointers) - - # handle to payload function - payload_func = engine.lookup(workload.payload_function_name) - - # call - payload_func(packed_args) - - if check_correctness: - workload.check_correctness(execution_engine=engine, verbose=verbose) - - -def benchmark( - workload: object, - nruns: int = 100, - nwarmup: int = 10, - schedule_parameters: Optional[dict] = None, - check_correctness: bool = True, - verbose: int = 0, -) -> np.ndarray: - # get original payload module - payload_module = workload.payload_module() - - # find payload function - payload_func = None - for op in payload_module.operation.regions[0].blocks[0]: - if ( - isinstance(op, func.FuncOp) - and op.name.value == workload.payload_function_name - ): - payload_func = op - break - assert payload_func is not None, "Could not find payload function" - payload_arguments = payload_func.type.inputs - - # emit benchmark function that calls payload and times it - with ir.InsertionPoint(payload_module.body): - # define rtclock function - f64_t = ir.F64Type.get() - func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") - # emit benchmark function - time_memref_t = ir.MemRefType.get((nruns,), f64_t) - args = payload_arguments + [time_memref_t] - - @func.func(*args) - def benchmark(*args): - index_t = ir.IndexType.get() - zero = arith.constant(index_t, 0) - one = arith.constant(index_t, 1) - nwarmup_cst = arith.constant(index_t, nwarmup) - for i in scf.for_(zero, nwarmup_cst, one): - # FIXME(upstream): func.call is broken for this use case? - func.CallOp(payload_func, list(args[: len(payload_arguments)])) - scf.yield_(()) - nruns_cst = arith.constant(index_t, nruns) - for i in scf.for_(zero, nruns_cst, one): - tic = func.call((f64_t,), "rtclock", ()) - func.CallOp(payload_func, list(args[: len(payload_arguments)])) - toc = func.call((f64_t,), "rtclock", ()) - time = arith.subf(toc, tic) - memref.store(time, args[-1], [i]) - scf.yield_(()) - - benchmark.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - - # lower - apply_transform_schedule( - payload_module, - workload.schedule_module(parameters=schedule_parameters), - ) - # get execution engine, rtclock requires mlir_c_runner - engine = get_engine(payload_module) - - with workload.allocate_inputs(execution_engine=engine) as inputs: - pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs] - if check_correctness: - # call payload once to verify correctness - # prepare function arguments - packed_args = get_packed_arg(pointers) - - payload_func = engine.lookup(workload.payload_function_name) - payload_func(packed_args) - success = workload.check_correctness( - execution_engine=engine, verbose=verbose - ) - if not success: - raise ValueError("Benchmark verification failed.") - - # allocate buffer for timings and prepare arguments - time_array = np.zeros((nruns,), dtype=np.float64) - time_memref = get_ranked_memref_descriptor(time_array) - time_pointer = ctypes.pointer(ctypes.pointer(time_memref)) - packed_args_with_time = get_packed_arg(pointers + [time_pointer]) - - # call benchmark function - benchmark_func = engine.lookup("benchmark") - benchmark_func(packed_args_with_time) - - return time_array diff --git a/lighthouse/ingress/__init__.py b/lighthouse/ingress/__init__.py index 2aa3859..989c237 100644 --- a/lighthouse/ingress/__init__.py +++ b/lighthouse/ingress/__init__.py @@ -1 +1,21 @@ -"""Provides functions to convert source objects (code, models, designs) into MLIR files that the MLIR project can consume""" +__all__ = ["mlir_gen", "torch"] + +import sys +import importlib + + +def __getattr__(name): + """Enable lazy loading of submodules. + + Enables `import lighthouse.ingress as lh_ingress; lh_ingress.` with + loading of (the submodule's heavy) depenendencies only upon being needed. + """ + + if name in __all__: + # Import the submodule and cache it on the current module. That is, + # upon the next access __getattr__ will not be called. + submodule = importlib.import_module("lighthouse.ingress." + name) + lighthouse_ingress_mod = sys.modules[__name__] + setattr(lighthouse_ingress_mod, name, submodule) + return submodule + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/lighthouse/utils/__init__.py b/lighthouse/utils/__init__.py new file mode 100644 index 0000000..6a85edd --- /dev/null +++ b/lighthouse/utils/__init__.py @@ -0,0 +1,21 @@ +__all__ = ["memref", "mlir", "torch"] + +import sys +import importlib + + +def __getattr__(name): + """Enable lazy loading of submodules. + + Enables `import lighthouse.utils as lh_utils; lh_utils.` with + loading of (the submodule's heavy) depenendencies only upon being needed. + """ + + if name in __all__: + # Import the submodule and cache it on the current module. That is, + # upon the next access __getattr__ will not be called. + submodule = importlib.import_module("lighthouse.utils." + name) + lighthouse_utils_mod = sys.modules[__name__] + setattr(lighthouse_utils_mod, name, submodule) + return submodule + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/lighthouse/utils/runtime/ffi.py b/lighthouse/utils/memref.py similarity index 72% rename from lighthouse/utils/runtime/ffi.py rename to lighthouse/utils/memref.py index 3c87b44..63f67a1 100644 --- a/lighthouse/utils/runtime/ffi.py +++ b/lighthouse/utils/memref.py @@ -1,7 +1,20 @@ import ctypes +from typing import Sequence -def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]: +def to_ctype(memref_desc) -> ctypes._Pointer: + """ + Convert a memref descriptor into a ctype argument. + + Args: + memref_desc: An MLIR memref descriptor. + """ + return ctypes.pointer(ctypes.pointer(memref_desc)) + + +def get_packed_arg( + ctypes_args: Sequence[ctypes._Pointer], +) -> ctypes.Array[ctypes.c_void_p]: """ Return a list of packed ctype arguments compatible with jitted MLIR function's interface. @@ -15,22 +28,12 @@ def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]: return packed_args -def memref_to_ctype(memref_desc) -> ctypes._Pointer: - """ - Convert a memref descriptor into a ctype argument. - - Args: - memref_desc: An MLIR memref descriptor. - """ - return ctypes.pointer(ctypes.pointer(memref_desc)) - - -def memrefs_to_packed_args(memref_descs) -> list[ctypes.c_void_p]: +def to_packed_args(memref_descs) -> ctypes.Array[ctypes.c_void_p]: """ Convert a list of memref descriptors into packed ctype arguments. Args: memref_descs: A list of memref descriptors. """ - ctype_args = [memref_to_ctype(memref) for memref in memref_descs] + ctype_args = [to_ctype(memref) for memref in memref_descs] return get_packed_arg(ctype_args) diff --git a/lighthouse/utils/runtime/__init__.py b/lighthouse/utils/runtime/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/lighthouse/utils/runtime/torch.py b/lighthouse/utils/torch.py similarity index 80% rename from lighthouse/utils/runtime/torch.py rename to lighthouse/utils/torch.py index 43c2f95..4fa3d1c 100644 --- a/lighthouse/utils/runtime/torch.py +++ b/lighthouse/utils/torch.py @@ -4,10 +4,10 @@ from mlir import ir from mlir.runtime.np_to_memref import get_ranked_memref_descriptor -from .ffi import memrefs_to_packed_args +from . import memref as memref_utils -def torch_to_memref(input: torch.Tensor) -> ctypes.Structure: +def to_memref(input: torch.Tensor) -> ctypes.Structure: """ Convert a PyTorch tensor into a memref descriptor. @@ -17,18 +17,18 @@ def torch_to_memref(input: torch.Tensor) -> ctypes.Structure: return get_ranked_memref_descriptor(input.numpy()) -def torch_to_packed_args(inputs: list[torch.Tensor]) -> list[ctypes.c_void_p]: +def to_packed_args(inputs: list[torch.Tensor]) -> ctypes.Array[ctypes.c_void_p]: """ Convert a list of PyTorch tensors into packed ctype arguments. Args: inputs: A list of PyTorch tensors. """ - memrefs = [torch_to_memref(input) for input in inputs] - return memrefs_to_packed_args(memrefs) + memrefs = [to_memref(input) for input in inputs] + return memref_utils.to_packed_args(memrefs) -def mlir_type_to_torch_dtype(mlir_type: ir.Type): +def dtype_from_mlir_type(mlir_type: ir.Type): """ Convert an MLIR type to a PyTorch dtype. Args: diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py index 9f475fa..9e9a5d0 100644 --- a/lighthouse/workload/runner.py +++ b/lighthouse/workload/runner.py @@ -9,7 +9,7 @@ from mlir.execution_engine import ExecutionEngine from mlir.runtime.np_to_memref import get_ranked_memref_descriptor from lighthouse.utils.mlir import get_mlir_library_path -from lighthouse.utils.runtime.ffi import memrefs_to_packed_args +from lighthouse.utils.memref import to_packed_args from lighthouse.workload import Workload from typing import Optional @@ -44,7 +44,7 @@ def execute( with workload.allocate_inputs(execution_engine=engine) as inputs: # prepare function arguments - packed_args = memrefs_to_packed_args(inputs) + packed_args = to_packed_args(inputs) # handle to payload function payload_func = engine.lookup(workload.payload_function_name) @@ -143,7 +143,7 @@ def benchmark( if check_correctness: # call payload once to verify correctness # prepare function arguments - packed_args = memrefs_to_packed_args(inputs) + packed_args = to_packed_args(inputs) payload_func = engine.lookup(workload.payload_function_name) payload_func(packed_args) @@ -156,7 +156,7 @@ def benchmark( # allocate buffer for timings and prepare arguments time_array = np.zeros((nruns,), dtype=np.float64) time_memref = get_ranked_memref_descriptor(time_array) - packed_args_with_time = memrefs_to_packed_args(inputs + [time_memref]) + packed_args_with_time = to_packed_args(inputs + [time_memref]) # call benchmark function benchmark_func = engine.lookup("benchmark")