diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dca3c646..b2f39b1c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ exclude: ^python/tests/__snapshots__/ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.3 + rev: v0.15.0 hooks: - id: ruff-check args: [--fix] - id: ruff-format - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.9.7 + rev: 0.10.0 hooks: - id: uv-lock diff --git a/docs/explanation/2023_11_17_pytensor.ipynb b/docs/explanation/2023_11_17_pytensor.ipynb index 22f5d6f0..13009583 100644 --- a/docs/explanation/2023_11_17_pytensor.ipynb +++ b/docs/explanation/2023_11_17_pytensor.ipynb @@ -167,7 +167,7 @@ ")\n", "converter(int, IntTuple, lambda i: IntTuple(Int(i64(i))))\n", "converter(i64, IntTuple, lambda i: IntTuple(Int(i)))\n", - "converter(Int, IntTuple, lambda i: IntTuple(i))\n", + "converter(Int, IntTuple, IntTuple)\n", "\n", "\n", "@egraph.register\n", diff --git a/python/egglog/__init__.py b/python/egglog/__init__.py index 7d20dfdb..66e7402b 100644 --- a/python/egglog/__init__.py +++ b/python/egglog/__init__.py @@ -4,7 +4,7 @@ from . import config, ipython_magic # noqa: F401 from .bindings import EggSmolError, StageInfo, TimeOnly, WithPlan # noqa: F401 -from .builtins import * # noqa: UP029 +from .builtins import * from .conversion import * from .deconstruct import * from .egraph import * diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index a9c75f69..fb3dc38e 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -801,7 +801,7 @@ def bool_le(self, other: BigIntLike) -> Bool: ... def bool_ge(self, other: BigIntLike) -> Bool: ... -converter(i64, BigInt, lambda i: BigInt(i)) +converter(i64, BigInt, BigInt) BigIntLike: TypeAlias = BigInt | i64Like diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index b69d8212..4e75e24a 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -297,7 +297,7 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int): yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int")) -converter(i64, Int, lambda x: Int(x)) +converter(i64, Int, Int) IntLike: TypeAlias = Int | i64Like @@ -377,8 +377,8 @@ def __gt__(self, other: FloatLike) -> Boolean: ... def __ge__(self, other: FloatLike) -> Boolean: ... -converter(float, Float, lambda x: Float(x)) -converter(Int, Float, lambda x: Float.from_int(x)) +converter(float, Float, Float) +converter(Int, Float, Float.from_int) FloatLike: TypeAlias = Float | float | IntLike @@ -521,7 +521,7 @@ def deselect(self, indices: TupleIntLike) -> TupleInt: return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i]) -converter(Vec[Int], TupleInt, lambda x: TupleInt.from_vec(x)) +converter(Vec[Int], TupleInt, TupleInt.from_vec) TupleIntLike: TypeAlias = TupleInt | VecLike[Int, IntLike] @@ -649,7 +649,7 @@ def product(self) -> TupleTupleInt: ) -converter(Vec[TupleInt], TupleTupleInt, lambda x: TupleTupleInt.from_vec(x)) +converter(Vec[TupleInt], TupleTupleInt, TupleTupleInt.from_vec) TupleTupleIntLike: TypeAlias = TupleTupleInt | VecLike[TupleInt, TupleIntLike] @@ -755,8 +755,8 @@ def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ... def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ... -converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x)) -converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x)) +converter(DType, IsDtypeKind, IsDtypeKind.dtype) +converter(str, IsDtypeKind, IsDtypeKind.string) converter( tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL ) @@ -922,8 +922,8 @@ def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue: return TupleValue(ti.length(), lambda i: Value.int(ti[i])) -converter(Vec[Value], TupleValue, lambda x: TupleValue.from_vec(x)) -converter(TupleInt, TupleValue, lambda x: TupleValue.from_tuple_int(x)) +converter(Vec[Value], TupleValue, TupleValue.from_vec) +converter(TupleInt, TupleValue, TupleValue.from_tuple_int) TupleValueLike: TypeAlias = TupleValue | VecLike[Value, ValueLike] | TupleIntLike @@ -1073,9 +1073,9 @@ def ndarray(cls, key: NDArray) -> IndexKey: converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS) -converter(Int, IndexKey, lambda i: IndexKey.int(i)) -converter(Slice, IndexKey, lambda s: IndexKey.slice(s)) -converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m)) +converter(Int, IndexKey, IndexKey.int) +converter(Slice, IndexKey, IndexKey.slice) +converter(MultiAxisIndexKey, IndexKey, IndexKey.multi_axis) class Device(Expr, ruleset=array_api_ruleset): ... @@ -1232,13 +1232,13 @@ def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ... NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike -converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v)) -converter(Value, NDArray, lambda v: NDArray.scalar(v)) +converter(NDArray, IndexKey, IndexKey.ndarray) +converter(Value, NDArray, NDArray.scalar) # Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive # to prefer upcasting in the other direction when we can, which is safer at runtime converter(NDArray, Value, lambda n: n.to_value(), 100) -converter(TupleValue, NDArray, lambda v: NDArray.vector(v)) -converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v)) +converter(TupleValue, NDArray, NDArray.vector) +converter(TupleInt, TupleValue, TupleValue.from_tuple_int) @array_api_ruleset.register @@ -1322,7 +1322,7 @@ def eval(self) -> tuple[NDArray, ...]: return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec) -converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x)) +converter(Vec[NDArray], TupleNDArray, TupleNDArray.from_vec) TupleNDArrayLike: TypeAlias = TupleNDArray | VecLike[NDArray, NDArrayLike] @@ -1371,7 +1371,7 @@ def some(cls, value: Boolean) -> OptionalBool: ... converter(type(None), OptionalBool, lambda _: OptionalBool.none) -converter(Boolean, OptionalBool, lambda x: OptionalBool.some(x)) +converter(Boolean, OptionalBool, OptionalBool.some) class OptionalDType(Expr, ruleset=array_api_ruleset): @@ -1382,7 +1382,7 @@ def some(cls, value: DType) -> OptionalDType: ... converter(type(None), OptionalDType, lambda _: OptionalDType.none) -converter(DType, OptionalDType, lambda x: OptionalDType.some(x)) +converter(DType, OptionalDType, OptionalDType.some) class OptionalDevice(Expr, ruleset=array_api_ruleset): @@ -1393,7 +1393,7 @@ def some(cls, value: Device) -> OptionalDevice: ... converter(type(None), OptionalDevice, lambda _: OptionalDevice.none) -converter(Device, OptionalDevice, lambda x: OptionalDevice.some(x)) +converter(Device, OptionalDevice, OptionalDevice.some) class OptionalTupleInt(Expr, ruleset=array_api_ruleset): @@ -1404,7 +1404,7 @@ def some(cls, value: TupleIntLike) -> OptionalTupleInt: ... converter(type(None), OptionalTupleInt, lambda _: OptionalTupleInt.none) -converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x)) +converter(TupleInt, OptionalTupleInt, OptionalTupleInt.some) class IntOrTuple(Expr, ruleset=array_api_ruleset): @@ -1417,8 +1417,8 @@ def int(cls, value: Int) -> IntOrTuple: ... def tuple(cls, value: TupleIntLike) -> IntOrTuple: ... -converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v)) -converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v)) +converter(Int, IntOrTuple, IntOrTuple.int) +converter(TupleInt, IntOrTuple, IntOrTuple.tuple) class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset): @@ -1429,7 +1429,7 @@ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ... converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none) -converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v)) +converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some) @function diff --git a/python/egglog/exp/array_api_loopnest.py b/python/egglog/exp/array_api_loopnest.py index 046c3c1a..df15b43a 100644 --- a/python/egglog/exp/array_api_loopnest.py +++ b/python/egglog/exp/array_api_loopnest.py @@ -31,7 +31,7 @@ def shape_api_ruleset(dims: TupleInt, axis: TupleInt): ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i])) ) yield rewrite(s.select(axis), subsume=True).to( - ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i])) + ShapeAPI(TupleInt.range(dims.length()).filter(axis.contains).map(lambda i: dims[i])) ) yield rewrite(s.to_tuple(), subsume=True).to(dims) diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 34564ddc..f3c50a7f 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -210,11 +210,14 @@ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: return NDArray( outshape, X.dtype, - lambda k: LoopNestAPI.from_tuple(reduce_axis) - .unwrap() - .indices() - .foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0) - .sqrt(), + lambda k: ( + LoopNestAPI + .from_tuple(reduce_axis) + .unwrap() + .indices() + .foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0) + .sqrt() + ), ) @@ -224,9 +227,11 @@ def linalg_norm_v2(X: NDArrayLike, axis: TupleIntLike) -> NDArray: return NDArray( X.shape.deselect(axis), X.dtype, - lambda k: ndindex(X.shape.select(axis)) - .foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0) - .sqrt(), + lambda k: ( + ndindex(X.shape.select(axis)) + .foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0) + .sqrt() + ), ) diff --git a/python/tests/test_convert.py b/python/tests/test_convert.py index 24edcadc..b7a826ee 100644 --- a/python/tests/test_convert.py +++ b/python/tests/test_convert.py @@ -95,7 +95,7 @@ def test_convert_to_generic(): class G(BuiltinExpr, Generic[T]): def __init__(self, x: T) -> None: ... - converter(i64, G[i64], lambda x: G(x)) + converter(i64, G[i64], G) assert expr_parts(convert(10, G[i64])) == expr_parts(G(i64(10))) with pytest.raises(ConvertError): @@ -114,7 +114,7 @@ def test_convert_to_unbound_generic(): class G(BuiltinExpr, Generic[T]): def __init__(self, x: i64) -> None: ... - converter(i64, G, lambda x: G[get_type_args()[0]](x)) # type: ignore[misc, operator] + converter(i64, G, G[get_type_args()[0]]) # type: ignore[misc, operator] assert expr_parts(convert(10, G[String])) == expr_parts(G[String](i64(10))) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 173e817a..2d598aff 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -535,14 +535,14 @@ def _global_make_tuple(x): def test_eval_fn_globals(): - assert EGraph().extract(PyObject(lambda x: _global_make_tuple(x))(PyObject.from_int(1))).value == (1,) + assert EGraph().extract(PyObject(_global_make_tuple)(PyObject.from_int(1))).value == (1,) def test_eval_fn_locals(): def _locals_make_tuple(x): return (x,) - assert EGraph().extract(PyObject(lambda x: _locals_make_tuple(x))(PyObject.from_int(1))).value == (1,) + assert EGraph().extract(PyObject(_locals_make_tuple)(PyObject.from_int(1))).value == (1,) def test_lazy_types(): @@ -1459,9 +1459,9 @@ def __contains__(self, item: int) -> bool: pytest.param(lambda: int(m), 1000, id="int"), pytest.param(lambda: float(m), 100.0, id="float"), pytest.param(lambda: complex(m), 1 + 0j, id="complex"), - pytest.param(lambda: m.__index__(), 20, id="index"), + pytest.param(m.__index__, 20, id="index"), pytest.param(lambda: len(m), 10, id="len"), - pytest.param(lambda: m.__length_hint__(), 5, id="length_hint"), + pytest.param(m.__length_hint__, 5, id="length_hint"), pytest.param(lambda: list(m), [1], id="iter"), pytest.param(lambda: list(reversed(m)), [10], id="reversed"), pytest.param(lambda: 1 in m, True, id="contains"), diff --git a/python/tests/test_unstable_fn.py b/python/tests/test_unstable_fn.py index eaaab64b..ff5eb22e 100644 --- a/python/tests/test_unstable_fn.py +++ b/python/tests/test_unstable_fn.py @@ -253,7 +253,7 @@ def apply_f(f: Callable[[A], A], x: A) -> A: @r.register def _rewrite(a: A): - yield rewrite(transform_a(a)).to(apply_f(lambda x: my_transform_a(x), a)) + yield rewrite(transform_a(a)).to(apply_f(my_transform_a, a)) assert check_eq(transform_a(A()), my_transform_a(A()), r * 10) @@ -276,7 +276,7 @@ def apply_f(f: Callable[[A], A], x: A) -> A: @ruleset def my_ruleset(a: A): - yield rewrite(transform_a(a)).to(apply_f(lambda x: my_transform_a(x), a)) + yield rewrite(transform_a(a)).to(apply_f(my_transform_a, a)) assert check_eq(transform_a(A()), my_transform_a(A()), (my_ruleset | apply_ruleset) * 10) @@ -296,7 +296,7 @@ def apply_f(f: Callable[[A], A], x: A) -> A: @function(ruleset=r) def transform_a(a: A) -> A: - return apply_f(lambda x: my_transform_a(x), a) + return apply_f(my_transform_a, a) assert check_eq(transform_a(A()), my_transform_a(A()), r * 10) @@ -325,7 +325,7 @@ def higher_order(f: Callable[[A], A]) -> A: ... @function def transform_a(a: A) -> A: ... - v = higher_order(lambda a: transform_a(a)) + v = higher_order(transform_a) assert str(v) == "higher_order(lambda a: transform_a(a))" def test_multiple_same(self):