diff --git a/loopy/tools.py b/loopy/tools.py index e9f9932b7..a4fa95e8a 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -73,13 +73,17 @@ class LoopyKeyBuilder(KeyBuilderBase): update_for_dict = KeyBuilderBase.update_for_constantdict update_for_defaultdict = KeyBuilderBase.update_for_constantdict - def update_for_BasicSet(self, key_hash, key): # noqa - from islpy import Printer - prn = Printer.to_str(key.get_ctx()) - getattr(prn, "print_"+key._base_name)(key) - key_hash.update(prn.get_str().encode("utf8")) + def update_for_BasicSet(self, key_hash, key): # noqa: N802 + key_hash.update(str(type(key)).encode("utf-8")) + self.rec(key_hash, frozenset(key.get_var_dict().keys())) - def update_for_Map(self, key_hash, key): # noqa + constraints = set() + for constraint in key.get_constraints(): + constraints.add(str(constraint).partition("->")[-1]) + + self.rec(key_hash, frozenset(constraints)) + + def update_for_Map(self, key_hash, key): # noqa: N802 if isinstance(key, isl.Map): self.update_for_BasicSet(key_hash, key) else: diff --git a/test/test_misc.py b/test/test_misc.py index 79c7698e7..9b7a04cbc 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -338,6 +338,31 @@ def test_memoize_on_disk_with_pym_expr(): assert cached_result == uncached_result +def test_basicset_keybuilder(): + # See https://github.com/inducer/loopy/issues/912 for context + import islpy as isl + + # Both sets have the same variables and constraints, but in different order. + # These sets are generated in test_convolution() in test_apps.py + a = isl.BasicSet("[im_w, im_h, nimgs, nfeats] -> " + "{ : im_w >= 7 and im_h >= 7 and nimgs >= 0 and nfeats > 0 }") + + b = isl.BasicSet("[nfeats, nimgs, im_h, im_w] -> " + "{ : nfeats > 0 and nimgs >= 0 and im_h >= 7 and im_w >= 7 }") + + from loopy.tools import LoopyKeyBuilder + + # Equality + assert a == b + assert a.is_equal(b) + assert not a.plain_is_equal(b) + + # Hashing + assert hash(a) != hash(b) + assert a.get_hash() != b.get_hash() + assert LoopyKeyBuilder()(a) == LoopyKeyBuilder()(b) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: