-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
large_db_high_cardinality.ts
153 lines (131 loc) Β· 4.16 KB
/
large_db_high_cardinality.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import { DocumentInterface } from "@langchain/core/documents";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import {
RunnablePassthrough,
RunnableSequence,
} from "@langchain/core/runnables";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { createSqlQueryChain } from "langchain/chains/sql_db";
import { SqlDatabase } from "langchain/sql_db";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { DataSource } from "typeorm";
const datasource = new DataSource({
type: "sqlite",
database: "../../../../Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
async function queryAsList(database: any, query: string): Promise<string[]> {
const res: Array<{ [key: string]: string }> = JSON.parse(
await database.run(query)
)
.flat()
.filter((el: any) => el != null);
const justValues: Array<string> = res.map((item) =>
Object.values(item)[0]
.replace(/\b\d+\b/g, "")
.trim()
);
return justValues;
}
let properNouns: string[] = await queryAsList(db, "SELECT Name FROM Artist");
properNouns = properNouns.concat(
await queryAsList(db, "SELECT Title FROM Album")
);
properNouns = properNouns.concat(
await queryAsList(db, "SELECT Name FROM Genre")
);
console.log(properNouns.length);
/**
647
*/
console.log(properNouns.slice(0, 5));
/**
[
'AC/DC',
'Accept',
'Aerosmith',
'Alanis Morissette',
'Alice In Chains'
]
*/
// Now we can embed and store all of our values in a vector database:
const vectorDb = await MemoryVectorStore.fromTexts(
properNouns,
{},
new OpenAIEmbeddings()
);
const retriever = vectorDb.asRetriever(15);
// And put together a query construction chain that first retrieves values from the database and inserts them into the prompt:
const system = `You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run.
Unless otherwise specified, do not return more than {top_k} rows.
Here is the relevant table info: {table_info}
Here is a non-exhaustive list of possible feature values.
If filtering on a feature value make sure to check its spelling against this list first:
{proper_nouns}`;
const prompt = ChatPromptTemplate.fromMessages([
["system", system],
["human", "{input}"],
]);
const llm = new ChatOpenAI({ model: "gpt-4", temperature: 0 });
const queryChain = await createSqlQueryChain({
llm,
db,
prompt,
dialect: "sqlite",
});
const retrieverChain = RunnableSequence.from([
(i: { question: string }) => i.question,
retriever,
(docs: Array<DocumentInterface>) =>
docs.map((doc) => doc.pageContent).join("\n"),
]);
const chain = RunnablePassthrough.assign({
proper_nouns: retrieverChain,
}).pipe(queryChain);
// To try out our chain, letβs see what happens when we try filtering on βelenis morisetβ, a misspelling of Alanis Morissette, without and with retrieval:
// Without retrieval
const query = await queryChain.invoke({
question: "What are all the genres of Elenis Moriset songs?",
proper_nouns: "",
});
console.log("query", query);
/**
query SELECT DISTINCT Genre.Name
FROM Genre
JOIN Track ON Genre.GenreId = Track.GenreId
JOIN Album ON Track.AlbumId = Album.AlbumId
JOIN Artist ON Album.ArtistId = Artist.ArtistId
WHERE Artist.Name = 'Elenis Moriset'
LIMIT 5;
*/
console.log("db query results", await db.run(query));
/**
db query results []
*/
// -------------
// You can see a LangSmith trace of the above chain here:
// https://smith.langchain.com/public/b153cb9b-6fbb-43a8-b2ba-4c86715183b9/r
// -------------
// With retrieval:
const query2 = await chain.invoke({
question: "What are all the genres of Elenis Moriset songs?",
});
console.log("query2", query2);
/**
query2 SELECT DISTINCT Genre.Name
FROM Genre
JOIN Track ON Genre.GenreId = Track.GenreId
JOIN Album ON Track.AlbumId = Album.AlbumId
JOIN Artist ON Album.ArtistId = Artist.ArtistId
WHERE Artist.Name = 'Alanis Morissette';
*/
console.log("db query results", await db.run(query2));
/**
db query results [{"Name":"Rock"}]
*/
// -------------
// You can see a LangSmith trace of the above chain here:
// https://smith.langchain.com/public/2f4f0e37-3b7f-47b5-837c-e2952489cac0/r
// -------------