-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Relax @differentiable
requirement for protocol witnesses.
#29771
Conversation
Previously, all witnesses of a `@differentiable` protocol requirement were required to have the same attribute (or one with superset parameter indices). However, this leads to many annotations on witnesses and is not ideal for usability. `@differentiable` attributes are really only significant on public witnesses, so that they are clearly `@differentiable` at a glance (in source code, interface files, and API documentation), without looking through protocol conformance hierarchies. Now, only *public* witnesses of `@differentiable` protocol requirements are required to have the same attribute (or one with superset parameter indices). For less-visible witnesses, an implicit `@differentiable` attribute is created with the same configuration as the requirement's. Resolves TF-1117.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes the implementation consistent with #29307. LGTM.
@swift-ci Please test tensorflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing: Creating implicit attributes leads to invalid source locations. Have you handled such @differentiable
attributes differentiation diagnostics? In the transform, the differentiation invoker will be DifferentiableAttribute
, and any diagnostics will try to get the source location of the implicit attribute.
I thought of a possible weird edge case. What if the conformance is declared in a separate file from the file where the function satisfying the requirement is declared? Then the differentiation pass on the file with the function never sees the implicit differentiable attribute because it only exists in the separate compilation where the conformance is declared. This edge case probably also applies to the existing subset logic, so it's probably an existing problem that can be dealt with later. |
@marcrasi Interesting case. One possible solution is to disallow any implicit attribute inheritance when the witness is in a separate translation unit. We’ll just emit an error in protocol witness matching, stating which attribute the potential witness is missing. |
Do not generate implicit `@differentiable` attribute on protocol witness when conformance is in a different file. Gardening. Add tests.
31797e0
to
4415d58
Compare
I added cross-file tests for this edge case: Indeed, implicit I implemented @rxwei's solution: a diagnostic is emitted for these cases.
Implicit protocol witness I added a test to verify this: protocol Protocol: Differentiable {
@differentiable(wrt: (self, x))
func internalMethod(_ x: Float) -> Float
}
struct ConformingStruct: Protocol {
// Non-differentiable body!
func internalMethod(_ x: Float) -> Float {
return Float(Int(x))
}
}
I think these source locations are reasonable. Let me know if you have better suggestions. Ready for re-review! Changes will be upstreamed to |
It is great that we don't have invalid source locations, but this error message could be improved -- it's currently showing an error on the protocol definition as if the protocol was defined wrong, which is misleading. Instead, we should show an error on the declaration that actually triggers the error, along with a note that points us to the protocol requirement that had the |
…bute". Add dedicated diagnostic for missing `@differentiable` attribute for non-public protocol witness declared in a different file than the conformance.
adf7bda
to
03a22e5
Compare
Done in 03a22e5: test.swift:20:8: error: function is not differentiable
func internalMethod(_ x: Float) -> Float {
^~~~~~~~~~~~~~
test.swift:20:8: note: when differentiating this function definition
func internalMethod(_ x: Float) -> Float {
^
test.swift:22:18: note: cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?
return Float(Int(x))
^ Ready for re-review. |
Hmm the updated implementation is not showing a note on the protocol requirement that had the attribute. |
Done: test.swift:17:15: error: type 'PublicConformingStruct' does not conform to protocol 'InternalProtocol'
public struct PublicConformingStruct: InternalProtocol {
^
test.swift:20:15: note: candidate is missing attribute '@differentiable'
public func publicMethod(_ x: Float) -> Float {
^
@differentiable
test.swift:8:4: note: required '@differentiable' attribute declared here
@differentiable(wrt: (self, x))
~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
test.swift:9:8: note: protocol requires function 'publicMethod' with type '(Float) -> Float'; do you want to add a stub?
func publicMethod(_ x: Float) -> Float
^ The last diagnostic is misleading because function |
I was not referring to the attribute-missing diagnostic. I was referring to the differentiation error.
This error should show a note that points to the trigger of differentiation, aka. the |
Use witness declaration's source location instead of protocol requirement attribute's source location.
441a497
to
2461721
Compare
Store the location of inherited protocol requirement `@differentiable` attributes for use in diagnostics. This improves non-differentiability diagnostics for implicitly inherited `@differentiable` attributes.
2461721
to
130ccd3
Compare
Oops, sorry for the misunderstanding. I reverted a commit in the wrong direction. protocol Protocol: Differentiable {
@differentiable(wrt: (self, x))
func internalMethod(_ x: Float) -> Float
}
struct ConformingStruct: Protocol {
func internalMethod(_ x: Float) -> Float {
return Float(Int(x))
}
} test.swift:7:8: error: function is not differentiable
func internalMethod(_ x: Float) -> Float {
^~~~~~~~~~~~~~
test.swift:2:4: note: differentiability required by the corresponding protocol requirement here
@differentiable(wrt: (self, x))
^~~~~~~~~~~~~~
test.swift:8:18: note: cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?
return Float(Int(x))
^ |
/// This is set during conformance type-checking, only for implicit | ||
/// `@differentiable` attributes created for non-public protocol witnesses of | ||
/// protocol requirements with `@differentiable` attributes. | ||
SourceLoc ImplicitlyInheritedDifferentiableAttrLocation; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: it's possible to generalize this "diagnostic source location" for other kinds of implicit @differentiable
attributes too, not just "ones that are implicitly inherited from protocol requirements".
Currently, there are only three kinds of implicit @differentiable
attributes:
- This case:
@differentiable
attributes inherited from protocol requirements. @differentiable
attributes synthesized from others on the same declaration with superset parameter indices.@differentiable(wrt: (x, y))
->@differentiable(wrt: x)
@differentiable
attributes synthesized on stored properties of structs/classes that derive a conformance toDifferentiable
.
enum class ImplicitKind {
InheritedFromProtocolRequirement, // this case
SynthesizedSubsetParametersAttribute, // triggered during conformance type-checking
SynthesizedForStoredProperty, // triggered during `Differentiable` derived conformances
};
I think the first case benefits the most from a "secondary diagnostic source location". Generalization could be done later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing all the comments!
Thanks for the thorough review! |
@swift-ci Please test tensorflow |
I accidentally deleted |
@swift-ci Please test tensorflow |
1 similar comment
@swift-ci Please test tensorflow |
…30629) Previously, all witnesses of a `@differentiable` protocol requirement were required to have the same attribute (or one with superset parameter indices). However, this leads to many annotations on witnesses and is not ideal for usability. `@differentiable` attributes are really only significant on public witnesses, so that they are clearly `@differentiable` at a glance (in source code, interface files, and API documentation), without looking through protocol conformance hierarchies. Now, only *public* witnesses of `@differentiable` protocol requirements are required to have the same attribute (or one with superset parameter indices). For less-visible witnesses, an implicit `@differentiable` attribute is created with the same configuration as the requirement's. Resolves TF-1117. Upstreams #29771 from tensorflow branch.
Previously, all witnesses of a
@differentiable
protocol requirement wererequired to have the same attribute (or one with superset parameter indices).
However, this leads to many annotations on witnesses and is not ideal for
usability.
@differentiable
attributes are really only significant onpublic witnesses, so that they are clearly
@differentiable
at a glance (insource code, interface files, and API documentation), without looking through
protocol conformance hierarchies.
Now, only public witnesses of
@differentiable
protocol requirements arerequired to have the same attribute (or one with superset parameter indices).
For less-visible witnesses, an implicit
@differentiable
attribute is createdwith the same configuration as the requirement's.
Resolves TF-1117.
This usability improvement was discussed during the
2020-01-17 Swift for TensorFlow open design review.
We agreed that it's a good idea!
Todo: upstream changes to
master
.Example:
Before (misleading diagnostic due to TF-1014):
After: no error.
An implicit
@differentiable
attribute is created forDummyInternalLayer.callAsFunction
: