Skip to content

Commit

Permalink
Address review feedback.
Browse files Browse the repository at this point in the history
Do not generate implicit `@differentiable` attribute on protocol witness when
conformance is in a different file.

Gardening. Add tests.
  • Loading branch information
dan-zheng committed Feb 13, 2020
1 parent 50adeee commit 4415d58
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 15 deletions.
21 changes: 13 additions & 8 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness,
/// witness.
/// - If requirement's `@differentiable` attributes are met, or if `result` is
/// not viable, returns `result`.
/// - Otherwise, returns a `DifferentiableConflict` `RequirementMatch`.
/// - Otherwise, returns a "missing `@differentiable` attribute"
/// `RequirementMatch`.
// Note: the `result` argument is only necessary for using
// `RequirementMatch::WitnessSubstitutions`.
static RequirementMatch
Expand Down Expand Up @@ -386,15 +387,19 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
bool success = false;
// If no exact witness derivative configuration was found, check
// conditions for creating an implicit witness `@differentiable` attribute
// with the exaxct derivative configuration:
// with the exact derivative configuration:
// - If the witness has a "superset" derivative configuration.
// - If the witness is less than public.
// - If the witness is less than public and is declared in the same file
// as the conformance.
// - `@differentiable` attributes are really only significant for public
// declarations: it improves usability to not require explicit
// `@differentiable` attributes for less-visible declarations.
bool createImplicitWitnessAttribute =
supersetConfig || witness->getFormalAccess() < AccessLevel::Public;
if (supersetConfig || witness->getFormalAccess() < AccessLevel::Public) {
supersetConfig ||
(witness->getFormalAccess() < AccessLevel::Public &&
dc->getModuleScopeContext() ==
witness->getDeclContext()->getModuleScopeContext());
if (createImplicitWitnessAttribute) {
auto derivativeGenSig = witnessAFD->getGenericSignature();
if (supersetConfig)
derivativeGenSig = supersetConfig->derivativeGenericSignature;
Expand Down Expand Up @@ -428,9 +433,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
return RequirementMatch(
getStandinForAccessor(vdWitness, AccessorKind::Get),
MatchKind::DifferentiableConflict, reqDiffAttr);
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
} else {
return RequirementMatch(witness, MatchKind::DifferentiableConflict,
return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
reqDiffAttr);
}
}
Expand Down Expand Up @@ -2328,7 +2333,7 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
case MatchKind::NonObjC:
diags.diagnose(match.Witness, diag::protocol_witness_not_objc);
break;
case MatchKind::DifferentiableConflict: {
case MatchKind::MissingDifferentiableAttr: {
// Emit a note and fix-it showing the missing requirement `@differentiable`
// attribute.
auto *reqAttr = cast<DifferentiableAttr>(match.UnmetAttribute);
Expand Down
15 changes: 8 additions & 7 deletions lib/Sema/TypeCheckProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,8 @@ enum class MatchKind : uint8_t {
/// The witness is explicitly @nonobjc but the requirement is @objc.
NonObjC,

/// The witness does not have a `@differentiable` attribute satisfying one
/// from the requirement.
DifferentiableConflict,
/// The witness is missing a `@differentiable` attribute from the requirement.
MissingDifferentiableAttr,
};

/// Describes the kind of optional adjustment performed when
Expand Down Expand Up @@ -363,7 +362,7 @@ struct RequirementMatch {
: Witness(witness), Kind(kind), WitnessType(), UnmetAttribute(attr),
ReqEnv(None) {
assert(!hasWitnessType() && "Should have witness type");
assert(UnmetAttribute);
assert(hasUnmetAttribute() && "Should have unmet attribute");
}

RequirementMatch(ValueDecl *witness, MatchKind kind,
Expand Down Expand Up @@ -438,7 +437,7 @@ struct RequirementMatch {
case MatchKind::RethrowsConflict:
case MatchKind::ThrowsConflict:
case MatchKind::NonObjC:
case MatchKind::DifferentiableConflict:
case MatchKind::MissingDifferentiableAttr:
return false;
}

Expand Down Expand Up @@ -468,7 +467,7 @@ struct RequirementMatch {
case MatchKind::RethrowsConflict:
case MatchKind::ThrowsConflict:
case MatchKind::NonObjC:
case MatchKind::DifferentiableConflict:
case MatchKind::MissingDifferentiableAttr:
return false;
}

Expand All @@ -479,7 +478,9 @@ struct RequirementMatch {
bool hasRequirement() { return Kind == MatchKind::MissingRequirement; }

/// Determine whether this requirement match has an unmet attribute.
bool hasUnmetAttribute() { return Kind == MatchKind::DifferentiableConflict; }
bool hasUnmetAttribute() {
return Kind == MatchKind::MissingDifferentiableAttr;
}

swift::Witness getWitness(ASTContext &ctx) const;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
protocol Protocol: Differentiable {
// expected-note @+2 {{protocol requires function 'internalMethod1' with type '(Float) -> Float'}}
@differentiable(wrt: (self, x))
func internalMethod1(_ x: Float) -> Float

// expected-note @+3 {{protocol requires function 'internalMethod2' with type '(Float) -> Float'}}
@differentiable(wrt: (self, x))
@differentiable(wrt: x)
func internalMethod2(_ x: Float) -> Float

@differentiable(wrt: (self, x))
@differentiable(wrt: x)
func internalMethod3(_ x: Float) -> Float
}

protocol Protocol2: Differentiable {
@differentiable(wrt: (self, x))
func internalMethod4(_ x: Float) -> Float
}

// Note:
// - No `ConformingStruct: Protocol` conformance exists in this file, so this
// file should compile just file.
// - A `ConformingStruct: Protocol` conformance in a different file should be
// diagnosed to prevent linker errors. Without a diagnostic, compilation of
// the other file creates external references to symbols for implicit
// `@differentiable` attributes, even though no such symbols exist.
// Context: https://github.com/apple/swift/pull/29771#issuecomment-585059721

struct ConformingStruct: Differentiable {
// Expected: errors for missing `@differentiable` attribute.
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
func internalMethod1(_ x: Float) -> Float {
x
}

// Expected: errors for missing `@differentiable` superset attribute.
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
@differentiable(wrt: x)
func internalMethod2(_ x: Float) -> Float {
x
}

// Expected: no errors for missing `@differentiable` subset attribute.
@differentiable(wrt: (self, x))
func internalMethod3(_ x: Float) -> Float {
x
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Test missing protocol requirement `@differentiable` attribute errors for
// non-public protocol witnesses, when the protocol conformance is declared in a
// separate file from witnesses.
//
// Implicit `@differentiable` attributes cannot be generated for protocol
// witnesses when the conformance is declared from a separate file from the
// witness. Otherwise, compilation of the file containing the conformance
// creates external references to symbols for implicit `@differentiable`
// attributes, even though no such symbols exist.
//
// Context: https://github.com/apple/swift/pull/29771#issuecomment-585059721

// Note: `swiftc main.swift other_file.swift` runs three commands:
// - `swiftc -frontend -primary-file main.swift other_file.swift -o ...`
// - `swiftc -frontend main.swift -primary-file other_file.swift -o ...`
// - `/usr/bin/ld ...`
//
// `%target-build-swift` performs `swiftc main.swift other_file.swift`, so it is expected to fail (hence `not`).
// `swiftc -frontend -primary-file main.swift other_file.swift` should succeed, so no need for `-verify`.
// `swiftc -frontend main.swift -primary-file other_file.swift` should fail, so `-verify` is needed.

// RUN: %target-swift-frontend -c -verify -primary-file %s %S/Inputs/other_file.swift
// RUN: %target-swift-frontend -c %s -primary-file %S/Inputs/other_file.swift
// RUN: not %target-build-swift %s %S/Inputs/other_file.swift

// Error: conformance is in different file than witnesses.
// expected-error @+1 {{type 'ConformingStruct' does not conform to protocol 'Protocol'}}
extension ConformingStruct: Protocol {}

// No error: conformance is in same file as witnesses.
extension ConformingStruct: Protocol2 {
func internalMethod4(_ x: Float) -> Float {
x
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %target-swift-frontend -emit-sil -verify %s

// Test end-to-end differentiation involving implicit `@differentiable`
// attributes for non-public protocol witnesses.
//
// Specifically, test the diagnostic source locations for implicit attributes.

protocol Protocol: Differentiable {
// Note: error below comes from the implicit `@differentiable` attribute on
// `PublicConformingStruct.internalMethod`. The source location of the
// implicit attribute is copied from the protocol requirement's attribute.

// expected-error @+1 {{function is not differentiable}}
@differentiable(wrt: (self, x))
func internalMethod(_ x: Float) -> Float
}

struct ConformingStruct: Protocol {
// Expected:
// - No error for missing `@differentiable` attribute on internal protocol witness.
// An implicit `@differentiable` attribute should be created.
// - A non-differentiability error, because the method body is non-differentiable.
// expected-note @+1 {{when differentiating this function definition}}
func internalMethod(_ x: Float) -> Float {
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
return Float(Int(x))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: %target-swift-frontend -print-ast -verify %s | %FileCheck %s

// Test implicit `@differentiable` attributes for non-public protocol witnesses.

protocol InternalProtocol: Differentiable {
// expected-note @+3 {{protocol requires function 'publicMethod' with type '(Float) -> Float'}}
@differentiable(wrt: self)
@differentiable(wrt: (self, x))
func publicMethod(_ x: Float) -> Float

@differentiable(wrt: self)
@differentiable(wrt: (self, x))
func internalMethod(_ x: Float) -> Float
}

// expected-error @+1 {{type 'PublicConformingStruct' does not conform to protocol 'InternalProtocol'}}
public struct PublicConformingStruct: InternalProtocol {
// Expected: error for missing `@differentiable` attribute on public protocol witness.
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
public func publicMethod(_ x: Float) -> Float {
x
}

// Expected: no error for missing `@differentiable` attribute on internal protocol witness.
// Implicit `@differentiable` attributes should be created.
func internalMethod(_ x: Float) -> Float {
x
}
}

// CHECK-LABEL: public struct PublicConformingStruct : InternalProtocol {
// CHECK: public func publicMethod(_ x: Float) -> Float
// CHECK: @differentiable(wrt: (self, x))
// CHECK: @differentiable(wrt: self)
// CHECK: internal func internalMethod(_ x: Float) -> Float
// CHECK: }

0 comments on commit 4415d58

Please sign in to comment.