Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dtype casting for constants in activation functions. #3

Merged
merged 6 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
name: 'Bug Report'
about: 'Report a bug or unexpected behavior to help us improve the package'
labels: 'bug'
---

Please:

- [ ] Check for duplicate issues.
- [ ] Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:

```python
import braintools as bt
```

- [ ] If applicable, include full error messages/tracebacks.
10 changes: 10 additions & 0 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
name: 'Feature Request'
about: 'Suggest a new idea or improvement for braintools'
labels: 'enhancement'
---

Please:

- [ ] Check for duplicate requests.
- [ ] Describe your goal, and if possible provide a code snippet with a motivating example.
37 changes: 37 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<!--- Provide a general summary of your changes in the Title above -->
<!--- Please do remember to follow the contributing guidelines -->

## Description
<!--- Why is this change required? What problem does it solve? -->
<!--- Describe your changes in detail here to communicate to the maintainers why this pull request should be accepted -->
<!--- Describe your technology stack here if not a documentation update -->
<!--- Tasklist format is recommended for all pull requests and is required for all draft pull requests. You can couple your description with the tasklist -->
<!--- If it fixes an open issue, please link to the issue here in the last line. -->

## How Has This Been Tested
<!--- Please describe in detail how you tested your changes locally -->
<!--- Include details of your testing environment, and the tests you ran to -->
<!--- For example, markdown files should pass markdownlint locally according to the rules -->
<!--- See how your change affects other areas of the code, etc. -->

## Types of changes
<!--- What types of changes does your code introduce? -->
<!--- Only left the line that best describes this pull request -->
- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Documentation (non-breaking change which updates documentation)
- Breaking change (fix or feature that would cause existing functionality to change)
- Code style (formatting, renaming)
- Refactoring (no functional changes, no api changes)
- Other (please describe here):

## Checklist
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
- [ ] Code follows the code style of this project.
- [ ] Changes follow the **CONTRIBUTING** guidelines.
- [ ] Update necessary documentation accordingly.
- [ ] Lint and tests pass locally with the changes.
- [ ] Check issues and pull requests first. You don't want to duplicate effort.

## Other information
26 changes: 26 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates

version: 2
updates:
- package-ecosystem: "pip"
directory: "/"
schedule:
interval: "daily"
allow:
- dependency-type: "all"
commit-message:
prefix: ":arrow_up:"
open-pull-requests-limit: 50

- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
allow:
- dependency-type: "all"
commit-message:
prefix: ":arrow_up:"
open-pull-requests-limit: 50
110 changes: 110 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Continuous Integration

on:
push:
branches:
- '**' # matches every branch
pull_request:
branches:
- '**' # matches every branch


permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows

# This is what will cancel the workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true


jobs:
test_linux:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [ "3.9", "3.10", "3.11", "3.12" ]

steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.12.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v4
- name: Print concurrency group
run: echo '${{ github.workflow }}-${{ github.ref }}'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install .
- name: Test with pytest
run: |
pytest braintools/


test_macos:
runs-on: macos-latest
strategy:
fail-fast: false
matrix:
python-version: [ "3.9", "3.10", "3.11", "3.12" ]

steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.12.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v4
- name: Print concurrency group
run: echo '${{ github.workflow }}-${{ github.ref }}'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install .
- name: Test with pytest
run: |
pytest braintools/


test_windows:
runs-on: windows-latest
strategy:
fail-fast: false
matrix:
python-version: [ "3.9", "3.10", "3.11", "3.12" ]

steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@0.12.1
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v4
- name: Print concurrency group
run: echo '${{ github.workflow }}-${{ github.ref }}'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools
python -m pip install -r requirements-dev.txt
pip install .
- name: Test with pytest
run: |
pytest braintools/ -p no:faulthandler
17 changes: 17 additions & 0 deletions .github/workflows/Publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Publish to PyPI.org
on:
release:
types: [published]
jobs:
pypi:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- run: python setup.py bdist_wheel --python-tag=py3
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<p align="center">
<a href="https://pypi.org/project/braintools/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/braintools"></a>
<a href="https://github.com/brainpy/braintools"><img alt="LICENSE" src="https://anaconda.org/brainpy/brainpy/badges/license.svg"></a>
<a href="https://github.com/brainpy/braintools/blob/main/LICENSE"><img alt="LICENSE" src="https://img.shields.io/badge/License-GPLv3-blue.svg"></a>
<a href="https://brainpy.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation" src="https://readthedocs.org/projects/brainpy/badge/?version=latest"></a>
<a href="https://badge.fury.io/py/braintools"><img alt="PyPI version" src="https://badge.fury.io/py/braintools.svg"></a>
<a href="https://github.com/brainpy/braintools/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/brainpy/braintools/actions/workflows/CI.yml/badge.svg"></a>
Expand Down
4 changes: 2 additions & 2 deletions braintools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

