Skip to content

Commit

Permalink
Merge pull request #14 from google/development
Browse files Browse the repository at this point in the history
UniSim defaults to Onnx runtime on GPU
  • Loading branch information
MarinaZhang authored May 16, 2024
2 parents c013259 + 5c48b65 commit 0e2e3a0
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install package
run: |
pip install ".[tensorflow,dev]"
pip install ".[dev]"
- name: Build package
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install package
run: |
pip install ".[tensorflow,onnx,dev]"
pip install ".[dev]"
- name: Lint with flake8
run: |
Expand Down
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ You can use `pip` to install the latest version of UniSim:
pip install unisim
```

By default, UniSim uses [Onnx](https://github.com/onnx/onnx) when running on CPU, and [TensorFlow](https://www.tensorflow.org/) for GPU acceleration. If you have a GPU, you can additionally install TensorFlow using:

```
pip install unisim[tensorflow]
```
By default, UniSim uses [Onnx](https://github.com/onnx/onnx) as the runtime. You can switch to using TensorFlow by setting the `BACKEND` environment variable (e.g. `os.environ["BACKEND"] = "tf"`).

## Text UniSim (TextSim)

Expand Down
13 changes: 11 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,18 @@ def get_version(rel_path):
author_email="unisim@google.com",
url="https://github.com/google/unisim",
license="MIT",
install_requires=["tabulate", "numpy", "tqdm", "onnx", "jaxtyping", "onnxruntime", "pandas", "usearch>=2.6.0"],
install_requires=[
"tabulate",
"numpy",
"tqdm",
"onnx",
"jaxtyping",
"onnxruntime-gpu",
"pandas",
"tensorflow>=2.11,<2.16",
"usearch>=2.6.0",
],
extras_require={
"tensorflow": ["tensorflow>=2.11"],
"dev": [
"datasets",
"mypy",
Expand Down
19 changes: 6 additions & 13 deletions unisim/backend/load_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,6 @@
except ImportError:
TF_AVAILABLE = False

# detect accelerator
if TF_AVAILABLE or get_backend() == BackendType.tf:
devices_types = [d.device_type for d in tf.config.list_physical_devices()]

if "GPU" in devices_types:
set_accelerator(AcceleratorType.gpu)
else:
set_accelerator(AcceleratorType.cpu)

else:
set_accelerator(AcceleratorType.cpu)

# choose backend if not set by user
accel = get_accelerator()
backend = get_backend()
Expand All @@ -62,10 +50,15 @@

# post detection
if get_backend() == BackendType.onnx:
import onnxruntime as rt

from .onnx import * # noqa: F403, F401

# FIXME onnx accelerator type support
set_accelerator(AcceleratorType.cpu)
if rt.get_device() == "GPU":
set_accelerator(AcceleratorType.gpu)
else:
set_accelerator(AcceleratorType.cpu)

elif get_backend() == BackendType.tf:
from .tf import * # type: ignore # noqa: F403, F401
Expand Down

0 comments on commit 0e2e3a0

Please sign in to comment.