Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update develop for potential v8.2 #786

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
a9b0471
Merge pull request #657 from svlandeg/feature/develop_copy
svlandeg May 9, 2022
f9f4fd5
Move compatiblity-related code into a separate `compat` module (#652)
shadeMe May 10, 2022
ffd998c
NumpyOps: Add a method to get a table of C BLAS functions (#643)
danieldk May 10, 2022
40a5a0d
Fix a unit test in the PyTorch wrapper (#663)
danieldk May 12, 2022
abbe0ff
`CupyOps`: Simplify `asarray` (#661)
shadeMe May 17, 2022
d2d7917
NumpyOps: Better type-casting in `asarray` (#656)
shadeMe May 17, 2022
9063168
Fix out-of-bounds writes in NumpyOps/CupyOps (#664)
danieldk May 17, 2022
6d84d00
Set version to v8.1.0.dev0 (#666)
danieldk May 18, 2022
a988fda
Fix model.copy() bug where layer used more than once (#659)
richardpaulhudson May 18, 2022
21d9b86
`conftest.py`: Handle exception caused by `pytest` options being adde…
shadeMe May 19, 2022
43268c2
Auto-format code with `black` + Pin `black` requirement (#673)
shadeMe May 20, 2022
46f3fd1
Add support for bot-invoked slow tests (#672)
shadeMe May 20, 2022
07a7dcf
`Shim`: Fix potential data race when allocated on different threads
shadeMe May 23, 2022
ab16559
Fix two warnings (#676)
danieldk May 23, 2022
99573ff
Replace use of gpu_is_available with has_cupy_gpu (#675)
danieldk May 24, 2022
81035bc
Fixes for slow tests (#671)
shadeMe May 24, 2022
3eb5e52
`test_uniqued`: Disable test timing for `test_uniqued_doesnt_change_r…
shadeMe May 24, 2022
b36492d
`test_to_categorical`: Ensure that `label_smoothing < 0.5` (#680)
shadeMe May 25, 2022
6472d0a
test_ops: do not lower precision in conversion to Torch tensor (#681)
danieldk May 27, 2022
c88e43c
Add `test_slow_gpu` explosion-bot command
shadeMe May 27, 2022
145d782
Auto-format code with black (#682)
github-actions[bot] May 27, 2022
6cc4d97
Azure: pin protobuf to fix Tensorflow
danieldk May 30, 2022
ff32729
Merge pull request #687 from explosion/fix-tensorflow-ciu
danieldk May 31, 2022
9df55f9
Merge pull request #684 from shadeMe/feature/test-slow-gpu-bot
danieldk Jun 1, 2022
abf7d31
Extend typing_extensions to <4.2.0 (#689)
adrianeboyd Jun 2, 2022
46334b5
xp2{tensorflow,torch}: convert NumPy arrays using dlpack (#686)
danieldk Jun 7, 2022
1c6e9f4
`test_model_gpu`: Use TF memory pool if available, feature-gate test …
shadeMe Jun 8, 2022
be65301
Bump version to v8.1.0.dev1 (#694)
danieldk Jun 9, 2022
8d9405f
`NumpyOps`: Do not use global for `CBlas` (#697)
shadeMe Jun 14, 2022
862a489
Merge pytorch-device branch into master (#695)
danieldk Jun 14, 2022
0711e60
Expose `get_torch_default_device` through `thinc.api` (#698)
danieldk Jun 15, 2022
cde169b
Make `CBlas` methods standalone functions to avoid using vtables (#700)
danieldk Jun 15, 2022
2ae2125
Add Dockerfile for building the website (#699)
danieldk Jun 15, 2022
3216d7c
Bump version to v8.1.0.dev2 (#701)
danieldk Jun 15, 2022
3f0082f
Use blis~=0.7.8 (#704)
adrianeboyd Jun 22, 2022
8236b89
Set version to v8.1.0.dev3 (#705)
adrianeboyd Jun 22, 2022
f630270
Speed up HashEmbed layer by avoiding large temporary arrays (#696)
danieldk Jun 23, 2022
0640617
Auto-format code with black (#706)
github-actions[bot] Jun 27, 2022
2ef3f3a
Fix MyPy error when Torch without MPS support is installed (#708)
danieldk Jun 29, 2022
c7b0d67
Check that Torch-verified activations obey `inplace` (#709)
danieldk Jun 30, 2022
7cd060e
Increase test deadline to 30 minutes to prevent spurious test failure…
shadeMe Jul 4, 2022
ea3c08e
`test_mxnet_wrapper`: Feature-gate GPU test (#717)
shadeMe Jul 6, 2022
28b2e8c
Add Ops.reduce_{first,last} plus tests (#710)
danieldk Jul 7, 2022
671d01e
Label smooth threshold fix (#707)
kadarakos Jul 7, 2022
274c41c
Set version to v8.1.0 (#718)
adrianeboyd Jul 7, 2022
9f9e494
`get_array_module` with non-array input returns `None` (#703)
kadarakos Jul 7, 2022
eb5c38b
Update build constraints and requirements for aarch64 wheels (#722)
adrianeboyd Jul 8, 2022
5508f53
Auto-format code with black (#723)
github-actions[bot] Jul 8, 2022
3dcd03d
Fix version string (#724)
adrianeboyd Jul 8, 2022
17846c4
Extend to mypy<0.970 (#725)
adrianeboyd Jul 8, 2022
2da00a3
Fix typo
cclauss Jul 13, 2022
b4b37ce
Merge pull request #727 from cclauss/patch-1
polm Jul 14, 2022
40c129f
Update build constraints for arm64 and aarch64 wheels (#716)
adrianeboyd Jul 18, 2022
5a4f868
Ops: replace FloatsType by constrained typevar (#720)
danieldk Jul 28, 2022
8e5c743
Unroll `argmax` in `maxout` for small sizes of `P` (#702)
danieldk Jul 28, 2022
42b73c9
Change Docker image tag to thinc-ai (#732)
danieldk Aug 3, 2022
69a280f
Add `with_signpost_interval` layer (#711)
danieldk Aug 3, 2022
1846855
Docs: Fix/update `label_smoothing` description, run prettier (#733)
shadeMe Aug 4, 2022
af0e3de
Add Dish activation (#719)
danieldk Aug 4, 2022
7fcdd0f
Auto-format code with black (#737)
github-actions[bot] Aug 5, 2022
d95b5fc
Increment `blis` version upper-bound to `0.10.0` (#736)
shadeMe Aug 5, 2022
01eb6b7
asarrayDf: take `Sequence[float]`, not `Sequence[int]` (#739)
danieldk Aug 5, 2022
a43635e
Use confection for configurations (#745)
rmitsch Aug 26, 2022
eda4c75
`PyTorchGradScaler`: Cache `_found_inf` on the CPU (#746)
shadeMe Aug 29, 2022
a7bbc48
More general remap_ids (#726)
kadarakos Sep 2, 2022
102d654
Auto-format code with black (#753)
github-actions[bot] Sep 5, 2022
fba3bf0
Switch to macos-latest (#755)
adrianeboyd Sep 6, 2022
fc323e1
`util`: Explicitly call `__dlpack__` built-in method in `xp2tensorflo…
shadeMe Sep 7, 2022
9836e9e
Set version to 8.1.1 (#758)
danieldk Sep 9, 2022
97a1a04
Remove references to FastAPI being an Explosion product (#761)
rmitsch Sep 9, 2022
139acbf
Update code example for Ragged (#756)
rmitsch Sep 9, 2022
37958ee
Update setup.cfg (#748)
willfrey Sep 12, 2022
562139e
Update cupy extras, quickstart (#740)
adrianeboyd Sep 13, 2022
20ce703
disable mypy run for Python 3.10 (#768)
svlandeg Sep 15, 2022
4dffe21
Reorder requirements in requirements.txt (#770)
adrianeboyd Sep 16, 2022
cb6edbe
Revert blis range to <0.8.0 (#772)
adrianeboyd Sep 26, 2022
d9c40cf
Set version to v8.1.2 (#773)
adrianeboyd Sep 26, 2022
1eaeb2b
Fix `fix_random_seed` entrypoint in setup.cfg (#775)
pawamoy Sep 27, 2022
3a143d3
Support both Python 3.6 and Pydantic 1.10 (#779)
svlandeg Oct 4, 2022
2e12baa
update to latest mypy and exclude Python 3.6 (#776)
svlandeg Oct 7, 2022
36b691f
Set version to v8.1.3 (#781)
adrianeboyd Oct 7, 2022
07b7a09
Update CI around conflicting extras requirements (#783)
adrianeboyd Oct 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions .github/workflows/autoblack.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# GitHub Action that uses Black to reformat all Python code and submits a PR
# in regular intervals. Inspired by: https://github.com/cclauss/autoblack

name: autoblack
on:
workflow_dispatch: # allow manual trigger
schedule:
- cron: '0 8 * * 5' # every Friday at 8am UTC

jobs:
autoblack:
if: github.repository_owner == 'explosion'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
with:
ref: ${{ github.head_ref }}
- uses: actions/setup-python@v2
- run: pip install black
- name: Auto-format code if needed
run: black thinc
# We can't run black --check here because that returns a non-zero excit
# code and makes GitHub think the action failed
- name: Check for modified files
id: git-check
run: echo ::set-output name=modified::$(if git diff-index --quiet HEAD --; then echo "false"; else echo "true"; fi)
- name: Create Pull Request
if: steps.git-check.outputs.modified == 'true'
uses: peter-evans/create-pull-request@v3
with:
title: Auto-format code with black
labels: meta
commit-message: Auto-format code with black
committer: GitHub <noreply@github.com>
author: explosion-bot <explosion-bot@users.noreply.github.com>
body: _This PR is auto-generated._
branch: autoblack
delete-branch: true
draft: false
- name: Check outputs
if: steps.git-check.outputs.modified == 'true'
run: |
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}"
2 changes: 1 addition & 1 deletion .github/workflows/explosionbot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ jobs:
env:
INPUT_TOKEN: ${{ secrets.EXPLOSIONBOT_TOKEN }}
INPUT_BK_TOKEN: ${{ secrets.BUILDKITE_SECRET }}
ENABLED_COMMANDS: "test_gpu"
ENABLED_COMMANDS: "test_gpu,test_slow,test_slow_gpu"
ALLOWED_TEAMS: "spacy-maintainers"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Thinc: A refreshing functional take on deep learning, compatible with your favorite libraries

### From the makers of [spaCy](https://spacy.io), [Prodigy](https://prodi.gy) and [FastAPI](https://fastapi.tiangolo.com)
### From the makers of [spaCy](https://spacy.io) and [Prodigy](https://prodi.gy)

[Thinc](https://thinc.ai) is a **lightweight deep learning library** that offers an elegant,
type-checked, functional-programming API for **composing models**, with support
Expand Down
29 changes: 21 additions & 8 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
imageName: 'windows-2019'
python.version: '3.6'
Python37Mac:
imageName: 'macos-10.15'
imageName: 'macos-latest'
python.version: '3.7'
Python38Linux:
imageName: 'ubuntu-latest'
Expand Down Expand Up @@ -63,6 +63,7 @@ jobs:
- script: |
python -m mypy thinc
displayName: 'Run mypy'
condition: ne(variables['python.version'], '3.6')

- task: DeleteFiles@1
inputs:
Expand All @@ -82,25 +83,37 @@ jobs:

- script: |
pip install -r requirements.txt
pip install "tensorflow~=2.5.0"
pip install "mxnet; sys_platform != 'win32'"
pip install "torch==1.9.0+cpu" -f https://download.pytorch.org/whl/torch_stable.html
pip install ipykernel pydot graphviz
python -m ipykernel install --name thinc-notebook-tests --user
displayName: 'Install test dependencies'
python -m pytest --pyargs thinc --cov=thinc --cov-report=term
displayName: 'Run tests without extras'

- script: |
pip install "protobuf~=3.20.0" "tensorflow~=2.5.0"
pip install "mxnet; sys_platform != 'win32'"
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
# torch does not have a direct numpy requirement but is compiled against
# a newer version than the oldest supported numpy for windows and
# python 3.10; this version of numpy would not work with
# tensorflow~=2.5.0 as specified above, but there is no release for
# python 3.10 anyway
pip install "numpy~=1.23.0; python_version=='3.10' and sys_platform=='win32'"
pip install -r requirements.txt
pip uninstall -y mypy
displayName: 'Install extras for testing'

- script: |
python -m pytest --pyargs thinc --cov=thinc --cov-report=term
displayName: 'Run tests'
displayName: 'Run tests with extras'

- script: |
pip uninstall -y tensorflow
pip install thinc-apple-ops
python -m pytest --pyargs thinc_apple_ops
displayName: 'Run tests for thinc-apple-ops'
condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.9'))
condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.10'))

- script: |
python -m pytest --pyargs thinc
displayName: 'Run tests with thinc-apple-ops'
condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.9'))
condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.10'))
6 changes: 4 additions & 2 deletions build-constraints.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# build version constraints for use with wheelwright + multibuild
numpy==1.15.0; python_version<='3.7'
numpy==1.17.3; python_version=='3.8'
numpy==1.15.0; python_version<='3.7' and platform_machine!='aarch64'
numpy==1.19.2; python_version<='3.7' and platform_machine=='aarch64'
numpy==1.17.3; python_version=='3.8' and platform_machine!='aarch64'
numpy==1.19.2; python_version=='3.8' and platform_machine=='aarch64'
numpy==1.19.3; python_version=='3.9'
numpy==1.21.3; python_version=='3.10'
numpy; python_version>='3.11'
9 changes: 7 additions & 2 deletions examples/transformers_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def forward(
return TokensPlus(**token_data), lambda d_tokens: []

return Model(
"tokenizer", forward, attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
"tokenizer",
forward,
attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
)


Expand Down Expand Up @@ -166,11 +168,14 @@ def convert_transformer_outputs(model, inputs_outputs, is_train):

def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs:
# Restore entries for bos and eos markers.
shim = model.shims[0]
row = model.ops.alloc2f(1, d_tokvecs[0].shape[1])
d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs]
return ArgsKwargs(
args=(torch_tokvecs,),
kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))},
kwargs={
"grad_tensors": xp2torch(model.ops.pad(d_tokvecs, device=shim.device))
},
)

return tokvecs, backprop
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
"murmurhash>=1.0.2,<1.1.0",
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"blis>=0.4.0,<0.8.0",
"blis>=0.7.8,<0.8.0",
"numpy>=1.15.0",
]
build-backend = "setuptools.build_meta"
10 changes: 6 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
murmurhash>=1.0.2,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
blis>=0.4.0,<0.8.0
blis>=0.7.8,<0.8.0
srsly>=2.4.0,<3.0.0
wasabi>=0.8.1,<1.1.0
catalogue>=2.0.4,<2.1.0
confection>=0.0.1,<1.0.0
ml_datasets>=0.2.0,<0.3.0
# Third-party dependencies
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0
numpy>=1.15.0
# Backports of modern Python features
dataclasses>=0.6,<1.0; python_version < "3.7"
typing_extensions>=3.7.4.1,<4.0.0.0; python_version < "3.8"
typing_extensions>=3.7.4.1,<4.2.0; python_version < "3.8"
contextvars>=2.4,<3; python_version < "3.7"
# Development dependencies
cython>=0.25.0,<3.0
Expand All @@ -22,7 +23,7 @@ pytest-cov>=2.7.0,<2.8.0
coverage>=5.0.0,<6.0.0
mock>=2.0.0,<3.0.0
flake8>=3.5.0,<3.6.0
mypy>=0.901,<0.960
mypy>=0.980,<0.990; platform_machine != "aarch64" and python_version >= "3.7"
types-mock>=0.1.1
types-contextvars>=0.1.2; python_version < "3.7"
types-dataclasses>=0.1.3; python_version < "3.7"
Expand All @@ -33,3 +34,4 @@ nbconvert>=5.6.1,<6.2.0
nbformat>=5.0.4,<5.2.0
# Test to_disk/from_disk against pathlib.Path subclasses
pathy>=0.3.5
black>=22.0,<23.0
21 changes: 17 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,29 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=1.0.2,<1.1.0
blis>=0.4.0,<0.8.0
blis>=0.7.8,<0.8.0
install_requires =
# Explosion-provided dependencies
blis>=0.4.0,<0.8.0
blis>=0.7.8,<0.8.0
murmurhash>=1.0.2,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
wasabi>=0.8.1,<1.1.0
srsly>=2.4.0,<3.0.0
catalogue>=2.0.4,<2.1.0
confection>=0.0.1,<1.0.0
# Third-party dependencies
setuptools
numpy>=1.15.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0
# Backports of modern Python features
dataclasses>=0.6,<1.0; python_version < "3.7"
typing_extensions>=3.7.4.1,<4.0.0.0; python_version < "3.8"
typing_extensions>=3.7.4.1,<4.2.0; python_version < "3.8"
contextvars>=2.4,<3; python_version < "3.7"

[options.entry_points]
pytest_randomly.random_seeder =
thinc = thinc.api:fix_random_seed

[options.extras_require]
cuda =
Expand Down Expand Up @@ -83,6 +88,14 @@ cuda114 =
cupy-cuda114>=5.0.0b4
cuda115 =
cupy-cuda115>=5.0.0b4
cuda116 =
cupy-cuda116>=5.0.0b4
cuda117 =
cupy-cuda117>=5.0.0b4
cuda11x =
cupy-cuda11x>=11.0.0
cuda-autodetect =
cupy-wheel>=11.0.0
datasets =
ml_datasets>=0.2.0,<0.3.0
torch =
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

PACKAGES = find_packages()
MOD_NAMES = [
"thinc.backends.cblas",
"thinc.backends.linalg",
"thinc.backends.numpy_ops",
"thinc.extra.search",
"thinc.layers.sparselinear",
]
COMPILE_OPTIONS = {
"msvc": ["/Ox", "/EHsc"],
"other": ["-O3", "-Wno-strict-prototypes", "-Wno-unused-function"],
"other": ["-O3", "-Wno-strict-prototypes", "-Wno-unused-function", "-std=c++11"],
}
COMPILER_DIRECTIVES = {
"language_level": -3,
Expand Down
2 changes: 1 addition & 1 deletion thinc/about.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "8.0.15"
__version__ = "8.1.3"
__release__ = True
7 changes: 5 additions & 2 deletions thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
from .util import DataValidationError, data_validation
from .util import to_categorical, get_width, get_array_module, to_numpy
from .util import torch2xp, xp2torch, tensorflow2xp, xp2tensorflow, mxnet2xp, xp2mxnet
from .util import get_torch_default_device
from .compat import has_cupy
from .backends import get_ops, set_current_ops, get_current_ops, use_ops
from .backends import Ops, CupyOps, NumpyOps, has_cupy, set_gpu_allocator
from .backends import Ops, CupyOps, MPSOps, NumpyOps, set_gpu_allocator
from .backends import use_pytorch_for_gpu_memory, use_tensorflow_for_gpu_memory

from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear
from .layers import Maxout, Mish, MultiSoftmax, Relu, softmax_activation, Softmax, LSTM
from .layers import CauchySimilarity, ParametricAttention, Logistic
from .layers import resizable, sigmoid_activation, Sigmoid, SparseLinear
from .layers import ClippedLinear, ReluK, HardTanh, HardSigmoid
from .layers import HardSwish, HardSwishMobilenet, Swish, Gelu
from .layers import Dish, HardSwish, HardSwishMobilenet, Swish, Gelu
from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM
from .layers import TensorFlowWrapper, keras_subclass, MXNetWrapper
from .layers import PyTorchWrapper_v2, Softmax_v2
Expand All @@ -38,6 +40,7 @@
from .layers import with_reshape, with_getitem, strings2arrays, list2array
from .layers import list2ragged, ragged2list, list2padded, padded2list, remap_ids
from .layers import array_getitem, with_cpu, with_debug, with_nvtx_range
from .layers import with_signpost_interval
from .layers import tuplify

from .layers import reduce_first, reduce_last, reduce_max, reduce_mean, reduce_sum
Expand Down
18 changes: 10 additions & 8 deletions thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import threading

from .ops import Ops
from .cupy_ops import CupyOps, has_cupy
from .cupy_ops import CupyOps
from .numpy_ops import NumpyOps
from .mps_ops import MPSOps
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array, set_torch_tensor_type_for_ops, require_cpu
from ..util import get_torch_default_device, is_cupy_array, require_cpu
from .. import registry
from ..compat import cupy, has_cupy


context_ops: ContextVar[Optional[Ops]] = ContextVar("context_ops", default=None)
Expand Down Expand Up @@ -46,9 +48,11 @@ def use_pytorch_for_gpu_memory() -> None: # pragma: no cover
We'd like to support routing Tensorflow memory allocation via PyTorch as well
(or vice versa), but do not currently have an implementation for it.
"""
import cupy.cuda

assert_pytorch_installed()

if get_torch_default_device().type != "cuda":
return

pools = context_pools.get()
if "pytorch" not in pools:
pools["pytorch"] = cupy.cuda.MemoryPool(allocator=cupy_pytorch_allocator)
Expand All @@ -65,8 +69,6 @@ def use_tensorflow_for_gpu_memory() -> None: # pragma: no cover
We'd like to support routing PyTorch memory allocation via Tensorflow as
well (or vice versa), but do not currently have an implementation for it.
"""
import cupy.cuda

assert_tensorflow_installed()
pools = context_pools.get()
if "tensorflow" not in pools:
Expand Down Expand Up @@ -94,7 +96,7 @@ def get_ops(name: str, **kwargs) -> Ops:

cls: Optional[Callable[..., Ops]] = None
if name == "cpu":
_import_extra_cpu_backends()
_import_extra_cpu_backends()
cls = ops_by_name.get("numpy")
cls = ops_by_name.get("apple", cls)
cls = ops_by_name.get("bigendian", cls)
Expand Down Expand Up @@ -137,7 +139,6 @@ def set_current_ops(ops: Ops) -> None:
"""Change the current backend object."""
context_ops.set(ops)
_get_thread_state().ops = ops
set_torch_tensor_type_for_ops(ops)


def contextvars_eq_thread_ops() -> bool:
Expand Down Expand Up @@ -173,6 +174,7 @@ def _create_thread_local(
"ParamServer",
"Ops",
"CupyOps",
"MPSOps",
"NumpyOps",
"has_cupy",
]
Loading