Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep Learning Extensions #158

Closed
szha opened this issue Apr 7, 2021 · 8 comments
Closed

Deep Learning Extensions #158

szha opened this issue Apr 7, 2021 · 8 comments
Labels
API extension Adds new functions or objects to the API.

Comments

@szha
Copy link
Member

szha commented Apr 7, 2021

Hi,

Now that we have a good foundation on the core array API in the standard, it's probably a good time to start thinking about the neural network extensions. As an initial step, here are the neural network operations from a few selected deep learning frameworks:

TF: https://www.tensorflow.org/api_docs/python/tf/nn
Pytorch: https://pytorch.org/docs/stable/nn.functional.html
MXNet: https://mxnet.apache.org/versions/master/api/python/docs/api/npx/index.html
flax (jax): https://flax.readthedocs.io/en/latest/flax.linen.html#linear-modules
haiku (jax): https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules

In addition, I think the array API standard can benefit from the model exchange format definition of ONNX. Here are the operators that are currently in ONNX opsets.
ONNX: https://github.com/onnx/onnx/blob/master/docs/Operators.md

Next step would be to figure out a good set of operators in the intersection and iterate through the design choices in them.

@learning-chip
Copy link

Not sure if relevant -- mlir-npcomp (https://github.com/llvm/mlir-npcomp) does the conversion of NumPy -> MLIR, but it doesn't feel very mature yet.

@rgommers
Copy link
Member

One question I have is whether there's any kind of commonality to APIs today. Some of the ones I checked that I expected to be "simplest" are functions like softmax and avg_pool. softmax seems to be the same except for in MXNet. Overall it's not easy to find many functions that overlap well though.

@szha
Copy link
Member Author

szha commented Apr 12, 2021

@rgommers valid question. I think there is enough commonality for basic use cases to start standardizing. The existence of ONNX and the possibility for mapping operators from TF/PT/MX to ONNX is to some extent an evidence of that. In fact, ONNX has been focusing on intersections of operator sets from different frameworks so far so it should provide a good starting point. That said, because deep learning is newer, it's more likely to have operators across frameworks that have semantic equivalence than those having identical definition.

We will likely need some more analysis and comparison to tell. I'm hoping to contribute some as soon as I have free time.

@szha
Copy link
Member Author

szha commented Apr 15, 2021

Activation Functions

Here I summarize a few activation functions in ONNX, PyTorch, Flax (JAX), Tensorflow, and MXNet.

celu

kwargs ONNX PyTorch Flax (JAX)
alpha Y Y Y
inplace N Y N
  • Not implemented in Tensorflow, MXNet.

elu

kwargs ONNX PyTorch Flax (JAX) Tensorflow MXNet
alpha Y Y Y N Y
inplace N Y N N N
  • Implemented as part of leaky_relu in MXNet.

gelu

kwargs PyTorch Flax (JAX) Tensorflow MXNet
approximate N Y Y N
  • Not defined in ONNX. MXNet hasn't implemented approximate version yet.

log_softmax

kwargs ONNX PyTorch Flax (JAX) Tensorflow MXNet
axis/dim Y Y Y Y Y
length N N N N Y

relu

kwargs ONNX PyTorch Flax (JAX) Tensorflow MXNet
inplace N Y N N N

sigmoid

All libraries have consistent definition.

soft_sign

All libraries have consistent definition except MXNet.

  • Implemented as part of activation op in MXNet.

softmax

kwargs ONNX PyTorch Flax (JAX) Tensorflow MXNet
axis/dim Y Y Y Y Y
length N N N N Y

silu

kwargs PyTorch Flax (JAX) Tensorflow MXNet
beta N N N Y
  • Not defined in ONNX.

@szha
Copy link
Member Author

szha commented Apr 15, 2021

We mentioned a few open questions in 4/15 meeting:

  • How large do we expect the API surface to be?
  • What criteria should we use for including operators as part of the standard?
    • What's the half-life of these activation functions and how many useful ones stick around?
  • Do we require the operators to be differentiable? If so, should we standardize the gradient definition (e.g. approximate gelu)?

@leofang
Copy link
Contributor

leofang commented Apr 15, 2021

  • Do we require the operators to be differentiable? If so, should we standardize the gradient definition (e.g. approximate gelu)?

I think it's good to have them differentiable, but then all differentiable functions should be grouped in a separate module, say, array_api.dl so that we can differentiate them from non-differentiable functions in the main namespace (and make this module optional for e.g. NumPy/CuPy).

@kgryte kgryte added this to the v2022 milestone Oct 4, 2021
@kgryte kgryte added the API extension Adds new functions or objects to the API. label Oct 4, 2021
@NeilGirdhar
Copy link

Not sure if this is the place to comment, but sigmoid might be a bad name. "Sigmoid" just means s-shaped. The function should really be called logistic: https://en.wikipedia.org/wiki/Logistic_function

@rgommers rgommers removed this from the v2022 milestone Nov 28, 2022
@kgryte
Copy link
Contributor

kgryte commented Jun 29, 2023

As this proposal is without a champion, I'll go ahead and close. Should we see more ecosystem consensus, we can revisit/reopen and consider as a future specification extension.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API.
Projects
None yet
Development

No branches or pull requests

6 participants