Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 50 additions & 47 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6679,7 +6679,8 @@ def narrow_type_by_identity_equality(
else:
raise AssertionError

partial_type_maps = []
all_if_maps: list[TypeMap] = []
all_else_maps: list[TypeMap] = []

# For each narrowable index, we see what we can narrow based on each relevant target
for i in expr_indices:
Expand All @@ -6690,10 +6691,8 @@ def narrow_type_by_identity_equality(
continue

expr_type = operand_types[i]
expanded_expr_type = try_expanding_sum_type_to_union(
coerce_to_literal(expr_type), None
)
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)
for j in expr_indices:
if i == j:
continue
Expand All @@ -6703,11 +6702,6 @@ def narrow_type_by_identity_equality(
continue
target_type = operand_types[j]
if should_coerce_literals:
# TODO: doing this prevents narrowing a single-member Enum to literal
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was on the wrong coerce_to_literal call

(In general, it's a very minor thing — there are more important other improvements to be made to narrowing — so I moved this text into the relevant test)

# of its member, because we expand it here and then refuse to add equal
# types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
# See testMatchEnumSingleChoice
target_type = coerce_to_literal(target_type)

if (
Expand All @@ -6718,24 +6712,21 @@ def narrow_type_by_identity_equality(
continue

target = TypeRange(target_type, is_upper_bound=False)
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))

if is_value_target:
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expanded_expr_type, [target])
)
partial_type_maps.append((if_map, else_map))
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
if is_target_for_value_narrowing(get_proper_type(target_type)):
all_if_maps.append(if_map)
all_else_maps.append(else_map)
else:
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
# For value targets, it is safe to narrow in the negative case.
# e.g. if (x: Literal[5] | None) != (y: Literal[5]), we can narrow x to None
# However, for non-value targets, we cannot do this narrowing,
# and so we ignore else_map
# e.g. if (x: str | None) != (y: str), we cannot narrow x to None
if if_map:
partial_type_maps.append((if_map, {}))
if if_map is not None: # TODO: this gate is incorrect and should be removed
all_if_maps.append(if_map)

# Handle narrowing for operands with custom __eq__ methods specially
# In most cases, we won't be able to do any narrowing
Expand All @@ -6757,14 +6748,12 @@ def narrow_type_by_identity_equality(
if should_coerce_literals:
target_type = coerce_to_literal(target_type)
target = TypeRange(target_type, is_upper_bound=False)
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))

if is_value_target:
if is_target_for_value_narrowing(get_proper_type(target_type)):
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target])
)
if else_map:
partial_type_maps.append(({}, else_map))
all_else_maps.append(else_map)
continue

# If our operand with custom __eq__ is a union, where only some members of the union
Expand All @@ -6778,37 +6767,24 @@ def narrow_type_by_identity_equality(
# we narrow to in the if_map
or_if_maps.append({operands[i]: expr_type})

expr_type = coerce_to_literal(try_expanding_sum_type_to_union(expr_type, None))
for j in expr_indices:
if j in custom_eq_indices:
continue
target_type = operand_types[j]
if should_coerce_literals:
target_type = coerce_to_literal(target_type)
target = TypeRange(target_type, is_upper_bound=False)
is_value_target = is_target_for_value_narrowing(get_proper_type(target_type))

if is_value_target:
expr_type = coerce_to_literal(expr_type)
expr_type = try_expanding_sum_type_to_union(expr_type, None)
if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target], default=expr_type)
)
or_if_maps.append(if_map)
if is_value_target:
if is_target_for_value_narrowing(get_proper_type(target_type)):
or_else_maps.append(else_map)

final_if_map: TypeMap = {}
final_else_map: TypeMap = {}
if or_if_maps:
final_if_map = or_if_maps[0]
for if_map in or_if_maps[1:]:
final_if_map = or_conditional_maps(final_if_map, if_map)
if or_else_maps:
final_else_map = or_else_maps[0]
for else_map in or_else_maps[1:]:
final_else_map = or_conditional_maps(final_else_map, else_map)

partial_type_maps.append((final_if_map, final_else_map))
all_if_maps.append(reduce_or_conditional_type_maps(or_if_maps))
all_else_maps.append(reduce_or_conditional_type_maps(or_else_maps))

# Handle narrowing for comparisons that produce additional narrowing, like
# `type(x) == T` or `x.__class__ is T`
Expand Down Expand Up @@ -6849,13 +6825,16 @@ def narrow_type_by_identity_equality(
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo)
else False
)
if not is_final:
else_map = {}
partial_type_maps.append((if_map, else_map))
all_if_maps.append(if_map)
if is_final:
# We can only narrow `type(x) == T` in the negative case if T is final
all_else_maps.append(else_map)

# We will not have duplicate entries in our type maps if we only have two operands,
# so we can skip running meets on the intersections
return reduce_conditional_maps(partial_type_maps, use_meet=len(operands) > 2)
if_map = reduce_and_conditional_type_maps(all_if_maps, use_meet=len(operands) > 2)
else_map = reduce_or_conditional_type_maps(all_else_maps)
return if_map, else_map

def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap:
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
Expand Down Expand Up @@ -8491,7 +8470,7 @@ def builtin_item_type(tp: Type) -> Type | None:
return None


def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> TypeMap:
def and_conditional_maps(m1: TypeMap, m2: TypeMap, *, use_meet: bool = False) -> TypeMap:
"""Calculate what information we can learn from the truth of (e1 and e2)
in terms of the information that we can learn from the truth of e1 and
the truth of e2.
Expand Down Expand Up @@ -8524,7 +8503,7 @@ def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> Ty
return result


def or_conditional_maps(m1: TypeMap, m2: TypeMap, coalesce_any: bool = False) -> TypeMap:
def or_conditional_maps(m1: TypeMap, m2: TypeMap, *, coalesce_any: bool = False) -> TypeMap:
"""Calculate what information we can learn from the truth of (e1 or e2)
in terms of the information that we can learn from the truth of e1 and
the truth of e2. If coalesce_any is True, consider Any a supertype when
Expand Down Expand Up @@ -8589,6 +8568,30 @@ def reduce_conditional_maps(
return final_if_map, final_else_map


def reduce_or_conditional_type_maps(ms: list[TypeMap]) -> TypeMap:
"""Reduces a list of TypeMaps into a single TypeMap by "or"-ing them together."""
if len(ms) == 0:
return {}
if len(ms) == 1:
return ms[0]
result = ms[0]
for m in ms[1:]:
result = or_conditional_maps(result, m)
return result


def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> TypeMap:
"""Reduces a list of TypeMaps into a single TypeMap by "and"-ing them together."""
if len(ms) == 0:
return {}
if len(ms) == 1:
return ms[0]
result = ms[0]
for m in ms[1:]:
result = and_conditional_maps(result, m, use_meet=use_meet)
return result


BUILTINS_CUSTOM_EQ_CHECKS: Final = {
"builtins.bytes",
"builtins.bytearray",
Expand Down
5 changes: 4 additions & 1 deletion test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,10 @@ def f(m: Medal) -> None:
match m:
case Medal.GOLD:
always_assigned = 1
# This should narrow to literal, see TODO in checker::refine_identity_comparison_expression
# Ideally, this should narrow to literal
# However `expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)`
# in checker.py means we expand the type to a LiteralType and then because there is no
# net change we don't end up inserting the LiteralType into the type map
reveal_type(m) # N: Revealed type is "__main__.Medal"
case _:
assert_never(m)
Expand Down