Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a streaming RAG method #967

Merged
merged 21 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deno.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 34 additions & 2 deletions packages/ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,21 @@ Creates an instance of `EdgeDBAI` with the specified client and options.

Returns a new `EdgeDBAI` instance with an updated query context.

- `async queryRag(message: string, context: QueryContext = this.context): Promise<string>`
- `async queryRag(message: string, context?: QueryContext): Promise<string>`

Sends a query with context to the configured AI model and returns the response as a string.

- `streamRag(message: string, context?: QueryContext): AsyncIterable<StreamingMessage> & PromiseLike<Response>`

It can be used in two ways:

- as **an async iterator** - if you want to process streaming data in real-time as it arrives, ideal for handling long-running streams;
- as **a Promise that resolves to a full Response object** - you have complete control over how you want to handle the stream, this might be useful when you want to manipulate the raw stream or parse it in a custom way.

- `generateEmbeddings(inputs: string[], model: string): Promise<number[]>`

Generates embeddings for the array of strings.

## Example

The following example demonstrates how to use the `@edgedb/ai` package to query an AI model about astronomy and chemistry.
Expand Down Expand Up @@ -76,6 +87,27 @@ console.timeEnd("gpt-3.5 Time");
const fastChemistryAi = fastAstronomyAi.withContext({ query: "Chemistry" });

console.log(
await fastChemistryAi.queryRag("What is the atomic number of gold?")
await fastChemistryAi.queryRag("What is the atomic number of gold?"),
);

// handle the Response object
const response = await fastChemistryAi.streamRag(
"What is the atomic number of gold?",
);
handleReadableStream(response); // custom function that reads the stream

// handle individual chunks as they arrive
for await (const chunk of fastChemistryAi.streamRag(
"What is the atomic number of gold?",
)) {
console.log("chunk", chunk);
}

