Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] Relax @differentiable requirement for protocol witnesses. #29771

Merged
merged 6 commits into from
Feb 23, 2020

Conversation

dan-zheng
Copy link
Contributor

Previously, all witnesses of a @differentiable protocol requirement were
required to have the same attribute (or one with superset parameter indices).

However, this leads to many annotations on witnesses and is not ideal for
usability. @differentiable attributes are really only significant on
public witnesses, so that they are clearly @differentiable at a glance (in
source code, interface files, and API documentation), without looking through
protocol conformance hierarchies.

Now, only public witnesses of @differentiable protocol requirements are
required to have the same attribute (or one with superset parameter indices).
For less-visible witnesses, an implicit @differentiable attribute is created
with the same configuration as the requirement's.

Resolves TF-1117.


This usability improvement was discussed during the
2020-01-17 Swift for TensorFlow open design review.
We agreed that it's a good idea!

Todo: upstream changes to master.


Example:

public protocol Layer: Differentiable {
  associatedtype Input: Differentiable
  associatedtype Output: Differentiable
  @differentiable(wrt: (self, input))
  func callAsFunction(_ input: Input) -> Output
}

// Internal conforming type.
struct DummyInternalLayer: Layer {
  func callAsFunction(_ input: Float) -> Float {
    return input
  }
}

Before (misleading diagnostic due to TF-1014):

