Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dartpain committed Dec 20, 2024
1 parent c2a95b5 commit 1f75f0c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
4 changes: 2 additions & 2 deletions application/llm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
self.AI_PROMPT = AI_PROMPT

def _raw_gen(
self, baseself, model, messages, stream=False, max_tokens=300, **kwargs
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
Expand All @@ -34,7 +34,7 @@ def _raw_gen(
return completion.completion

def _raw_gen_stream(
self, baseself, model, messages, stream=True, max_tokens=300, **kwargs
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
Expand Down
4 changes: 2 additions & 2 deletions application/llm/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
self.endpoint = settings.SAGEMAKER_ENDPOINT
self.runtime = runtime

def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
Expand Down Expand Up @@ -105,7 +105,7 @@ def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
print(result[0]["generated_text"], file=sys.stderr)
return result[0]["generated_text"][len(prompt) :]

def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
Expand Down
3 changes: 2 additions & 1 deletion tests/llm/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def test_gen_stream(self):
{"content": "question"}
]
mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")]
mock_tools = Mock()

with patch("application.cache.get_redis_instance") as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
mock_redis_instance.get.return_value = None
mock_redis_instance.set = Mock()

with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
responses = list(self.llm.gen_stream("test_model", messages))
responses = list(self.llm.gen_stream("test_model", messages, tools=mock_tools))
self.assertListEqual(responses, ["response_1", "response_2"])

prompt_expected = "### Context \n context \n ### Question \n question"
Expand Down
2 changes: 1 addition & 1 deletion tests/llm/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_gen_stream(self):

with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
return_value=self.response) as mock_invoke_endpoint:
output = list(self.sagemaker.gen_stream(None, self.messages))
output = list(self.sagemaker.gen_stream(None, self.messages, tools=None))
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Expand Down
25 changes: 14 additions & 11 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@ def test_make_gen_cache_key():
{'role': 'system', 'content': 'test_system_message'},
]
model = "test_docgpt"
tools = None

# Manually calculate the expected hash
expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}"
messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else ""
expected_combined = f"{model}_{messages_str}_{tools_str}"
expected_hash = get_hash(expected_combined)
cache_key = gen_cache_key(*messages, model=model)
cache_key = gen_cache_key(messages, model=model, tools=None)

assert cache_key == expected_hash

def test_gen_cache_key_invalid_message_format():
# Test when messages is not a list
with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context:
gen_cache_key("This is not a list", model="docgpt")
gen_cache_key("This is not a list", model="docgpt", tools=None)
assert str(context.exception) == "All messages must be dictionaries."

# Test for gen_cache decorator
Expand All @@ -35,14 +38,14 @@ def test_gen_cache_hit(mock_make_redis):
mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit

@gen_cache
def mock_function(self, model, messages):
def mock_function(self, model, messages, stream, tools):
return "new_result"

messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"

# Act
result = mock_function(None, model, messages)
result = mock_function(None, model, messages, stream=False, tools=None)

# Assert
assert result == "cached_result" # Should return cached result
Expand All @@ -58,7 +61,7 @@ def test_gen_cache_miss(mock_make_redis):
mock_redis_instance.get.return_value = None # Simulate a cache miss

@gen_cache
def mock_function(self, model, messages):
def mock_function(self, model, messages, steam, tools):
return "new_result"

messages = [
Expand All @@ -67,7 +70,7 @@ def mock_function(self, model, messages):
]
model = "test_docgpt"
# Act
result = mock_function(None, model, messages)
result = mock_function(None, model, messages, stream=False, tools=None)

# Assert
assert result == "new_result"
Expand All @@ -83,14 +86,14 @@ def test_stream_cache_hit(mock_make_redis):
mock_redis_instance.get.return_value = cached_chunk

@stream_cache
def mock_function(self, model, messages, stream):
def mock_function(self, model, messages, stream, tools):
yield "new_chunk"

messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"

# Act
result = list(mock_function(None, model, messages, stream=True))
result = list(mock_function(None, model, messages, stream=True, tools=None))

# Assert
assert result == ["chunk1", "chunk2"] # Should return cached chunks
Expand All @@ -106,7 +109,7 @@ def test_stream_cache_miss(mock_make_redis):
mock_redis_instance.get.return_value = None # Simulate a cache miss

@stream_cache
def mock_function(self, model, messages, stream):
def mock_function(self, model, messages, stream, tools):
yield "new_chunk"

messages = [
Expand All @@ -117,7 +120,7 @@ def mock_function(self, model, messages, stream):
model = "test_docgpt"

# Act
result = list(mock_function(None, model, messages, stream=True))
result = list(mock_function(None, model, messages, stream=True, tools=None))

# Assert
assert result == ["new_chunk"]
Expand Down

0 comments on commit 1f75f0c

Please sign in to comment.