From aaa8f043c0d3863d3091034f3b486f8ba8e11de5 Mon Sep 17 00:00:00 2001 From: colin-sentry <161344340+colin-sentry@users.noreply.github.com> Date: Fri, 3 May 2024 10:22:46 -0400 Subject: [PATCH] Reduce API cross-section for huggingface in test (#3042) --- .../huggingface_hub/test_huggingface_hub.py | 56 ++++++------------- 1 file changed, 16 insertions(+), 40 deletions(-) diff --git a/tests/integrations/huggingface_hub/test_huggingface_hub.py b/tests/integrations/huggingface_hub/test_huggingface_hub.py index 062bd4fb31..734778d08a 100644 --- a/tests/integrations/huggingface_hub/test_huggingface_hub.py +++ b/tests/integrations/huggingface_hub/test_huggingface_hub.py @@ -1,14 +1,8 @@ import itertools -import json import pytest from huggingface_hub import ( InferenceClient, - TextGenerationOutput, - TextGenerationOutputDetails, - TextGenerationStreamOutput, - TextGenerationOutputToken, - TextGenerationStreamDetails, ) from huggingface_hub.errors import OverloadedError @@ -35,19 +29,15 @@ def test_nonstreaming_chat_completion( client = InferenceClient("some-model") if details_arg: client.post = mock.Mock( - return_value=json.dumps( - [ - TextGenerationOutput( - generated_text="the model response", - details=TextGenerationOutputDetails( - finish_reason="TextGenerationFinishReason", - generated_tokens=10, - prefill=[], - tokens=[], # not needed for integration - ), - ) - ] - ).encode("utf-8") + return_value=b"""[{ + "generated_text": "the model response", + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "tokens": [] + } + }]""" ) else: client.post = mock.Mock( @@ -96,27 +86,13 @@ def test_streaming_chat_completion( client = InferenceClient("some-model") client.post = mock.Mock( return_value=[ - b"data:" - + json.dumps( - TextGenerationStreamOutput( - token=TextGenerationOutputToken( - id=1, special=False, text="the model " - ), - ), - ).encode("utf-8"), - b"data:" - + json.dumps( - TextGenerationStreamOutput( - token=TextGenerationOutputToken( - id=2, special=False, text="response" - ), - details=TextGenerationStreamDetails( - finish_reason="length", - generated_tokens=10, - seed=0, - ), - ) - ).encode("utf-8"), + b"""data:{ + "token":{"id":1, "special": false, "text": "the model "} + }""", + b"""data:{ + "token":{"id":2, "special": false, "text": "response"}, + "details":{"finish_reason": "length", "generated_tokens": 10, "seed": 0} + }""", ] ) with start_transaction(name="huggingface_hub tx"):