-
Notifications
You must be signed in to change notification settings - Fork 202
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #334 from Saibo-creator/query_builder
add a basic QueryBuilder, test and doc
- Loading branch information
Showing
6 changed files
with
137 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from lmql.api import run, run_sync | ||
|
||
|
||
class QueryExecution: | ||
def __init__(self, query_string): | ||
self.query_string = query_string | ||
|
||
async def run(self, *args, **kwargs): | ||
# This method should asynchronously execute the query_string | ||
return await run(self.query_string, *args, **kwargs) | ||
|
||
def run_sync(self, *args, **kwargs): | ||
# This method should synchronously execute the query_string | ||
return run_sync(self.query_string, *args, **kwargs) | ||
|
||
|
||
class QueryBuilder: | ||
def __init__(self): | ||
self.decoder = None | ||
self.prompt = None | ||
self.model = None | ||
self.where = None | ||
self.distribution_expr = None | ||
|
||
def set_decoder(self, decoder='argmax', **kwargs): | ||
if decoder not in ['argmax', 'sample', 'beam', 'beam_var', 'var', 'best_k']: | ||
raise ValueError(f"Invalid decoder: {decoder}") | ||
self.decoder = (decoder, kwargs) | ||
return self | ||
|
||
def set_prompt(self, prompt="What is the capital of France? [ANSWER]"): | ||
self.prompt = prompt | ||
return self | ||
|
||
def set_model(self, model="gpt2"): | ||
self.model = model | ||
return self | ||
|
||
def set_where(self, where="len(TOKENS(ANSWER)) < 10"): | ||
""" | ||
Add a where clause to the query | ||
If a where clause already exists, the new clause is appended with an 'and' | ||
If the user wants to use 'or', they need to put or in the where clause | ||
such as: "len(TOKENS(ANSWER)) < 10 or len(TOKENS(ANSWER)) > 2" | ||
""" | ||
self.where = where if self.where is None else f"{self.where} and {where}" | ||
return self | ||
|
||
def set_distribution(self, variable="ANSWER", expr='["A", "B"]'): | ||
self.distribution_expr = (variable, expr) | ||
return self | ||
|
||
def build(self): | ||
components = [] | ||
|
||
if self.decoder: | ||
decoder_str = self.decoder[0] | ||
if self.decoder[1]: # If keyword arguments are provided | ||
decoder_str += f"({self.decoder[1]})" | ||
components.append(decoder_str) | ||
|
||
if self.prompt: | ||
components.append(f'"{self.prompt}"') | ||
|
||
if self.model: | ||
components.append(f'from "{self.model}"') | ||
|
||
if self.where: | ||
components.append(f'where {self.where}') | ||
|
||
if self.distribution_expr: | ||
variable, expr = self.distribution_expr | ||
components.append(f'distribution {variable} in {expr}') | ||
|
||
query_string = ' '.join(components) | ||
# Return an instance of QueryExecution instead of a string | ||
return QueryExecution(query_string) | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import lmql | ||
import numpy as np | ||
|
||
from lmql.tests.expr_test_utils import run_all_tests | ||
|
||
|
||
|
||
def test_query_builder(): | ||
# Example usage: | ||
query = (lmql.QueryBuilder() | ||
.set_decoder('argmax') | ||
.set_prompt('What is the capital of France? [ANSWER]') | ||
.set_model('gpt2') | ||
.set_where('len(TOKENS(ANSWER)) < 10') | ||
.set_where('len(TOKENS(ANSWER)) > 2') | ||
.build()) | ||
|
||
expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" where len(TOKENS(ANSWER)) < 10 and len(TOKENS(ANSWER)) > 2' | ||
|
||
assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}" | ||
out = query.run_sync() | ||
|
||
def test_query_builder_with_dist(): | ||
|
||
query = (lmql.QueryBuilder() | ||
.set_decoder('argmax') | ||
.set_prompt('What is the capital of France? [ANSWER]') | ||
.set_model('gpt2') | ||
.set_distribution('ANSWER', '["Paris", "London"]') | ||
.build()) | ||
|
||
expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" distribution ANSWER in ["Paris", "London"]' | ||
|
||
assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}" | ||
out = query.run_sync() | ||
|
||
if __name__ == "__main__": | ||
run_all_tests(globals()) |