-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into mb/feat_create_highlight_extractor
- Loading branch information
Showing
8 changed files
with
99 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ __pycache__/ | |
.*venv* | ||
_build/ | ||
.tox | ||
.env | ||
|
||
# MacOS | ||
.DS_Store | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
OPENAI_API_KEY = "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
You are a teacher creating spaced repetition prompts to reinforce knowledge from a student. | ||
|
||
Your goals are to: | ||
- Create perspective: The questions should relate concepts to one another, comparing alternate solutions to a given problem if possible. | ||
* Create understanding: Prompts should foster understanding, activating related concepts and comparing them to one another. Focus especially on why concepts are related to each other, and how they differ. | ||
- Be concise: Questions and answers should be as short as possible. Be clear, direct, even curt, and don't state anything in the answer that could be inferred from the question. | ||
* Be paraphrased: The questions should never quote the context. | ||
- Be context-independent: In review, this prompt will be interleaved with many others about many topics. The prompt must cue or supply whatever context is necessary to understand the question. They should not assume one has read the text that generated the prompts. It shouldn't address the text or use the context of the text in any way. | ||
|
||
You will be provided with two inputs: | ||
- target (delimited by <target></target> tags), representing the specific content which should be reinforced | ||
- context (delimited by <context></context> tags), representing the content surrounding the target | ||
|
||
Users will be presented with your results and keep only the ones they like, so be sure to write an excellent prompt! The entire response should be at most 20 words. The format is JSON, with the keys 'question' and 'answer'. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from pathlib import Path | ||
|
||
|
||
def read_txt(filepath: Path) -> str: | ||
"""Read a text file.""" | ||
return filepath.read_text(encoding="utf-8") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import ast | ||
from pathlib import Path | ||
|
||
from dotenv import load_dotenv | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.schema import HumanMessage, SystemMessage | ||
from langchain.schema.output import LLMResult | ||
|
||
import gpt2anki.fileio as fileio | ||
from gpt2anki.sources.hypothesis import Highlight | ||
|
||
load_dotenv() | ||
print(Path(__file__)) | ||
PROMPT_DIR = Path(__file__).parent.parent.parent / "prompts" | ||
assert PROMPT_DIR.exists(), "Prompts directory does not exist" | ||
SYSTEM_PROMPT = fileio.read_txt(PROMPT_DIR / "martin_prompt.txt") | ||
|
||
|
||
def initialize_model(model_name: str = "gpt-4") -> ChatOpenAI: | ||
return ChatOpenAI(model=model_name) | ||
|
||
|
||
def highlight_to_prompt(highlight: Highlight) -> str: | ||
return "<target>{target}</target><context>{context}</context>".format( | ||
target=highlight.highlight, | ||
context=highlight.context, | ||
) | ||
|
||
|
||
def parse_output(output: LLMResult) -> dict[str, str]: | ||
text_output = output.generations[0][0].text | ||
# extract dictionary from string | ||
start = text_output.find("{") | ||
end = text_output.rfind("}") + 1 | ||
return ast.literal_eval(text_output[start:end]) | ||
|
||
|
||
async def prompt_gpt( | ||
model: ChatOpenAI, | ||
highlight: Highlight, | ||
) -> dict[str, str]: | ||
prompt = highlight_to_prompt(highlight) | ||
messages = [SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=prompt)] | ||
output = await model.agenerate(messages=[messages]) | ||
return parse_output(output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import pytest | ||
|
||
import gpt2anki.magi as magi | ||
from gpt2anki.sources.hypothesis import Highlight | ||
|
||
|
||
# create a pytest fixture for the model | ||
@pytest.fixture(scope="session") | ||
def model() -> magi.ChatOpenAI: | ||
return magi.initialize_model(model_name="gpt-3.5-turbo") | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_model_response(model: magi.ChatOpenAI) -> None: | ||
higlight = Highlight( | ||
context="Mitochondria is the powerhouse of the cell", | ||
highlight="Mitochondria", | ||
) | ||
output = await magi.prompt_gpt(model, higlight) | ||
# check that outpuis a dictionary with keys "answer" and "question" | ||
assert "answer" in output | ||
assert "question" in output | ||
assert len(output) == 2 |