Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ pip install .[ingress_torch_cpu] \

## Running tests

Running the tests is as simple as `lit .` in the root of the project.
Running the tests is as simple as `lit .` in the root of the project (in a suitable Python environment, e.g. through `uv run lit .`).

We assume that the [`FileCheck`](https://llvm.org/docs/CommandGuide/FileCheck.html) and [`lit`](https://llvm.org/docs/CommandGuide/lit.html) executables are available on the `PATH`.

Expand Down
8 changes: 5 additions & 3 deletions examples/ingress/convert-kernel-bench-to-mlir.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# RUN: %PYTHON %s 1,1 1,2 1,3 2,1 2,2 2,3
# RUN: %PYTHON %s 1,1 1,2 2,1
# REQUIRES: torch
# REQUIRES: kernel_bench

# Basic conversion of KernelBench PyTorch kernels to mlir kernels, relying on
# torch-mlir for the conversion. As there are a number of kernels for which
Expand All @@ -13,7 +15,7 @@
from typing import Iterable

from mlir import ir, passmanager
from lighthouse.ingress import torch as torch_ingress
import lighthouse.ingress as lh_ingress

project_root = Path(__file__).parent.parent.parent
torch_kernels_dir = project_root / "third_party" / "KernelBench" / "KernelBench"
Expand Down Expand Up @@ -171,7 +173,7 @@ def process_task(task: KernelConversionTask):
print("Processing:", kernel_relative_name)

try:
mlir_kernel = torch_ingress.import_from_file(task.torch_path, ir_context=ctx)
mlir_kernel = lh_ingress.torch.import_from_file(task.torch_path, ir_context=ctx)
assert isinstance(mlir_kernel, ir.Module)
except Exception as e:
print(
Expand Down
3 changes: 3 additions & 0 deletions examples/ingress/torch/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
config.excludes = ["MLPModel"]

if "torch" not in config.available_features:
config.unsupported = True
2 changes: 1 addition & 1 deletion examples/llama/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -1 +1 @@
config.excludes = ["ref_model.py"]
config.excludes = ["ref_model.py"]
61 changes: 30 additions & 31 deletions examples/llama/test_llama3.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# RUN: %pytest %s
# RUN: %PYTHON %s
# REQUIRES: torch

import functools
import math as pymath
import pytest
import torch


from mlir import ir
from mlir.dialects import transform, func, linalg, tensor, arith, complex, math
from mlir.dialects.linalg import ElementwiseKind
from mlir.dialects.transform import structured, bufferization, interpreter
from mlir.passmanager import PassManager
from mlir.runtime.np_to_memref import (
get_ranked_memref_descriptor,
)
from mlir.runtime.np_to_memref import get_ranked_memref_descriptor
from mlir.execution_engine import ExecutionEngine

from lighthouse import utils as lh_utils

from ref_model import (
Attention,
ModelArgs,
Expand All @@ -25,7 +25,6 @@
TransformerBlock,
Transformer,
)
from lighthouse import utils as lh_utils


def with_mlir_ctx_and_location(func):
Expand Down Expand Up @@ -1020,7 +1019,7 @@ def bin_op(a, b, out):
eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("bin_op")

torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
a = torch.randn(*shape, dtype=torch_dtype)
b = torch.randn(*shape, dtype=torch_dtype)
out_ref = references[op](a, b)
Expand All @@ -1030,7 +1029,7 @@ def bin_op(a, b, out):
a_mem = get_ranked_memref_descriptor(a.numpy())
b_mem = get_ranked_memref_descriptor(b.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem])
args = lh_utils.memref.to_packed_args([a_mem, b_mem, out_mem])
func_ptr(args)

assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
Expand Down Expand Up @@ -1076,14 +1075,14 @@ def unary_op(a, out):
eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("unary_op")

torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
a = torch.randn(*shape, dtype=torch_dtype)
out_ref = references[op](a)
out = torch.empty_like(out_ref)

a_mem = get_ranked_memref_descriptor(a.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args([a_mem, out_mem])
args = lh_utils.memref.to_packed_args([a_mem, out_mem])
func_ptr(args)

assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
Expand Down Expand Up @@ -1112,13 +1111,13 @@ def rms_norm(a, out):

eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("rms_norm")
torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
a = torch.randn(*shape, dtype=torch_dtype)
out_ref = references[get_l2_norm](a, eps)
out = torch.empty_like(out_ref)
a_mem = get_ranked_memref_descriptor(a.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args([a_mem, out_mem])
args = lh_utils.memref.to_packed_args([a_mem, out_mem])
func_ptr(args)

assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
Expand Down Expand Up @@ -1160,7 +1159,7 @@ def linear_op(x, w, b, out):

eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("linear_op")
torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
x = torch.randn(*shape, in_features, dtype=torch_dtype)
w = torch.randn(out_features, in_features, dtype=torch_dtype)
b = torch.randn(out_features, dtype=torch_dtype)
Expand All @@ -1171,7 +1170,7 @@ def linear_op(x, w, b, out):
w_mem = get_ranked_memref_descriptor(w.numpy())
b_mem = get_ranked_memref_descriptor(b.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args([x_mem, w_mem, b_mem, out_mem])
args = lh_utils.memref.to_packed_args([x_mem, w_mem, b_mem, out_mem])
func_ptr(args)
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)

Expand Down Expand Up @@ -1201,15 +1200,15 @@ def polar_op(magnitude, angle, out):

eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("polar_op")
torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
magnitude = torch.randn(4, 16, dtype=torch_dtype)
angle = torch.randn(4, 16, dtype=torch_dtype)
out_ref = references[get_polar](magnitude, angle)
out = torch.empty_like(out_ref)
magnitude_mem = get_ranked_memref_descriptor(magnitude.numpy())
angle_mem = get_ranked_memref_descriptor(angle.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args([magnitude_mem, angle_mem, out_mem])
args = lh_utils.memref.to_packed_args([magnitude_mem, angle_mem, out_mem])
func_ptr(args)
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)

Expand Down Expand Up @@ -1237,14 +1236,14 @@ def repeat_kv_op(x, out):
eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("repeat_kv_op")

torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
x = torch.randn(2, 512, 8, 64, dtype=torch_dtype)
out_ref = references[get_repeat_kv](x, n_rep)
out = torch.empty_like(out_ref)

x_mem = get_ranked_memref_descriptor(x.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args([x_mem, out_mem])
args = lh_utils.memref.to_packed_args([x_mem, out_mem])
func_ptr(args)

assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
Expand Down Expand Up @@ -1275,7 +1274,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out):
eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("reshape_for_broadcast")

torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
freqs_cis = torch.randn(512, 64, dtype=torch_dtype)
x = torch.randn(2, 512, 32, 128, dtype=torch_dtype)
# Convert x to complex view as expected by reshape_for_broadcast
Expand All @@ -1286,7 +1285,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out):
freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy())
x_mem = get_ranked_memref_descriptor(x.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args([freqs_cis_mem, x_mem, out_mem])
args = lh_utils.memref.to_packed_args([freqs_cis_mem, x_mem, out_mem])
func_ptr(args)

assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
Expand Down Expand Up @@ -1317,15 +1316,15 @@ def view_as_complex_op(x, out):
eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("view_as_complex_op")

torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
x = torch.randn(2, 512, 32, 128, dtype=torch_dtype)
x_reshaped = x.reshape(2, 512, 32, 64, 2)
out_ref = torch.view_as_complex(x_reshaped)
out = torch.empty_like(out_ref)

x_mem = get_ranked_memref_descriptor(x_reshaped.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args([x_mem, out_mem])
args = lh_utils.memref.to_packed_args([x_mem, out_mem])
func_ptr(args)

assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
Expand Down Expand Up @@ -1353,15 +1352,15 @@ def as_real_op(x, out):
eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("as_real_op")

torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
x = torch.randn(2, 512, 32, 64, 2, dtype=torch_dtype)
x_complex = torch.view_as_complex(x)
out_ref = torch.view_as_real(x_complex)
out = torch.empty_like(out_ref)

x_mem = get_ranked_memref_descriptor(x_complex.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args([x_mem, out_mem])
args = lh_utils.memref.to_packed_args([x_mem, out_mem])
func_ptr(args)

assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
Expand Down Expand Up @@ -1394,7 +1393,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
return module

ir_type = to_ir_type(elem_type)
torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
xq_shape = (batch_size, seq_len, n_heads, head_dim)
xk_shape = (batch_size, seq_len, n_kv_heads, head_dim)
freqs_cis_shape = (seq_len, head_dim // 2)
Expand Down Expand Up @@ -1423,7 +1422,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy())
out1_mem = get_ranked_memref_descriptor(out1.numpy())
out2_mem = get_ranked_memref_descriptor(out2.numpy())
args = lh_utils.memrefs_to_packed_args(
args = lh_utils.memref.to_packed_args(
[a_mem, b_mem, freqs_cis_mem, out1_mem, out2_mem]
)
func_ptr(args)
Expand Down Expand Up @@ -1488,7 +1487,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out):
eng = ExecutionEngine(module, opt_level=2)
func_ptr = eng.lookup("feed_forward")

torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
x = torch.randn(4, 16, dtype=torch_dtype)
w1 = torch.randn(64, 16, dtype=torch_dtype)
b1 = torch.randn(64, dtype=torch_dtype)
Expand All @@ -1511,7 +1510,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out):
w3_mem = get_ranked_memref_descriptor(w3.numpy())
b3_mem = get_ranked_memref_descriptor(b3.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args(
args = lh_utils.memref.to_packed_args(
[x_mem, w1_mem, b1_mem, w2_mem, b2_mem, w3_mem, b3_mem, out_mem]
)
func_ptr(args)
Expand Down Expand Up @@ -1644,7 +1643,7 @@ def attention_op(x, wq, wk, wv, wo, freqs_cis, mask, out):
freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis_real.numpy())
mask_mem = get_ranked_memref_descriptor(mask.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())
args = lh_utils.memrefs_to_packed_args(
args = lh_utils.memref.to_packed_args(
[x_mem, wq_mem, wk_mem, wv_mem, wo_mem, freqs_cis_mem, mask_mem, out_mem]
)
func_ptr(args)
Expand Down Expand Up @@ -1791,7 +1790,7 @@ def transformer_block_op(
b3_mem = get_ranked_memref_descriptor(b3.numpy())
out_mem = get_ranked_memref_descriptor(out.numpy())

args = lh_utils.memrefs_to_packed_args(
args = lh_utils.memref.to_packed_args(
[
x_mem,
wq_mem,
Expand Down Expand Up @@ -1980,7 +1979,7 @@ def transformer_op(*params):
out_mem = get_ranked_memref_descriptor(out.numpy())
memrefs.append(out_mem)

args = lh_utils.memrefs_to_packed_args(memrefs)
args = lh_utils.memref.to_packed_args(memrefs)
func_ptr(args)

assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
5 changes: 3 additions & 2 deletions examples/mlir/compile_and_run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# RUN: %PYTHON %s
# REQUIRES: torch

import torch
import argparse
Expand All @@ -9,7 +10,7 @@
from mlir.execution_engine import ExecutionEngine
from mlir.passmanager import PassManager

from lighthouse import utils as lh_utils
import lighthouse.utils as lh_utils


def create_kernel(ctx: ir.Context) -> ir.Module:
Expand Down Expand Up @@ -167,7 +168,7 @@ def main(args):
out = torch.empty_like(out_ref)

# Execute the kernel.
args = lh_utils.torch_to_packed_args([a, b, out])
args = lh_utils.torch.to_packed_args([a, b, out])
add_func(args)

### Verification ###
Expand Down
24 changes: 9 additions & 15 deletions examples/workload/example.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
# RUN: %PYTHON %s | FileCheck %s
# RUN: %PYTHON %s | FileCheck %s
# CHECK: func.func @payload
# CHECK: PASSED
# CHECK: Throughput:
"""
Workload example: Element-wise sum of two (M, N) float32 arrays on CPU.
"""

import ctypes
from contextlib import contextmanager
from functools import cached_property
from typing import Optional

import numpy as np
from mlir import ir
from mlir.runtime.np_to_memref import get_ranked_memref_descriptor
from mlir.dialects import func, linalg, bufferization
from mlir.dialects import transform
from mlir.execution_engine import ExecutionEngine
from contextlib import contextmanager
from functools import cached_property
import ctypes
from typing import Optional
from lighthouse.utils.mlir import (
apply_registered_pass,
canonicalize,
match,
)
from lighthouse.workload import (
Workload,
execute,
benchmark,
)

from lighthouse.utils.mlir import apply_registered_pass, canonicalize, match
from lighthouse.workload import Workload, execute, benchmark


class ElementwiseSum(Workload):
Expand Down
Loading