Skip to content

Commit

Permalink
Expose embeddings API (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix authored May 28, 2024
1 parent 636bc0e commit 7386cd0
Showing 1 changed file with 43 additions and 26 deletions.
69 changes: 43 additions & 26 deletions edgedb/ai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ def with_context(self, **kwargs) -> typing.Self:
rv.client = self.client
return rv

def _make_rag_request(
self,
*,
message: str,
context: typing.Optional[types.QueryContext] = None,
stream: bool,
) -> types.RAGRequest:
if context is None:
context = self.context
return types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
context=context,
query=message,
stream=stream,
)


class EdgeDBAI(BaseEdgeDBAI):
client: httpx.Client
Expand All @@ -95,14 +112,10 @@ def _init_client(self, **kwargs):
def query_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
) -> str:
if context is None:
context = self.context
resp = self.client.post(
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
**self._make_rag_request(
context=context,
query=message,
message=message,
stream=False,
).to_httpx_request()
)
Expand All @@ -111,24 +124,27 @@ def query_rag(

def stream_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
):
if context is None:
context = self.context
) -> typing.Iterator[str]:
with httpx_sse.connect_sse(
self.client,
"post",
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
**self._make_rag_request(
context=context,
query=message,
message=message,
stream=True,
).to_httpx_request(),
) as event_source:
event_source.response.raise_for_status()
for sse in event_source.iter_sse():
yield sse.data

def generate_embeddings(self, *inputs: str, model: str) -> list[float]:
resp = self.client.post(
"/embeddings", json={"input": inputs, "model": model}
)
resp.raise_for_status()
return resp.json()["data"][0]["embedding"]


class AsyncEdgeDBAI(BaseEdgeDBAI):
client: httpx.AsyncClient
Expand All @@ -139,14 +155,10 @@ def _init_client(self, **kwargs):
async def query_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
) -> str:
if context is None:
context = self.context
resp = await self.client.post(
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
**self._make_rag_request(
context=context,
query=message,
message=message,
stream=False,
).to_httpx_request()
)
Expand All @@ -155,20 +167,25 @@ async def query_rag(

async def stream_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
):
if context is None:
context = self.context
) -> typing.Iterator[str]:
async with httpx_sse.aconnect_sse(
self.client,
"post",
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
**self._make_rag_request(
context=context,
query=message,
message=message,
stream=True,
).to_httpx_request(),
) as event_source:
event_source.response.raise_for_status()
async for sse in event_source.aiter_sse():
yield sse.data

async def generate_embeddings(
self, *inputs: str, model: str
) -> list[float]:
resp = await self.client.post(
"/embeddings", json={"input": inputs, "model": model}
)
resp.raise_for_status()
return resp.json()["data"][0]["embedding"]

0 comments on commit 7386cd0

Please sign in to comment.