diff --git a/log10/feedback/_summary_feedback_utils.py b/log10/feedback/_summary_feedback_utils.py index 2dfa1a42..c487d01e 100644 --- a/log10/feedback/_summary_feedback_utils.py +++ b/log10/feedback/_summary_feedback_utils.py @@ -72,12 +72,12 @@ SystemMessage(SUMMARY_SYSTEM_PROMPT), UserMessage(SUMMARY_USER_MESSAGE), UserMessage("Examples: \n{examples}\n\nTest: \n{prompt}"), - model=OpenaiChatModel("gpt-3.5-turbo", temperature=0.2), + model=OpenaiChatModel("gpt-4-0125-preview", temperature=0.2), ) def summary_feedback_llm_call(examples, prompt) -> str: ... -def get_prompt_response(completion: dict) -> dict: +def flatten_messages(completion: dict) -> dict: request_messages = completion.get("request", {}).get("messages", []) if len(request_messages) > 1 and request_messages[1].get("content", ""): prompt = request_messages[1].get("content") diff --git a/log10/feedback/autofeedback.py b/log10/feedback/autofeedback.py index c264084a..5c4516ff 100644 --- a/log10/feedback/autofeedback.py +++ b/log10/feedback/autofeedback.py @@ -5,9 +5,10 @@ import click import openai +from rich.console import Console from log10.completions.completions import _get_completion -from log10.feedback._summary_feedback_utils import get_prompt_response, summary_feedback_llm_call +from log10.feedback._summary_feedback_utils import flatten_messages, summary_feedback_llm_call from log10.feedback.feedback import _get_feedback_list from log10.load import log10, log10_session @@ -55,6 +56,7 @@ def _get_examples(self): "feedback": json.dumps(feedback_values), } ) + logger.info(f"Sampled completion ids: {[d['completion_id'] for d in few_shot_examples]}") return few_shot_examples def predict(self, text: str = None, completion_id: str = None) -> str: @@ -64,7 +66,7 @@ def predict(self, text: str = None, completion_id: str = None) -> str: # Here assumps the completion is summary, prompt is article, response is summary if completion_id and not text: completion = _get_completion(completion_id) - pr = get_prompt_response(completion.json()["data"]) + pr = flatten_messages(completion.json()["data"]) text = json.dumps(pr) logger.info(f"{text=}") @@ -88,14 +90,15 @@ def auto_feedback_icl(task_id: str, content: str, file: str, completion_id: str, click.echo("Only one of --content, --file, or --completion_id should be provided.") return + console = Console() auto_feedback_icl = AutoFeedbackICL(task_id, num_samples=num_samples) if completion_id: results = auto_feedback_icl.predict(completion_id=completion_id) - click.echo(results) + console.print_json(results) return if file: with open(file, "r") as f: content = f.read() results = auto_feedback_icl.predict(text=content) - click.echo(results) + console.print_json(results)