// embeddings
console.log(
await fastChemistryAi.generateEmbeddings(
["What is the atomic number of gold?"],
"text-embedding-ada-002",
),
);
```
3 changes: 3 additions & 0 deletions packages/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,8 @@
},
"peerDependencies": {
"edgedb": "^1.5.0"
},
"dependencies": {
"eventsource-parser": "^1.1.2"
}
}
137 changes: 117 additions & 20 deletions packages/ai/src/core.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
import type { Client } from "edgedb";
import { getHTTPSCRAMAuth } from "edgedb/dist/httpScram.js";
import cryptoUtils from "edgedb/dist/adapter.crypto.node.js";
import {
EventSourceParserStream,
type ParsedEvent,
} from "eventsource-parser/stream";

import type { ResolvedConnectConfig } from "edgedb/dist/conUtils.js";
import {
getAuthenticatedFetch,
type AuthenticatedFetch,
} from "edgedb/dist/utils.js";

import type { AIOptions, QueryContext, RAGRequest } from "./types.js";
import type {
AIOptions,
QueryContext,
RAGRequest,
StreamingMessage,
} from "./types.js";
import { getHTTPSCRAMAuth } from "edgedb/dist/httpScram.js";
import { cryptoUtils } from "edgedb/dist/browserCrypto.js";

export function createAI(client: Client, options: AIOptions) {
return new EdgeDBAI(client, options);
}

const httpSCRAMAuth = getHTTPSCRAMAuth(cryptoUtils.default);
const httpSCRAMAuth = getHTTPSCRAMAuth(cryptoUtils);

export class EdgeDBAI {
/** @internal */
Expand All @@ -36,7 +46,10 @@ export class EdgeDBAI {
}

private static async getAuthenticatedFetch(client: Client) {
const connectConfig = await client.resolveConnectionParams();
const connectConfig: ResolvedConnectConfig = (
await (client as any).pool._getNormalizedConnectConfig()
).connectionParams;

return getAuthenticatedFetch(connectConfig, httpSCRAMAuth, "ext/ai/");
}

Expand All @@ -55,7 +68,7 @@ export class EdgeDBAI {
});
}

private async fetchRag(request: RAGRequest) {
private async fetchRag(request: Omit<RAGRequest, "model" | "prompt">) {
const headers = request.stream
? { Accept: "text/event-stream", "Content-Type": "application/json" }
: { Accept: "application/json", "Content-Type": "application/json" };
Expand All @@ -65,7 +78,11 @@ export class EdgeDBAI {
)("rag", {
method: "POST",
headers,
body: JSON.stringify(request),
body: JSON.stringify({
...request,
model: this.options.model,
prompt: this.options.prompt,
}),
});

if (!response.ok) {
Expand All @@ -76,35 +93,115 @@ export class EdgeDBAI {
return response;
}

async queryRag(
message: string,
context: QueryContext = this.context,
): Promise<string> {
const response = await this.fetchRag({
model: this.options.model,
prompt: this.options.prompt,
async queryRag(query: string, context = this.context): Promise<string> {
const res = await this.fetchRag({
context,
query: message,
query,
stream: false,
});

if (!response.headers.get("content-type")?.includes("application/json")) {
if (!res.headers.get("content-type")?.includes("application/json")) {
throw new Error(
"expected response to have content-type: application/json",
"Expected response to have content-type: application/json",
);
}

const data = await response.json();
const data = await res.json();

if (
!data ||
typeof data !== "object" ||
typeof data.response !== "string"
) {
throw new Error(
"expected response to be object with response key of type string",
"Expected response to be object with response key of type string",
);
}

return data.response;
}

streamRag(
query: string,
context = this.context,
): AsyncIterable<StreamingMessage> & PromiseLike<Response> {
const fetchRag = this.fetchRag.bind(this);

const ragOptions = {
context,
query,
stream: true,
};

return {
async *[Symbol.asyncIterator]() {
const res = await fetchRag(ragOptions);

if (!res.body) {
throw new Error("Expected response to include a body");
}

const reader = res.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(new EventSourceParserStream())
.getReader();
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const message = extractMessageFromParsedEvent(value);
yield message;
if (message.type === "message_stop") break;
}
} finally {
reader.releaseLock();
}
},
then<TResult1 = Response, TResult2 = never>(
onfulfilled?:
| ((value: Response) => TResult1 | PromiseLike<TResult1>)
| undefined
| null,
onrejected?:
| ((reason: any) => TResult2 | PromiseLike<TResult2>)
| undefined
| null,
): Promise<TResult1 | TResult2> {
return fetchRag(ragOptions).then(onfulfilled, onrejected);
},
};
}

async generateEmbeddings(inputs: string[], model: string): Promise<number[]> {
const response = await (
await this.authenticatedFetch
)("embeddings", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
input: inputs,
}),
});

if (!response.ok) {
const bodyText = await response.text();
throw new Error(bodyText);
}

const data: { data: { embedding: number[] }[] } = await response.json();
return data.data[0].embedding;
}
}

function extractMessageFromParsedEvent(
parsedEvent: ParsedEvent,
): StreamingMessage {
const { data } = parsedEvent;
if (!data) {
throw new Error("Expected SSE message to include a data payload");
}
return JSON.parse(data) as StreamingMessage;
}
51 changes: 51 additions & 0 deletions packages/ai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,54 @@ export interface RAGRequest {
query: string;
stream?: boolean;
}

export interface MessageStart {
type: "message_start";
message: {
role: "assistant" | "system" | "user";
id: string;
model: string;
};
}

export interface ContentBlockStart {
type: "content_block_start";
index: number;
content_block: {
text: string;
type: "text";
};
}

export interface ContentBlockDelta {
type: "content_block_delta";
delta: {
type: "text_delta";
text: string;
};
index: number;
}

export interface ContentBlockStop {
type: "content_block_stop";
index: number;
}

export interface MessageDelta {
type: "message_delta";
delta: {
stop_reason: "stop";
};
}

export interface MessageStop {
type: "message_stop";
}

export type StreamingMessage =
| MessageStart
| ContentBlockStart
| ContentBlockDelta
| ContentBlockStop
| MessageDelta
| MessageStop;
2 changes: 1 addition & 1 deletion packages/generate/src/syntax/cardinality.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Cardinality } from "edgedb/dist/reflection/index";
import type { TypeSet } from "./typesystem";

// Computing cardinality of path
// From base set cadinality and pointer cardinality
// From base set cardinality and pointer cardinality
// Used in path expressions
// Cardinality Empty AtMostOne One Many AtLeastOne
// Empty 0 0 0 0 0
Expand Down
5 changes: 5 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -2370,6 +2370,11 @@ etag@~1.8.1:
resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887"
integrity sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==

eventsource-parser@^1.1.2:
version "1.1.2"
resolved "https://registry.yarnpkg.com/eventsource-parser/-/eventsource-parser-1.1.2.tgz#ed6154a4e3dbe7cda9278e5e35d2ffc58b309f89"
integrity sha512-v0eOBUbiaFojBu2s2NPBfYUoRR9GjcDNvCXVaqEf5vVfpIAh9f8RCo4vXTP8c63QRKCFwoLpMpTdPwwhEKVgzA==

execa@^5.0.0:
version "5.1.1"
resolved "https://registry.yarnpkg.com/execa/-/execa-5.1.1.tgz#f80ad9cbf4298f7bd1d4c9555c21e93741c411dd"
Expand Down