Skip to content

Commit

Permalink
openai_build_prompt() aggregate function, no docs or tests yet, refs #4
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jan 12, 2023
1 parent 00b5764 commit ef6cabf
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions datasette_openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,41 @@ def count_tokens(text):
return len(tokenize(text))


class BuildPrompt:
def __init__(self):
self.texts = []
self.first = True
self.prefix = ""
self.suffix = ""
self.completion_tokens = 0
self.token_limit = 0

def step(self, text, prefix, suffix, completion_tokens, token_limit=4000):
if self.first:
self.first = False
self.prefix = prefix
self.suffix = suffix
self.completion_tokens = completion_tokens
self.token_limit = token_limit
self.texts.append(text)

def finalize(self):
available_tokens = (
self.token_limit
- self.completion_tokens
- count_tokens(self.prefix)
- count_tokens(self.suffix)
)
if available_tokens < 0:
return self.prefix + " " + self.suffix
# Get that many tokens from each of the texts
tokens_per_text = available_tokens // len(self.texts)
truncated_texts = []
for text in self.texts:
truncated_texts.append(" ".join(tokenize(text)[:tokens_per_text]))
return self.prefix + " " + " ".join(truncated_texts) + " " + self.suffix


@hookimpl
def prepare_connection(conn):
conn.create_function("openai_embedding", 2, openai_embedding)
Expand All @@ -29,6 +64,8 @@ def prepare_connection(conn):
conn.create_function("openai_strip_tags", 1, openai_strip_tags)
conn.create_function("openai_count_tokens", 1, count_tokens)
conn.create_function("openai_tokenize", 1, lambda s: json.dumps(tokenize(s)))
conn.create_aggregate("openai_build_prompt", 4, BuildPrompt)
conn.create_aggregate("openai_build_prompt", 5, BuildPrompt)


def openai_strip_tags(text):
Expand Down

0 comments on commit ef6cabf

Please sign in to comment.