-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[OPIK-318] - [Prompt library] SDK Prompt implementation - create/get (#…
…557) * update openapi spec * update openapi spec (with prompt) * update openapi spec (WIP) * fix openapi for fern * draft implementation of create_prompt() and get_prompt() * update openapi spec * create/get return PromptDetail * use public Prompt class WIP * fix circular import * add new method create_prompt_detail * update openapi spec * add e2e tests * move prompt-related api calls to new file * remove description param * rename Prompt.template to Prompt.prompt * use double curly brackets in prompts template * make id field private * simplified solution
- Loading branch information
1 parent
73ee9b1
commit 02f2835
Showing
30 changed files
with
2,718 additions
and
29 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .prompt import Prompt | ||
|
||
__all__ = [ | ||
"Prompt", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from typing import Optional | ||
|
||
from opik import Prompt | ||
from opik.rest_api import PromptVersionDetail, client as rest_client | ||
from opik.rest_api.core import ApiError | ||
|
||
|
||
class PromptClient: | ||
def __init__(self, client: rest_client.OpikApi): | ||
self._rest_client = client | ||
|
||
def create_prompt( | ||
self, | ||
name: str, | ||
prompt: str, | ||
) -> Prompt: | ||
""" | ||
Creates the prompt detail for the given prompt name and template. | ||
Parameters: | ||
- name: The name of the prompt. | ||
- prompt: The template content for the prompt. | ||
Returns: | ||
- A Prompt object for the provided prompt name and template. | ||
""" | ||
prompt_version = self._get_latest_version(name) | ||
|
||
if prompt_version is None or prompt_version.template != prompt: | ||
prompt_version = self._create_new_version(name=name, prompt=prompt) | ||
|
||
prompt_obj = Prompt.from_fern_prompt_version( | ||
name=name, prompt_version=prompt_version | ||
) | ||
|
||
return prompt_obj | ||
|
||
def _create_new_version( | ||
self, | ||
name: str, | ||
prompt: str, | ||
) -> PromptVersionDetail: | ||
new_prompt_version_detail_data = PromptVersionDetail(template=prompt) | ||
new_prompt_version_detail: PromptVersionDetail = ( | ||
self._rest_client.prompts.create_prompt_version( | ||
name=name, | ||
version=new_prompt_version_detail_data, | ||
) | ||
) | ||
return new_prompt_version_detail | ||
|
||
def _get_latest_version(self, name: str) -> Optional[PromptVersionDetail]: | ||
try: | ||
prompt_latest_version = self._rest_client.prompts.retrieve_prompt_version( | ||
name=name | ||
) | ||
return prompt_latest_version | ||
except ApiError as e: | ||
if e.status_code != 404: | ||
raise e | ||
return None | ||
|
||
def get_prompt( | ||
self, | ||
name: str, | ||
commit: Optional[str] = None, | ||
) -> Optional[Prompt]: | ||
""" | ||
Retrieve the prompt detail for a given prompt name and commit version. | ||
Parameters: | ||
name: The name of the prompt. | ||
commit: An optional commit version of the prompt. If not provided, the latest version is retrieved. | ||
Returns: | ||
Prompt: The details of the specified prompt. | ||
""" | ||
try: | ||
prompt_version = self._rest_client.prompts.retrieve_prompt_version( | ||
name=name, | ||
commit=commit, | ||
) | ||
prompt_obj = Prompt.from_fern_prompt_version( | ||
name=name, prompt_version=prompt_version | ||
) | ||
|
||
return prompt_obj | ||
|
||
except ApiError as e: | ||
if e.status_code != 404: | ||
raise e | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Any | ||
|
||
from opik.rest_api import PromptVersionDetail | ||
|
||
|
||
class Prompt: | ||
""" | ||
Prompt class represents a prompt with a name, prompt text/template and commit hash. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
prompt: str, | ||
) -> None: | ||
""" | ||
Initializes a new instance of the class with the given parameters. | ||
Creates a new prompt using the opik client and sets the initial state of the instance attributes based on the created prompt. | ||
Parameters: | ||
name: The name for the prompt. | ||
prompt: The template for the prompt. | ||
""" | ||
# we will import opik client here to avoid circular import issue | ||
from opik.api_objects import opik_client | ||
|
||
client = opik_client.get_client_cached() | ||
|
||
new_instance = client.create_prompt( | ||
name=name, | ||
prompt=prompt, | ||
) | ||
self._name = new_instance.name | ||
self._prompt = new_instance.prompt | ||
self._commit = new_instance.commit | ||
self.__internal_api__prompt_id__: str = new_instance.__internal_api__prompt_id__ | ||
|
||
@property | ||
def name(self) -> str: | ||
"""The name of the prompt.""" | ||
return self._name | ||
|
||
@property | ||
def prompt(self) -> str: | ||
"""The latest template of the prompt.""" | ||
return self._prompt | ||
|
||
@property | ||
def commit(self) -> str: | ||
"""The commit hash of the prompt.""" | ||
return self._commit | ||
|
||
def format(self, **kwargs: Any) -> str: | ||
""" | ||
Replaces placeholders in the template with provided keyword arguments. | ||
Args: | ||
**kwargs: Arbitrary keyword arguments where the key represents the placeholder | ||
in the template and the value is the value to replace the placeholder with. | ||
Returns: | ||
A string with all placeholders replaced by their corresponding values from kwargs. | ||
""" | ||
template = self._prompt | ||
for key, value in kwargs.items(): | ||
template = template.replace(f"{{{{{key}}}}}", str(value)) | ||
return template | ||
|
||
@classmethod | ||
def from_fern_prompt_version( | ||
cls, | ||
name: str, | ||
prompt_version: PromptVersionDetail, | ||
) -> "Prompt": | ||
# will not call __init__ to avoid API calls, create new instance with __new__ | ||
prompt = cls.__new__(cls) | ||
|
||
prompt.__internal_api__prompt_id__ = prompt_version.id | ||
prompt._name = name | ||
prompt._prompt = prompt_version.template | ||
prompt._commit = prompt_version.commit | ||
|
||
return prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.