Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: configurable plugins #711

Merged
merged 30 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d2d77bf
refactor test strings for readability
jmartin-tech May 20, 2024
d720bfb
fixture for site config testing
jmartin-tech May 20, 2024
dace23b
add plugin general config
jmartin-tech May 20, 2024
f925b3d
consolidate some plugin option processing
jmartin-tech May 21, 2024
b4a1089
heiarchy based config in yaml and json
jmartin-tech May 21, 2024
948beb7
improve rest code reuse
jmartin-tech May 23, 2024
120fe4f
enforce only supported_params when available
jmartin-tech May 23, 2024
fac3955
consistent ENV_VAR class level constants
jmartin-tech May 24, 2024
9b70b5d
continue configurable refactor
jmartin-tech May 27, 2024
c0ba6d0
configurable plugins support
jmartin-tech May 30, 2024
323e202
rest generations should be from the constructor or generators config
jmartin-tech May 29, 2024
b5f250b
shift plugin docs link attribute uri->doc_uri
jmartin-tech May 31, 2024
c77c0c1
refactor logic for DEFAULT_CLASS selection
jmartin-tech May 31, 2024
26876e9
more flashy message
jmartin-tech May 31, 2024
12b3706
set instance name early & remove stray comment
jmartin-tech May 31, 2024
a4ec848
configurable `device_map` in `LLaVA`
jmartin-tech May 31, 2024
5bc0955
Clarify configurable expectations
jmartin-tech May 31, 2024
640cdb0
fix invalid assignment for generator_name in interactive
jmartin-tech May 31, 2024
843553f
align dtype as torch_dtype keyword param
jmartin-tech Jun 3, 2024
f644722
adjust tap generator creation to use `config_root`
jmartin-tech Jun 3, 2024
6247c2f
test plugin `_supported_params` include `DEFAULT_PARAMS`
jmartin-tech Jun 3, 2024
54da223
Merge 'origin/main' into feature/configurable-plugins
jmartin-tech Jun 3, 2024
4d1120b
add missing SPDX headers
jmartin-tech Jun 3, 2024
13acb3c
guard for config file argument exists
jmartin-tech Jun 3, 2024
17a9ba9
rollback overzealous rename
jmartin-tech Jun 3, 2024
b8c401b
initializer adjustment in huggingface and nemo
jmartin-tech Jun 4, 2024
7839ee3
configurable provides `_validate_env_var()`
jmartin-tech Jun 4, 2024
a4ec673
`_load_config` calls `_validate_env_vars` for api keys
jmartin-tech Jun 5, 2024
0120665
ensure access to instance variables after `_load_config`
jmartin-tech Jun 5, 2024
5d90543
do not mutate `name` when initializing huggingface
jmartin-tech Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 47 additions & 38 deletions garak/_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@

from garak import _config

PLUGIN_TYPES = ("probes", "detectors", "generators", "harnesses", "buffs")
PLUGIN_CLASSES = ("Probe", "Detector", "Generator", "Harness", "Buff")


@staticmethod
def _extract_modules_klasses(base_klass):
return [ # Extract only classes with same source package name
name
for name, klass in inspect.getmembers(base_klass, inspect.isclass)
if klass.__module__.startswith(base_klass.__name__)
]