__version__ = "0.0.1"

from . import metric
from . import inputs
from . import init
from . import optim
from . import functional

__all__ = ['inputs', 'init', 'optim', 'functional']

__all__ = ['inputs', 'init', 'optim', 'functional', 'metric']

4 changes: 2 additions & 2 deletions braintools/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

from ._activations import *
from ._activations import __all__ as __activations_all__
from ._others import *
from ._others import __all__ as __others_all__
from .normalization import *
from .normalization import __all__ as __others_all__
from ._spikes import *
from ._spikes import __all__ as __spikes_all__

Expand Down
48 changes: 39 additions & 9 deletions braintools/functional/_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike
import braincore as bc

__all__ = [
"relu",
Expand Down Expand Up @@ -42,6 +43,18 @@
]


def _get_dtype(x: ArrayLike):
if hasattr(x, 'dtype'):
return x.dtype
else:
if isinstance(x, float):
return bc.environ.dftype()
elif isinstance(x, int):
return bc.environ.dftype()
else:
raise ValueError(f'Unsupported type: {type(x)}')


def softmin(x, axis=-1):
r"""Applies the Softmin function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
Expand Down Expand Up @@ -93,7 +106,10 @@ def prelu(x, a=0.25):
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
a separate :math:`a` is used for each input channel.
"""
return jnp.where(x >= 0., x, a * x)
dtype = _get_dtype(x)
return jnp.where(x >= jnp.asarray(0., dtype),
x,
jnp.asarray(a, dtype) * x)


def soft_shrink(x, lambd=0.5):
Expand All @@ -114,7 +130,11 @@ def soft_shrink(x, lambd=0.5):
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
"""
return jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0.))
dtype = _get_dtype(x)
lambd = jnp.asarray(lambd, dtype)
return jnp.where(x > lambd,
x - lambd,
jnp.where(x < -lambd, x + lambd, jnp.asarray(0., dtype)))


def mish(x):
Expand All @@ -135,7 +155,7 @@ def mish(x):
return x * jnp.tanh(softplus(x))


def rrelu(key, x, lower=0.125, upper=0.3333333333333333):
def rrelu(x, lower=0.125, upper=0.3333333333333333):
r"""Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper:

Expand Down Expand Up @@ -166,9 +186,9 @@ def rrelu(key, x, lower=0.125, upper=0.3333333333333333):
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
https://arxiv.org/abs/1505.00853
"""
x = jnp.asarray(x)
a = jax.random.uniform(key, x.shape, x.dtype, lower, upper)
return jnp.where(x >= 0., x, a * x)
dtype = _get_dtype(x)
a = bc.random.uniform(lower, upper, size=jnp.shape(x), dtype=dtype)
return jnp.where(x >= jnp.asarray(0., dtype), x, jnp.asarray(a, dtype) * x)


def hard_shrink(x, lambd=0.5):
Expand All @@ -192,7 +212,11 @@ def hard_shrink(x, lambd=0.5):
- Output: :math:`(*)`, same shape as the input.

"""
return jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0.))
dtype = _get_dtype(x)
lambd = jnp.asarray(lambd, dtype)
return jnp.where(x > lambd,
x,
jnp.where(x < -lambd, x, jnp.asarray(0., dtype)))


def relu(x: ArrayLike) -> jax.Array:
Expand Down Expand Up @@ -229,7 +253,6 @@ def relu(x: ArrayLike) -> jax.Array:
return jax.nn.relu(x)



def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
r"""Squareplus activation function.

Expand All @@ -244,7 +267,8 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
x : input array
b : smoothness parameter
"""
return jax.nn.squareplus(x, b)
dtype = _get_dtype(x)
return jax.nn.squareplus(x, jnp.asarray(b, dtype))


def softplus(x: ArrayLike) -> jax.Array:
Expand Down Expand Up @@ -362,6 +386,8 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
See also:
:func:`selu`
"""
dtype = _get_dtype(x)
alpha = jnp.asarray(alpha, dtype)
return jax.nn.elu(x, alpha)


Expand All @@ -388,6 +414,8 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
See also:
:func:`relu`
"""
dtype = _get_dtype(x)
negative_slope = jnp.asarray(negative_slope, dtype)
return jax.nn.leaky_relu(x, negative_slope=negative_slope)


Expand Down Expand Up @@ -434,6 +462,8 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
Returns:
An array.
"""
dtype = _get_dtype(x)
alpha = jnp.asarray(alpha, dtype)
return jax.nn.celu(x, alpha)


Expand Down
File renamed without changes.
Loading
Loading