From 307fe776ab2cbfbb49f6468d4b693564004a4d09 Mon Sep 17 00:00:00 2001 From: Jeffrey Martin Date: Wed, 11 Dec 2024 16:26:37 -0600 Subject: [PATCH] pass device to transfomers pipeline When a `pipeline` is created the base class attempts to auto-detect the optimal hardware, since the project accepts configuration for hardware device selection the device must be passed. Signed-off-by: Jeffrey Martin --- garak/detectors/base.py | 4 +++- pyproject.toml | 2 +- requirements.txt | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/garak/detectors/base.py b/garak/detectors/base.py index 82770ba35..82bde822a 100644 --- a/garak/detectors/base.py +++ b/garak/detectors/base.py @@ -120,7 +120,9 @@ def __init__(self, config_root=_config): self.detector_model_path ) self.detector = TextClassificationPipeline( - model=self.detector_model, tokenizer=self.detector_tokenizer + model=self.detector_model, + tokenizer=self.detector_tokenizer, + device=self.device, ) transformers_logging.set_verbosity(orig_loglevel) diff --git a/pyproject.toml b/pyproject.toml index 31f6c1a7a..43bc10004 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "base2048>=0.1.3", - "transformers>=4.43.0,<4.47.0", + "transformers>=4.43.0", "datasets>=2.14.6,<2.17", "colorama>=0.4.3", "tqdm>=4.64.0", diff --git a/requirements.txt b/requirements.txt index 2d3b71292..50de30fe5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ base2048>=0.1.3 -transformers>=4.43.0,<4.47.0 +transformers>=4.43.0 datasets>=2.14.6,<2.17 colorama>=0.4.3 tqdm>=4.64.0