Skip to content

Commit

Permalink
corrections to config format for TAP internal generators (#796)
Browse files Browse the repository at this point in the history
* corrections to config format for TAP internal generators

* ensure `name` is set
* wrap config in required `Configurable` structure

Signed-off-by: Jeffrey Martin <jemartin@nvidia.com>

* ensure device is passed for huggingface generator

Signed-off-by: Jeffrey Martin <jemartin@nvidia.com>

---------

Signed-off-by: Jeffrey Martin <jemartin@nvidia.com>
  • Loading branch information
jmartin-tech authored Jul 22, 2024
1 parent 8f7eaa4 commit 6b40f03
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions garak/resources/tap/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def load_generator(
"""

config = {
"name": model_name,
"generations": generations,
"max_tokens": max_tokens,
}
Expand All @@ -55,12 +56,22 @@ def load_generator(

if model_name.lower() in hf_dict.keys():
config["name"] = hf_dict[model_name]
config["device"] = device

if model_name in supported_openai:
generator = OpenAIGenerator(config_root=config)
config_root = {
"generators": {
OpenAIGenerator.__module__.split(".")[-1]: {
OpenAIGenerator.__name__: config
}
}
}
generator = OpenAIGenerator(config_root=config_root)
elif model_name in supported_huggingface:
generator = Model(config_root=config)
config["hf_args"] = {"device": device}
config_root = {
"generators": {Model.__module__.split(".")[-1]: {Model.__name__: config}}
}
generator = Model(config_root=config_root)
else:
msg = (
f"{model_name} is not currently supported for TAP generation. Support is available for the following "
Expand Down

0 comments on commit 6b40f03

Please sign in to comment.