def enumerate_plugins(
category: str = "probes", skip_base_classes=True
Expand All @@ -31,27 +43,12 @@ def enumerate_plugins(
:type category: str
"""

if category not in ("probes", "detectors", "generators", "harnesses", "buffs"):
if category not in PLUGIN_TYPES:
raise ValueError("Not a recognised plugin type:", category)

base_mod = importlib.import_module(f"garak.{category}.base")

if category == "harnesses":
root_plugin_classname = "Harness"
else:
root_plugin_classname = category.title()[:-1]

base_plugin_classnames = set(
[
# be careful with what's imported into base modules
n
for n in dir(base_mod) # everything in the module ..
if "__class__" in dir(getattr(base_mod, n)) # .. that's a class ..
and getattr(base_mod, n).__class__.__name__ # .. and not a base class
== "type"
]
+ [root_plugin_classname]
)
base_plugin_classnames = set(_extract_modules_klasses(base_mod))
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved

plugin_class_names = []

Expand All @@ -63,17 +60,18 @@ def enumerate_plugins(
if module_filename == "base.py" and skip_base_classes:
continue
module_name = module_filename.replace(".py", "")
mod = importlib.import_module(f"garak.{category}.{module_name}")
module_entries = set(
[entry for entry in dir(mod) if not entry.startswith("__")]
)
mod = importlib.import_module(
f"garak.{category}.{module_name}"
) # import here will access all namespace level imports consider a cache to speed up processing
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
module_entries = set(_extract_modules_klasses(mod))
if skip_base_classes:
module_entries = module_entries.difference(base_plugin_classnames)
module_plugin_names = set()
for module_entry in module_entries:
obj = getattr(mod, module_entry)
if inspect.isclass(obj):
if obj.__bases__[-1].__name__ in base_plugin_classnames:
for interface in base_plugin_classnames:
klass = getattr(base_mod, interface)
if issubclass(obj, klass):
module_plugin_names.add((module_entry, obj.active))

for module_plugin_name, active in sorted(module_plugin_names):
Expand All @@ -83,17 +81,7 @@ def enumerate_plugins(
return plugin_class_names


def configure_plugin(plugin_path: str, plugin: object) -> object:
category, module_name, plugin_class_name = plugin_path.split(".")
plugin_name = f"{module_name}.{plugin_class_name}"
plugin_type_config = getattr(_config.plugins, category)
if plugin_name in plugin_type_config:
for k, v in plugin_type_config[plugin_name].items():
setattr(plugin, k, v)
return plugin


def load_plugin(path, break_on_fail=True) -> object:
def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
"""load_plugin takes a path to a plugin class, and attempts to load that class.
If successful, it returns an instance of that class.

Expand All @@ -104,7 +92,25 @@ def load_plugin(path, break_on_fail=True) -> object:
:type break_on_fail: bool
"""
try:
category, module_name, plugin_class_name = path.split(".")
parts = path.split(".")
match len(parts):
case 2:
category, module_name = parts
generator_mod = importlib.import_module(
f"garak.{category}.{module_name}"
)
if generator_mod.DEFAULT_CLASS:
plugin_class_name = generator_mod.DEFAULT_CLASS
else:
raise ValueError(
"module {module_name} has no default class; pass module.ClassName to model_type"
)
case 3:
category, module_name, plugin_class_name = parts
case _:
raise ValueError(
f"Attempted to load {path} with unexpected number of tokens."
)
except ValueError as ve:
if break_on_fail:
raise ValueError(
Expand All @@ -123,7 +129,12 @@ def load_plugin(path, break_on_fail=True) -> object:
return False

try:
plugin_instance = getattr(mod, plugin_class_name)()
klass = getattr(mod, plugin_class_name)
if "config_root" not in inspect.signature(klass.__init__).parameters:
raise AttributeError(
'Incompatible function signature: "config_root" is incompatible with this plugin'
)
plugin_instance = klass(config_root=config_root)
except AttributeError as ae:
logging.warning(
"Exception failed instantiation of %s.%s", module_path, plugin_class_name
Expand All @@ -144,6 +155,4 @@ def load_plugin(path, break_on_fail=True) -> object:
else:
return False

plugin_instance = configure_plugin(path, plugin_instance)

return plugin_instance
9 changes: 6 additions & 3 deletions garak/buffs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import tqdm

import garak.attempt
from garak import _config
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
from garak.configurable import Configurable


class Buff:
class Buff(Configurable):
"""Base class for a buff.

A buff should take as input a list of attempts, and return
Expand All @@ -27,11 +29,12 @@ class Buff:
of derivative attempt objects.
"""

uri = ""
doc_uri = ""
bcp47 = None # set of languages this buff should be constrained to
active = True

def __init__(self) -> None:
def __init__(self, config_root=_config) -> None:
self._load_config(config_root)
module = self.__class__.__module__.replace("garak.buffs.", "")
self.fullname = f"{module}.{self.__class__.__name__}"
self.post_buff_hook = False
Expand Down
14 changes: 8 additions & 6 deletions garak/buffs/low_resource_languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from os import getenv

import garak.attempt
from garak import _config
from garak.buffs.base import Buff

# Low resource languages supported by DeepL
Expand All @@ -26,21 +27,22 @@ class LRLBuff(Buff):

Uses the DeepL API to translate prompts into low-resource languages"""

uri = "https://arxiv.org/abs/2310.02446"
ENV_VAR = "DEEPL_API_KEY"
doc_uri = "https://arxiv.org/abs/2310.02446"

api_key_error_sent = False

def __init__(self):
super().__init__()
def __init__(self, config_root=_config):
super().__init__(config_root=config_root)
self.post_buff_hook = True

def transform(
self, attempt: garak.attempt.Attempt
) -> Iterable[garak.attempt.Attempt]:
api_key = getenv("DEEPL_API_KEY", None)
api_key = getenv(self.ENV_VAR, None)
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
if api_key is None:
if not self.api_key_error_sent:
msg = "DEEPL_API_KEY not set in env, cannot use LRLBuff."
msg = f"{self.ENV_VAR} not set in env, cannot use LRLBuff."
user_msg = (
msg
+ " If you do not have a DeepL API key, sign up at https://www.deepl.com/pro#developer"
Expand All @@ -62,7 +64,7 @@ def transform(
yield self._derive_new_attempt(attempt)

def untransform(self, attempt: garak.attempt.Attempt) -> garak.attempt.Attempt:
api_key = getenv("DEEPL_API_KEY", None)
api_key = getenv(self.ENV_VAR, None)
translator = Translator(api_key)
outputs = attempt.outputs
attempt.notes["original_responses"] = outputs
Expand Down
13 changes: 7 additions & 6 deletions garak/buffs/paraphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
from collections.abc import Iterable

import garak.attempt
from garak import _config
from garak.buffs.base import Buff


class PegasusT5(Buff):
"""Paraphrasing buff using Pegasus model"""

bcp47 = "en"
uri = "https://huggingface.co/tuner007/pegasus_paraphrase"
doc_uri = "https://huggingface.co/tuner007/pegasus_paraphrase"

def __init__(self) -> None:
super().__init__()
def __init__(self, config_root=_config) -> None:
self.para_model_name = "tuner007/pegasus_paraphrase" # https://huggingface.co/tuner007/pegasus_paraphrase
self.max_length = 60
self.temperature = 1.5
Expand All @@ -25,6 +25,7 @@ def __init__(self) -> None:
self.torch_device = None
self.tokenizer = None
self.para_model = None
super().__init__(config_root=config_root)

def _load_model(self):
import torch
Expand Down Expand Up @@ -72,10 +73,9 @@ class Fast(Buff):
"""CPU-friendly paraphrase buff based on Humarin's T5 paraphraser"""

bcp47 = "en"
uri = "https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base"
doc_uri = "https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base"

def __init__(self) -> None:
super().__init__()
def __init__(self, config_root=_config) -> None:
self.para_model_name = "humarin/chatgpt_paraphraser_on_T5_base"
self.num_beams = 5
self.num_beam_groups = 5
Expand All @@ -88,6 +88,7 @@ def __init__(self) -> None:
self.torch_device = None
self.tokenizer = None
self.para_model = None
super().__init__(config_root=config_root)

def _load_model(self):
import torch
Expand Down
Loading
Loading