diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 546fde39dfee6..e9653ae932ed0 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -308,7 +308,8 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness, /// witness. /// - If requirement's `@differentiable` attributes are met, or if `result` is /// not viable, returns `result`. -/// - Otherwise, returns a `DifferentiableConflict` `RequirementMatch`. +/// - Otherwise, returns a "missing `@differentiable` attribute" +/// `RequirementMatch`. // Note: the `result` argument is only necessary for using // `RequirementMatch::WitnessSubstitutions`. static RequirementMatch @@ -386,15 +387,19 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, bool success = false; // If no exact witness derivative configuration was found, check // conditions for creating an implicit witness `@differentiable` attribute - // with the exaxct derivative configuration: + // with the exact derivative configuration: // - If the witness has a "superset" derivative configuration. - // - If the witness is less than public. + // - If the witness is less than public and is declared in the same file + // as the conformance. // - `@differentiable` attributes are really only significant for public // declarations: it improves usability to not require explicit // `@differentiable` attributes for less-visible declarations. bool createImplicitWitnessAttribute = - supersetConfig || witness->getFormalAccess() < AccessLevel::Public; - if (supersetConfig || witness->getFormalAccess() < AccessLevel::Public) { + supersetConfig || + (witness->getFormalAccess() < AccessLevel::Public && + dc->getModuleScopeContext() == + witness->getDeclContext()->getModuleScopeContext()); + if (createImplicitWitnessAttribute) { auto derivativeGenSig = witnessAFD->getGenericSignature(); if (supersetConfig) derivativeGenSig = supersetConfig->derivativeGenericSignature; @@ -428,9 +433,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, if (auto *vdWitness = dyn_cast(witness)) { return RequirementMatch( getStandinForAccessor(vdWitness, AccessorKind::Get), - MatchKind::DifferentiableConflict, reqDiffAttr); + MatchKind::MissingDifferentiableAttr, reqDiffAttr); } else { - return RequirementMatch(witness, MatchKind::DifferentiableConflict, + return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr, reqDiffAttr); } } @@ -2328,7 +2333,7 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, case MatchKind::NonObjC: diags.diagnose(match.Witness, diag::protocol_witness_not_objc); break; - case MatchKind::DifferentiableConflict: { + case MatchKind::MissingDifferentiableAttr: { // Emit a note and fix-it showing the missing requirement `@differentiable` // attribute. auto *reqAttr = cast(match.UnmetAttribute); diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index ec45d594fc499..c5bd6a2c75294 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -210,9 +210,8 @@ enum class MatchKind : uint8_t { /// The witness is explicitly @nonobjc but the requirement is @objc. NonObjC, - /// The witness does not have a `@differentiable` attribute satisfying one - /// from the requirement. - DifferentiableConflict, + /// The witness is missing a `@differentiable` attribute from the requirement. + MissingDifferentiableAttr, }; /// Describes the kind of optional adjustment performed when @@ -363,7 +362,7 @@ struct RequirementMatch { : Witness(witness), Kind(kind), WitnessType(), UnmetAttribute(attr), ReqEnv(None) { assert(!hasWitnessType() && "Should have witness type"); - assert(UnmetAttribute); + assert(hasUnmetAttribute() && "Should have unmet attribute"); } RequirementMatch(ValueDecl *witness, MatchKind kind, @@ -438,7 +437,7 @@ struct RequirementMatch { case MatchKind::RethrowsConflict: case MatchKind::ThrowsConflict: case MatchKind::NonObjC: - case MatchKind::DifferentiableConflict: + case MatchKind::MissingDifferentiableAttr: return false; } @@ -468,7 +467,7 @@ struct RequirementMatch { case MatchKind::RethrowsConflict: case MatchKind::ThrowsConflict: case MatchKind::NonObjC: - case MatchKind::DifferentiableConflict: + case MatchKind::MissingDifferentiableAttr: return false; } @@ -479,7 +478,9 @@ struct RequirementMatch { bool hasRequirement() { return Kind == MatchKind::MissingRequirement; } /// Determine whether this requirement match has an unmet attribute. - bool hasUnmetAttribute() { return Kind == MatchKind::DifferentiableConflict; } + bool hasUnmetAttribute() { + return Kind == MatchKind::MissingDifferentiableAttr; + } swift::Witness getWitness(ASTContext &ctx) const; }; diff --git a/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/Inputs/other_file.swift b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/Inputs/other_file.swift new file mode 100644 index 0000000000000..20554743e45e1 --- /dev/null +++ b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/Inputs/other_file.swift @@ -0,0 +1,49 @@ +protocol Protocol: Differentiable { + // expected-note @+2 {{protocol requires function 'internalMethod1' with type '(Float) -> Float'}} + @differentiable(wrt: (self, x)) + func internalMethod1(_ x: Float) -> Float + + // expected-note @+3 {{protocol requires function 'internalMethod2' with type '(Float) -> Float'}} + @differentiable(wrt: (self, x)) + @differentiable(wrt: x) + func internalMethod2(_ x: Float) -> Float + + @differentiable(wrt: (self, x)) + @differentiable(wrt: x) + func internalMethod3(_ x: Float) -> Float +} + +protocol Protocol2: Differentiable { + @differentiable(wrt: (self, x)) + func internalMethod4(_ x: Float) -> Float +} + +// Note: +// - No `ConformingStruct: Protocol` conformance exists in this file, so this +// file should compile just file. +// - A `ConformingStruct: Protocol` conformance in a different file should be +// diagnosed to prevent linker errors. Without a diagnostic, compilation of +// the other file creates external references to symbols for implicit +// `@differentiable` attributes, even though no such symbols exist. +// Context: https://github.com/apple/swift/pull/29771#issuecomment-585059721 + +struct ConformingStruct: Differentiable { + // Expected: errors for missing `@differentiable` attribute. + // expected-note @+1 {{candidate is missing attribute '@differentiable'}} + func internalMethod1(_ x: Float) -> Float { + x + } + + // Expected: errors for missing `@differentiable` superset attribute. + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} + @differentiable(wrt: x) + func internalMethod2(_ x: Float) -> Float { + x + } + + // Expected: no errors for missing `@differentiable` subset attribute. + @differentiable(wrt: (self, x)) + func internalMethod3(_ x: Float) -> Float { + x + } +} diff --git a/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/main.swift b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/main.swift new file mode 100644 index 0000000000000..a8ae809918254 --- /dev/null +++ b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_cross_file/main.swift @@ -0,0 +1,35 @@ +// Test missing protocol requirement `@differentiable` attribute errors for +// non-public protocol witnesses, when the protocol conformance is declared in a +// separate file from witnesses. +// +// Implicit `@differentiable` attributes cannot be generated for protocol +// witnesses when the conformance is declared from a separate file from the +// witness. Otherwise, compilation of the file containing the conformance +// creates external references to symbols for implicit `@differentiable` +// attributes, even though no such symbols exist. +// +// Context: https://github.com/apple/swift/pull/29771#issuecomment-585059721 + +// Note: `swiftc main.swift other_file.swift` runs three commands: +// - `swiftc -frontend -primary-file main.swift other_file.swift -o ...` +// - `swiftc -frontend main.swift -primary-file other_file.swift -o ...` +// - `/usr/bin/ld ...` +// +// `%target-build-swift` performs `swiftc main.swift other_file.swift`, so it is expected to fail (hence `not`). +// `swiftc -frontend -primary-file main.swift other_file.swift` should succeed, so no need for `-verify`. +// `swiftc -frontend main.swift -primary-file other_file.swift` should fail, so `-verify` is needed. + +// RUN: %target-swift-frontend -c -verify -primary-file %s %S/Inputs/other_file.swift +// RUN: %target-swift-frontend -c %s -primary-file %S/Inputs/other_file.swift +// RUN: not %target-build-swift %s %S/Inputs/other_file.swift + +// Error: conformance is in different file than witnesses. +// expected-error @+1 {{type 'ConformingStruct' does not conform to protocol 'Protocol'}} +extension ConformingStruct: Protocol {} + +// No error: conformance is in same file as witnesses. +extension ConformingStruct: Protocol2 { + func internalMethod4(_ x: Float) -> Float { + x + } +} diff --git a/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_sil.swift b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_sil.swift new file mode 100644 index 0000000000000..aba4e15958191 --- /dev/null +++ b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_sil.swift @@ -0,0 +1,28 @@ +// RUN: %target-swift-frontend -emit-sil -verify %s + +// Test end-to-end differentiation involving implicit `@differentiable` +// attributes for non-public protocol witnesses. +// +// Specifically, test the diagnostic source locations for implicit attributes. + +protocol Protocol: Differentiable { + // Note: error below comes from the implicit `@differentiable` attribute on + // `PublicConformingStruct.internalMethod`. The source location of the + // implicit attribute is copied from the protocol requirement's attribute. + + // expected-error @+1 {{function is not differentiable}} + @differentiable(wrt: (self, x)) + func internalMethod(_ x: Float) -> Float +} + +struct ConformingStruct: Protocol { + // Expected: + // - No error for missing `@differentiable` attribute on internal protocol witness. + // An implicit `@differentiable` attribute should be created. + // - A non-differentiability error, because the method body is non-differentiable. + // expected-note @+1 {{when differentiating this function definition}} + func internalMethod(_ x: Float) -> Float { + // expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}} + return Float(Int(x)) + } +} diff --git a/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_type_checking.swift b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_type_checking.swift new file mode 100644 index 0000000000000..63ca9f642166b --- /dev/null +++ b/test/AutoDiff/downstream/implicit_nonpublic_differentiable_attr_type_checking.swift @@ -0,0 +1,36 @@ +// RUN: %target-swift-frontend -print-ast -verify %s | %FileCheck %s + +// Test implicit `@differentiable` attributes for non-public protocol witnesses. + +protocol InternalProtocol: Differentiable { + // expected-note @+3 {{protocol requires function 'publicMethod' with type '(Float) -> Float'}} + @differentiable(wrt: self) + @differentiable(wrt: (self, x)) + func publicMethod(_ x: Float) -> Float + + @differentiable(wrt: self) + @differentiable(wrt: (self, x)) + func internalMethod(_ x: Float) -> Float +} + +// expected-error @+1 {{type 'PublicConformingStruct' does not conform to protocol 'InternalProtocol'}} +public struct PublicConformingStruct: InternalProtocol { + // Expected: error for missing `@differentiable` attribute on public protocol witness. + // expected-note @+1 {{candidate is missing attribute '@differentiable'}} + public func publicMethod(_ x: Float) -> Float { + x + } + + // Expected: no error for missing `@differentiable` attribute on internal protocol witness. + // Implicit `@differentiable` attributes should be created. + func internalMethod(_ x: Float) -> Float { + x + } +} + +// CHECK-LABEL: public struct PublicConformingStruct : InternalProtocol { +// CHECK: public func publicMethod(_ x: Float) -> Float +// CHECK: @differentiable(wrt: (self, x)) +// CHECK: @differentiable(wrt: self) +// CHECK: internal func internalMethod(_ x: Float) -> Float +// CHECK: }