Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[parallelization] new module of brainpy.pnn for auto parallelization of brain models #385

Merged
merged 14 commits into from
Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 105 additions & 105 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.8", "3.9", "3.10" ]
python-version: [ "3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
Expand All @@ -42,45 +42,45 @@ jobs:
cd examples
pytest ../brainpy/

test_linux_py37:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.7"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install jax==0.3.25
pip install jaxlib==0.3.25
pip uninstall brainpy -y
python setup.py install
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
cd examples
pytest ../brainpy/
# test_linux_py37:
# runs-on: ubuntu-latest
# strategy:
# fail-fast: false
# matrix:
# python-version: ["3.7"]
#
# steps:
# - uses: actions/checkout@v2
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v2
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# python -m pip install flake8 pytest
# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
# pip install jax==0.3.25
# pip install jaxlib==0.3.25
# pip uninstall brainpy -y
# python setup.py install
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# - name: Test with pytest
# run: |
# cd examples
# pytest ../brainpy/
#
test_macos:
runs-on: macos-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
Expand All @@ -106,46 +106,46 @@ jobs:
cd examples
pytest ../brainpy/

test_macos_py37:
runs-on: macos-latest
strategy:
fail-fast: false
matrix:
python-version: [ "3.7" ]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install jax==0.3.25
pip install jaxlib==0.3.25
pip uninstall brainpy -y
python setup.py install
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
cd examples
pytest ../brainpy/
# test_macos_py37:
# runs-on: macos-latest
# strategy:
# fail-fast: false
# matrix:
# python-version: [ "3.7" ]
#
# steps:
# - uses: actions/checkout@v2
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v2
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# python -m pip install flake8 pytest
# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
# pip install jax==0.3.25
# pip install jaxlib==0.3.25
# pip uninstall brainpy -y
# python setup.py install
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# - name: Test with pytest
# run: |
# cd examples
# pytest ../brainpy/
#

test_windows:
runs-on: windows-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
Expand All @@ -158,8 +158,8 @@ jobs:
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install numpy>=1.21.0
python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz
python -m pip install "jaxlib==0.4.10" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install jax==0.4.10
python -m pip install -r requirements-dev.txt
python -m pip install tqdm brainpylib
pip uninstall brainpy -y
Expand All @@ -175,37 +175,37 @@ jobs:
cd examples
pytest ../brainpy/

test_windows_py37:
runs-on: windows-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.7"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install numpy>=1.21.0
python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz
python -m pip install -r requirements-dev.txt
python -m pip install tqdm brainpylib
pip uninstall brainpy -y
python setup.py install
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
cd examples
pytest ../brainpy/
# test_windows_py37:
# runs-on: windows-latest
# strategy:
# fail-fast: false
# matrix:
# python-version: ["3.7"]
#
# steps:
# - uses: actions/checkout@v2
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v2
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# python -m pip install flake8 pytest
# python -m pip install numpy>=1.21.0
# python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz
# python -m pip install -r requirements-dev.txt
# python -m pip install tqdm brainpylib
# pip uninstall brainpy -y
# python setup.py install
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# - name: Test with pytest
# run: |
# cd examples
# pytest ../brainpy/
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,4 @@ cython_debug/
/examples/training_snn_models/data/
/docs/tutorial_advanced/data/
/my_tests/
/examples/dynamics_simulation/Joglekar_2018_data/
39 changes: 22 additions & 17 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@
# Part 3: Models #
# ---------------- #

from brainpy import (channels, # channel models
layers, # ANN layers
neurons, # neuron groups
synapses, # synapses
rates, # rate models
experimental,
)
from brainpy import (
channels, # channel models
layers, # ANN layers
neurons, # neuron groups
synapses, # synapses
rates, # rate models
experimental,
pnn, # parallel SNN models
)
from brainpy.synapses import (synouts, # synaptic output
synplast, ) # synaptic plasticity

Expand All @@ -78,16 +80,17 @@
from brainpy._src.context import share
from brainpy._src.dynsys import not_pass_shared
# running
from brainpy._src.dyn.runners import (DSRunner as DSRunner)
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
from brainpy._src.runners import (DSRunner as DSRunner)
from brainpy._src.transform import (LoopOverTime as LoopOverTime, )
# DynamicalSystem base classes
from brainpy._src.dynsys import (DynamicalSystemNS as DynamicalSystemNS,
NeuGroupNS as NeuGroupNS,
TwoEndConnNS as TwoEndConnNS,
)
from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS,
SynSTPNS as SynSTPNS,
SynConnNS as SynConnNS, )
from brainpy._src.dynsys import (
DynamicalSystemNS as DynamicalSystemNS,
NeuGroupNS as NeuGroupNS,
TwoEndConnNS as TwoEndConnNS,
)
from brainpy._src.synapses_v2.base import (SynOutNS as SynOutNS,
SynSTPNS as SynSTPNS,
SynConnNS as SynConnNS, )


# Part 4: Training #
Expand All @@ -114,6 +117,8 @@
# ---------------------- #


math.__dict__['sparse_matmul'] = math.sparse.seg_matmul

math.__dict__['event_matvec_prob_conn_homo_weight'] = math.jitconn.event_mv_prob_homo
math.__dict__['event_matvec_prob_conn_uniform_weight'] = math.jitconn.event_mv_prob_uniform
math.__dict__['event_matvec_prob_conn_normal_weight'] = math.jitconn.event_mv_prob_normal
Expand Down Expand Up @@ -242,7 +247,7 @@
dyn.__dict__['OUProcess'] = neurons.OUProcess

# synapses
from brainpy._src.dyn.synapses import compat
from brainpy._src.synapses import compat
dyn.__dict__['DeltaSynapse'] = compat.DeltaSynapse
dyn.__dict__['ExpCUBA'] = compat.ExpCUBA
dyn.__dict__['ExpCOBA'] = compat.ExpCOBA
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from brainpy import optim, losses
from brainpy._src.analysis import utils, base, constants
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.dyn.runners import check_and_format_inputs, _f_ops
from brainpy._src.runners import check_and_format_inputs, _f_ops
from brainpy._src.tools.dicts import DotDict
from brainpy.errors import AnalyzerError, UnsupportedError
from brainpy.types import ArrayType
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from brainpy._src.math.environment import get_float
from brainpy._src.math.interoperability import as_jax
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.dyn.runners import DSRunner
from brainpy._src.runners import DSRunner
from brainpy._src.integrators.base import Integrator
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.base import ODEIntegrator
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
6 changes: 2 additions & 4 deletions brainpy/_src/checkpoints/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@
try:
from jax.experimental.array_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa
except (ModuleNotFoundError, ImportError):
try:
from jax.experimental.gda_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa
except (ModuleNotFoundError, ImportError):
get_tensorstore_spec = None
get_tensorstore_spec = None
GlobalAsyncCheckpointManager = None

from brainpy._src.math.ndarray import Array
from brainpy.errors import (AlreadyExistsError,
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def load(self, key, value: Any = None):
return self._arguments[key]
if value is None:
raise KeyError(f'Cannot found shared data of {key}. '
f'Please define it with "brainpy.share.save()". ')
f'Please define it with "brainpy.share.save({key}=<?>)". ')
else:
return value

Expand Down
Loading