diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 1f2b1bff1d85..8f4ff262091f 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -592,16 +592,27 @@ class BindingSet { void dump(llvm::raw_ostream &out, unsigned indent) const; private: - /// Add a new binding to the set. + /// Introduce a new binding to the set. The binding might not + /// actually be added due to subtyping or other rule like + /// CGFloat/Double implicit conversion. This method should be + /// be preferred over \c addBinding when adding new bindings. /// /// \param binding The binding to add. /// \param isTransitive Indicates whether this binding has been /// acquired through transitive inference and requires validity /// checking. - void addBinding(PotentialBinding binding, bool isTransitive); + void introduceBinding(PotentialBinding binding, bool isTransitive); void addLiteralRequirement(Constraint *literal); + /// Insert the given binding into \c Bindings. + /// + /// This method is going to compute referenced variables before + /// forwarding to the other overload. + void addBinding(const PotentialBinding &&binding); + void addBinding(const PotentialBinding &&binding, + llvm::SmallPtrSetImpl &referencedVars); + void addDefault(Constraint *constraint); StringRef getLiteralBindingKind(LiteralBindingKind K) const { diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index fc46bd46a0ff..18b95fb3d326 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -43,7 +43,7 @@ BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar, : CS(CS), TypeVar(TypeVar), Info(info) { for (const auto &binding : info.Bindings) - addBinding(binding, /*isTransitive=*/false); + introduceBinding(binding, /*isTransitive=*/false); for (auto *constraint : info.Constraints) { switch (constraint->getKind()) { @@ -596,7 +596,7 @@ void BindingSet::inferTransitiveKeyPathBindings() { // Copy the bindings over to the root. for (const auto &binding : bindings.Bindings) - addBinding(binding, /*isTransitive=*/true); + introduceBinding(binding, /*isTransitive=*/true); // Make a note that the key path root is transitively adjacent // to contextual root type variable and all of its variables. @@ -606,7 +606,7 @@ void BindingSet::inferTransitiveKeyPathBindings() { bindings.AdjacentVars.end()); } } else { - addBinding( + introduceBinding( binding.withSameSource(inferredRootTy, AllowedBindingKind::Exact), /*isTransitive=*/true); } @@ -679,7 +679,7 @@ void BindingSet::inferTransitiveSupertypeBindings() { if (ConstraintSystem::typeVarOccursInType(TypeVar, type)) continue; - addBinding(binding.withSameSource(type, AllowedBindingKind::Supertypes), + introduceBinding(binding.withSameSource(type, AllowedBindingKind::Supertypes), /*isTransitive=*/true); } } @@ -713,7 +713,7 @@ void BindingSet::inferTransitiveUnresolvedMemberRefBindings() { continue; } - addBinding({protocolTy, AllowedBindingKind::Exact, constraint}, + introduceBinding({protocolTy, AllowedBindingKind::Exact, constraint}, /*isTransitive=*/false); } } @@ -889,7 +889,22 @@ void BindingSet::finalizeUnresolvedMemberChainResult() { } } -void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) { +void BindingSet::addBinding(const PotentialBinding &&binding) { + SmallPtrSet referencedVars; + binding.BindingType->getTypeVariables(referencedVars); + + addBinding(std::move(binding), referencedVars); +} + +void BindingSet::addBinding(const PotentialBinding &&binding, + SmallPtrSetImpl &referencedVars) { + for (auto *adjacentVar : referencedVars) + AdjacentVars.insert(adjacentVar); + + (void)Bindings.insert(binding); +} + +void BindingSet::introduceBinding(PotentialBinding binding, bool isTransitive) { if (Bindings.count(binding)) return; @@ -944,6 +959,57 @@ void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) { } } + // If the type variable prefers subtypes, diasambiguate a situation + // when this type variable is simultaneously a supertype of `@Sendable` + // function type and a subtype of a non-Sendable one by using a supertype + // binding because it constitutes a "subtype" in this case. + // + // For example: + // + // @Sendable () -> Void conv $T + // $T argument conv () -> Void + // + // Either of the types could also be wrapped in a number of optionals. Even if + // there is an optionality mismatch, let's still prefer a supertype binding + // because that would be easier to diagnose. + // + // In particular, this is helpful with ternary operators where the context is + // non-Sendable, but one or both sides are. + if (TypeVar->getImpl().prefersSubtypeBinding()) { + if (auto *funcType = binding.BindingType->lookThroughAllOptionalTypes() + ->getAs()) { + if (binding.Kind == AllowedBindingKind::Supertypes && + funcType->isSendable()) { + // Note that we are removing the bindings but leaving AdjacentVars + // intact to make sure that this doesn't affect assessment of the + // binding set i.e. \c involvesTypeVariables. + Bindings.remove_if([](const PotentialBinding &existing) { + if (existing.Kind != AllowedBindingKind::Subtypes) + return false; + + auto *existingFn = existing.BindingType->lookThroughAllOptionalTypes() + ->getAs(); + return existingFn && !existingFn->isSendable(); + }); + } + + // If there are existing `@Sendable` supertype bindings, we can skip this + // one. + if (binding.Kind == AllowedBindingKind::Subtypes && + !funcType->isSendable()) { + if (llvm::any_of(Bindings, [](const PotentialBinding &existing) { + if (existing.Kind != AllowedBindingKind::Supertypes) + return false; + auto *existingFn = + existing.BindingType->lookThroughAllOptionalTypes() + ->getAs(); + return existingFn && existingFn->isSendable(); + })) + return; + } + } + } + // If this is a non-defaulted supertype binding, // check whether we can combine it with another // supertype binding by computing the 'join' of the types. @@ -976,7 +1042,7 @@ void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) { } for (const auto &binding : joined) - (void)Bindings.insert(binding); + addBinding(std::move(binding)); // If new binding has been joined with at least one of existing // bindings, there is no reason to include it into the set. @@ -984,10 +1050,7 @@ void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) { return; } - for (auto *adjacentVar : referencedTypeVars) - AdjacentVars.insert(adjacentVar); - - (void)Bindings.insert(std::move(binding)); + addBinding(std::move(binding), referencedTypeVars); } void BindingSet::determineLiteralCoverage() { diff --git a/test/Concurrency/sendable_keypaths.swift b/test/Concurrency/sendable_keypaths.swift index 6db444cbbe0e..1ce25f74b184 100644 --- a/test/Concurrency/sendable_keypaths.swift +++ b/test/Concurrency/sendable_keypaths.swift @@ -247,16 +247,19 @@ do { static func otherFn() {} } - // TODO(rdar://125948508): This shouldn't be ambiguous (@Sendable version should be preferred) func fnRet(cond: Bool) -> () -> Void { - cond ? Test.fn : Test.otherFn // expected-error {{failed to produce diagnostic for expression}} + cond ? Test.fn : Test.otherFn // Ok } func forward(_: T) -> T { } - // TODO(rdar://125948508): This shouldn't be ambiguous (@Sendable version should be preferred) - let _: () -> Void = forward(Test.fn) // expected-error {{conflicting arguments to generic parameter 'T' ('@Sendable () -> ()' vs. '() -> Void')}} + let _: () -> Void = forward(Test.fn) // Ok + + func test(fn1: (@Sendable () -> Void)?, fn2: @escaping () -> Void) { + let _: () -> Void = true ? fn1 : fn2 + // expected-error@-1 {{cannot convert value of type '(@Sendable () -> Void)?' to specified type '() -> Void'}} + } } // https://github.com/swiftlang/swift/issues/77105 diff --git a/test/Concurrency/sendable_methods.swift b/test/Concurrency/sendable_methods.swift index 75038c79da21..90825a9bcca2 100644 --- a/test/Concurrency/sendable_methods.swift +++ b/test/Concurrency/sendable_methods.swift @@ -331,4 +331,21 @@ do { static func ff() {} } -} \ No newline at end of file +} + +// Ambiguity between `@Sendable` method and non-Sendable context injected into an Optional. +do { + struct Test { + func action() -> Void {} + + func onAction(_: (() -> Void)?) {} + + func test() { + onAction(true ? action : nil) // Ok + } + + func test(fn1: (@Sendable () -> Void)?, fn2: @escaping () -> Void) { + let _: () -> Void = fn1 ?? fn2 // Ok + } + } +}