diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 5f0f06a..5d87f9b 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -25,7 +25,7 @@ jobs: - name: Install package run: | - pip install ".[tensorflow,dev]" + pip install ".[dev]" - name: Build package run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2c95607..5f5fefe 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: - name: Install package run: | - pip install ".[tensorflow,onnx,dev]" + pip install ".[dev]" - name: Lint with flake8 run: | diff --git a/README.md b/README.md index d46ae36..8f87765 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,7 @@ You can use `pip` to install the latest version of UniSim: pip install unisim ``` -By default, UniSim uses [Onnx](https://github.com/onnx/onnx) when running on CPU, and [TensorFlow](https://www.tensorflow.org/) for GPU acceleration. If you have a GPU, you can additionally install TensorFlow using: - -``` -pip install unisim[tensorflow] -``` +By default, UniSim uses [Onnx](https://github.com/onnx/onnx) as the runtime. You can switch to using TensorFlow by setting the `BACKEND` environment variable (e.g. `os.environ["BACKEND"] = "tf"`). ## Text UniSim (TextSim) diff --git a/setup.py b/setup.py index 41bb532..c31af56 100644 --- a/setup.py +++ b/setup.py @@ -37,9 +37,18 @@ def get_version(rel_path): author_email="unisim@google.com", url="https://github.com/google/unisim", license="MIT", - install_requires=["tabulate", "numpy", "tqdm", "onnx", "jaxtyping", "onnxruntime", "pandas", "usearch>=2.6.0"], + install_requires=[ + "tabulate", + "numpy", + "tqdm", + "onnx", + "jaxtyping", + "onnxruntime-gpu", + "pandas", + "tensorflow>=2.11,<2.16", + "usearch>=2.6.0", + ], extras_require={ - "tensorflow": ["tensorflow>=2.11"], "dev": [ "datasets", "mypy", diff --git a/unisim/backend/load_backend.py b/unisim/backend/load_backend.py index 823f39f..6f6301b 100644 --- a/unisim/backend/load_backend.py +++ b/unisim/backend/load_backend.py @@ -30,18 +30,6 @@ except ImportError: TF_AVAILABLE = False -# detect accelerator -if TF_AVAILABLE or get_backend() == BackendType.tf: - devices_types = [d.device_type for d in tf.config.list_physical_devices()] - - if "GPU" in devices_types: - set_accelerator(AcceleratorType.gpu) - else: - set_accelerator(AcceleratorType.cpu) - -else: - set_accelerator(AcceleratorType.cpu) - # choose backend if not set by user accel = get_accelerator() backend = get_backend() @@ -62,10 +50,15 @@ # post detection if get_backend() == BackendType.onnx: + import onnxruntime as rt + from .onnx import * # noqa: F403, F401 # FIXME onnx accelerator type support - set_accelerator(AcceleratorType.cpu) + if rt.get_device() == "GPU": + set_accelerator(AcceleratorType.gpu) + else: + set_accelerator(AcceleratorType.cpu) elif get_backend() == BackendType.tf: from .tf import * # type: ignore # noqa: F403, F401