layer.swift:9:8: error: type 'DummyInternalLayer' does not conform to protocol 'Layer'
struct DummyInternalLayer: Layer {
       ^
layer.swift:2:18: note: protocol requires nested type 'Input'; do you want to add it?
  associatedtype Input: Differentiable
                 ^
layer.swift:3:18: note: protocol requires nested type 'Output'; do you want to add it?
  associatedtype Output: Differentiable
                 ^

After: no error.

An implicit @differentiable attribute is created for DummyInternalLayer.callAsFunction:

$ swiftc -print-ast layer.swift
...

internal struct DummyInternalLayer : Layer {
  @differentiable(wrt: (self, input))
  internal func callAsFunction(_ input: Float) -> Float
  ...
}

Previously, all witnesses of a `@differentiable` protocol requirement were
required to have the same attribute (or one with superset parameter indices).

However, this leads to many annotations on witnesses and is not ideal for
usability. `@differentiable` attributes are really only significant on
public witnesses, so that they are clearly `@differentiable` at a glance (in
source code, interface files, and API documentation), without looking through
protocol conformance hierarchies.

Now, only *public* witnesses of `@differentiable` protocol requirements are
required to have the same attribute (or one with superset parameter indices).
For less-visible witnesses, an implicit `@differentiable` attribute is created
with the same configuration as the requirement's.

Resolves TF-1117.
@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label Feb 11, 2020
@dan-zheng dan-zheng requested review from rxwei and marcrasi February 11, 2020 23:23
Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the implementation consistent with #29307. LGTM.

@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing: Creating implicit attributes leads to invalid source locations. Have you handled such @differentiable attributes differentiation diagnostics? In the transform, the differentiation invoker will be DifferentiableAttribute, and any diagnostics will try to get the source location of the implicit attribute.

@marcrasi
Copy link

I thought of a possible weird edge case. What if the conformance is declared in a separate file from the file where the function satisfying the requirement is declared? Then the differentiation pass on the file with the function never sees the implicit differentiable attribute because it only exists in the separate compilation where the conformance is declared.

This edge case probably also applies to the existing subset logic, so it's probably an existing problem that can be dealt with later.

@rxwei
Copy link
Contributor

rxwei commented Feb 12, 2020

@marcrasi Interesting case. One possible solution is to disallow any implicit attribute inheritance when the witness is in a separate translation unit. We’ll just emit an error in protocol witness matching, stating which attribute the potential witness is missing.

Do not generate implicit `@differentiable` attribute on protocol witness when
conformance is in a different file.

Gardening. Add tests.
@dan-zheng
Copy link
Contributor Author

I thought of a possible weird edge case. What if the conformance is declared in a separate file from the file where the function satisfying the requirement is declared? Then the differentiation pass on the file with the function never sees the implicit differentiable attribute because it only exists in the separate compilation where the conformance is declared.

I added cross-file tests for this edge case:

Indeed, implicit @differentiable attributes cannot be generated for protocol witnesses when the conformance is declared from a separate file from the witnesses. Otherwise, compilation of the file containing the conformance creates external references to symbols for implicit @differentiable attributes, even though no such symbols exist.

I implemented @rxwei's solution: a diagnostic is emitted for these cases.


One thing: Creating implicit attributes leads to invalid source locations. Have you handled such @differentiable attributes differentiation diagnostics? In the transform, the differentiation invoker will be DifferentiableAttribute, and any diagnostics will try to get the source location of the implicit attribute.

Implicit protocol witness @differentiable attributes are actually created with the source location from the protocol requirement's @differentiable attribute.

I added a test to verify this:

protocol Protocol: Differentiable {
  @differentiable(wrt: (self, x))
  func internalMethod(_ x: Float) -> Float
}

struct ConformingStruct: Protocol {
  // Non-differentiable body!
  func internalMethod(_ x: Float) -> Float {
    return Float(Int(x))
  }
}
test.swift:2:4: error: function is not differentiable
  @differentiable(wrt: (self, x))
  ~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
test.swift:8:8: note: when differentiating this function definition
  func internalMethod(_ x: Float) -> Float {
       ^
test.swift:9:18: note: cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?
    return Float(Int(x))
                 ^

I think these source locations are reasonable. Let me know if you have better suggestions.


Ready for re-review! Changes will be upstreamed to master.

@rxwei
Copy link
Contributor

rxwei commented Feb 13, 2020

I thought of a possible weird edge case. What if the conformance is declared in a separate file from the file where the function satisfying the requirement is declared? Then the differentiation pass on the file with the function never sees the implicit differentiable attribute because it only exists in the separate compilation where the conformance is declared.

I added cross-file tests for this edge case:

Indeed, implicit @differentiable attributes cannot be generated for protocol witnesses when the conformance is declared from a separate file from the witnesses. Otherwise, compilation of the file containing the conformance creates external references to symbols for implicit @differentiable attributes, even though no such symbols exist.

I implemented @rxwei's solution: a diagnostic is emitted for these cases.

One thing: Creating implicit attributes leads to invalid source locations. Have you handled such @differentiable attributes differentiation diagnostics? In the transform, the differentiation invoker will be DifferentiableAttribute, and any diagnostics will try to get the source location of the implicit attribute.

Implicit protocol witness @differentiable attributes are actually created with the source location from the protocol requirement's @differentiable attribute.

I added a test to verify this:

protocol Protocol: Differentiable {
  @differentiable(wrt: (self, x))
  func internalMethod(_ x: Float) -> Float
}

struct ConformingStruct: Protocol {
  // Non-differentiable body!
  func internalMethod(_ x: Float) -> Float {
    return Float(Int(x))
  }
}
test.swift:2:4: error: function is not differentiable
  @differentiable(wrt: (self, x))
  ~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
test.swift:8:8: note: when differentiating this function definition
  func internalMethod(_ x: Float) -> Float {
       ^
test.swift:9:18: note: cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?
    return Float(Int(x))
                 ^

I think these source locations are reasonable. Let me know if you have better suggestions.

Ready for re-review! Changes will be upstreamed to master.

It is great that we don't have invalid source locations, but this error message could be improved -- it's currently showing an error on the protocol definition as if the protocol was defined wrong, which is misleading. Instead, we should show an error on the declaration that actually triggers the error, along with a note that points us to the protocol requirement that had the @differentiable attribute.

…bute".

Add dedicated diagnostic for missing `@differentiable` attribute for non-public
protocol witness declared in a different file than the conformance.
@dan-zheng
Copy link
Contributor Author

It is great that we don't have invalid source locations, but this error message could be improved -- it's currently showing an error on the protocol definition as if the protocol was defined wrong, which is misleading. Instead, we should show an error on the declaration that actually triggers the error, along with a note that points us to the protocol requirement that had the @differentiable attribute.

Done in 03a22e5:

test.swift:20:8: error: function is not differentiable
  func internalMethod(_ x: Float) -> Float {
       ^~~~~~~~~~~~~~
test.swift:20:8: note: when differentiating this function definition
  func internalMethod(_ x: Float) -> Float {
       ^
test.swift:22:18: note: cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?
    return Float(Int(x))
                 ^

Ready for re-review.

@rxwei
Copy link
Contributor

rxwei commented Feb 14, 2020

Hmm the updated implementation is not showing a note on the protocol requirement that had the attribute.

@dan-zheng
Copy link
Contributor Author

Hmm the updated implementation is not showing a note on the protocol requirement that had the attribute.

Done:

test.swift:17:15: error: type 'PublicConformingStruct' does not conform to protocol 'InternalProtocol'
public struct PublicConformingStruct: InternalProtocol {
              ^
test.swift:20:15: note: candidate is missing attribute '@differentiable'
  public func publicMethod(_ x: Float) -> Float {
              ^
         @differentiable
test.swift:8:4: note: required '@differentiable' attribute declared here
  @differentiable(wrt: (self, x))
  ~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
test.swift:9:8: note: protocol requires function 'publicMethod' with type '(Float) -> Float'; do you want to add a stub?
  func publicMethod(_ x: Float) -> Float
       ^

The last diagnostic is misleading because function publicMethod has been declared, but changing/removing it seems non-trivial. I'd like to defer further diagnostic improvements: TF-1154.

@rxwei
Copy link
Contributor

rxwei commented Feb 14, 2020

I was not referring to the attribute-missing diagnostic. I was referring to the differentiation error.

test.swift:20:8: error: function is not differentiable
  func internalMethod(_ x: Float) -> Float {
       ^~~~~~~~~~~~~~
test.swift:20:8: note: when differentiating this function definition
  func internalMethod(_ x: Float) -> Float {
       ^
test.swift:22:18: note: cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?
    return Float(Int(x))
                 ^

This error should show a note that points to the trigger of differentiation, aka. the @differentiable attribute on the protocol requirement that got implicitly inherited.

Use witness declaration's source location instead of protocol requirement
attribute's source location.
Store the location of inherited protocol requirement `@differentiable`
attributes for use in diagnostics.

This improves non-differentiability diagnostics for implicitly inherited
`@differentiable` attributes.
@dan-zheng
Copy link
Contributor Author

This error should show a note that points to the trigger of differentiation, aka. the @differentiable attribute on the protocol requirement that got implicitly inherited.

Oops, sorry for the misunderstanding.

I reverted a commit in the wrong direction.
Implicitly inherited @differentiable attributes now have a dedicated diagnostic:

protocol Protocol: Differentiable {
  @differentiable(wrt: (self, x))
  func internalMethod(_ x: Float) -> Float
}

struct ConformingStruct: Protocol {
  func internalMethod(_ x: Float) -> Float {
    return Float(Int(x))
  }
}
test.swift:7:8: error: function is not differentiable
  func internalMethod(_ x: Float) -> Float {
       ^~~~~~~~~~~~~~
test.swift:2:4: note: differentiability required by the corresponding protocol requirement here
  @differentiable(wrt: (self, x))
   ^~~~~~~~~~~~~~
test.swift:8:18: note: cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?
    return Float(Int(x))
                 ^

/// This is set during conformance type-checking, only for implicit
/// `@differentiable` attributes created for non-public protocol witnesses of
/// protocol requirements with `@differentiable` attributes.
SourceLoc ImplicitlyInheritedDifferentiableAttrLocation;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: it's possible to generalize this "diagnostic source location" for other kinds of implicit @differentiable attributes too, not just "ones that are implicitly inherited from protocol requirements".

Currently, there are only three kinds of implicit @differentiable attributes:

  1. This case: @differentiable attributes inherited from protocol requirements.
  2. @differentiable attributes synthesized from others on the same declaration with superset parameter indices.
    • @differentiable(wrt: (x, y)) -> @differentiable(wrt: x)
  3. @differentiable attributes synthesized on stored properties of structs/classes that derive a conformance to Differentiable.
enum class ImplicitKind {
  InheritedFromProtocolRequirement, // this case
  SynthesizedSubsetParametersAttribute, // triggered during conformance type-checking
  SynthesizedForStoredProperty, // triggered during `Differentiable` derived conformances
};

I think the first case benefits the most from a "secondary diagnostic source location". Generalization could be done later.

lib/Sema/TypeCheckProtocol.h Outdated Show resolved Hide resolved
Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing all the comments!

@dan-zheng
Copy link
Contributor Author

Thanks for the thorough review!
I'll try upstreaming changes to master first, then merging the changes into tensorflow.

@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@compnerd compnerd closed this Feb 22, 2020
@dan-zheng
Copy link
Contributor Author

I accidentally deleted tensorflow branch, which closed this PR. That was not intentional, sorry!
It would be nice to protect tensorflow branch against deletion to prevent this from happening again.

@dan-zheng dan-zheng reopened this Feb 22, 2020
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

1 similar comment
@dan-zheng
Copy link
Contributor Author

@swift-ci Please test tensorflow

@dan-zheng dan-zheng merged commit a077071 into swiftlang:tensorflow Feb 23, 2020
@dan-zheng dan-zheng deleted the autodiff-usability branch February 23, 2020 17:41
dan-zheng added a commit that referenced this pull request Mar 25, 2020
…30629)

Previously, all witnesses of a `@differentiable` protocol requirement were
required to have the same attribute (or one with superset parameter indices).

However, this leads to many annotations on witnesses and is not ideal for
usability. `@differentiable` attributes are really only significant on
public witnesses, so that they are clearly `@differentiable` at a glance (in
source code, interface files, and API documentation), without looking through
protocol conformance hierarchies.

Now, only *public* witnesses of `@differentiable` protocol requirements are
required to have the same attribute (or one with superset parameter indices).
For less-visible witnesses, an implicit `@differentiable` attribute is created
with the same configuration as the requirement's.

Resolves TF-1117.
Upstreams #29771 from tensorflow branch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants