Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support configuring precision and quantization in HuggingFaceClient #1912

Merged
merged 3 commits into from
Dec 11, 2023

Conversation

yifanmai
Copy link
Collaborator

Support passing the torch_dtype, load_in_8bit and load_in_4bit parameters to HuggingFaceClient from the ModelDeployment configuration.

Example usage:

prod_env/model_deployments.yaml:

model_deployments:
  # Example model with precision set to bfloat16
  - name: yifanmai/gpt2-bfloat16
    tokenizer_name: "huggingface/gpt2"
    max_sequence_length: 1024
    window_service_spec:
      class_name: "helm.benchmark.window_services.huggingface_window_service.HuggingFaceWindowService"
      args:
        pretrained_model_name_or_path: gpt2
    client_spec:
      class_name: "helm.proxy.clients.huggingface_client.HuggingFaceClient"
      args:
        pretrained_model_name_or_path: gpt2
        torch_dtype: torch.bfloat16
  # Example model with 8-bit quantization
  - name: yifanmai/gpt2-8bit
    tokenizer_name: "huggingface/gpt2"
    max_sequence_length: 1024
    window_service_spec:
      class_name: "helm.benchmark.window_services.huggingface_window_service.HuggingFaceWindowService"
      args:
        pretrained_model_name_or_path: gpt2
    client_spec:
      class_name: "helm.proxy.clients.huggingface_client.HuggingFaceClient"
      args:
        pretrained_model_name_or_path: gpt2
        load_in_8bit: true

Also clean up revision - treat it as as another kwarg rather than treating it specially.

@yifanmai
Copy link
Collaborator Author

Also, eventually we might require trust_remote_code=True to be set explicitly in the same way.

Copy link
Contributor

@JosselinSomervilleRoberts JosselinSomervilleRoberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -56,23 +56,20 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
class HuggingFaceServer:
"""A thin wrapper around a Hugging Face AutoModelForCausalLM for HuggingFaceClient to call."""

def __init__(self, pretrained_model_name_or_path: str, revision: Optional[str] = None):
def __init__(self, pretrained_model_name_or_path: str, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a comment here describing common kwargs that should be specified such as revision, precision, ... ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to do this in the docs instead.

@yifanmai yifanmai force-pushed the yifanmai/fix-huggingface-precision branch from ee803fa to 21898b8 Compare November 2, 2023 16:54
@JosselinSomervilleRoberts
Copy link
Contributor

I think this is now handled by #1903 and we can close this

@yifanmai yifanmai force-pushed the yifanmai/fix-huggingface-precision branch from 21898b8 to c882746 Compare December 2, 2023 01:20
@yifanmai yifanmai force-pushed the yifanmai/fix-huggingface-precision branch from be458e3 to 982cdf9 Compare December 11, 2023 22:56
@yifanmai yifanmai merged commit c9b4a7e into main Dec 11, 2023
6 checks passed
@yifanmai yifanmai deleted the yifanmai/fix-huggingface-precision branch December 11, 2023 23:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants