Skip to content

Commit

Permalink
fix and add
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaobeicn committed Sep 27, 2024
1 parent 0a902cd commit 4b09171
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 27 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ python examples/storm_examples/run_storm_wiki_gpt.py \
To run Co-STORM with `gpt` family models with default configurations,
1. Add `BING_SEARCH_API_KEY="xxx"`to `secrets.toml`
1. Add `BING_SEARCH_API_KEY="xxx"` and `ENCODER_API_TYPE="xxx"` to `secrets.toml`
2. Run the following command
```bash
Expand Down
5 changes: 5 additions & 0 deletions examples/costorm_examples/run_costorm_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def main(args):
with open(os.path.join(args.output_dir, "report.md"), "w") as f:
f.write(article)

# Save instance dump
instance_copy = costorm_runner.to_dict()
with open(os.path.join(args.output_dir, "instance_dump.json"), "w") as f:
json.dump(instance_copy, f, indent=2)

# Save logging
log_dump = costorm_runner.dump_logging_and_reset()
with open(os.path.join(args.output_dir, "log.json"), "w") as f:
Expand Down
27 changes: 1 addition & 26 deletions knowledge_storm/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,13 @@


class EmbeddingModel:
def __init__():
def __init__(self):
pass

def get_embedding(self, text: str) -> Tuple[np.ndarray, int]:
raise Exception("Not implemented")


class OpenAIEmbeddingModel(EmbeddingModel):
def __init__(self, model: str = "text-embedding-3-small", api_key: str = None):
if not api_key:
self.api_key = os.getenv("OPENAI_API_KEY")

self.url = "https://api.openai.com/v1/embeddings"
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
self.model = model

def get_embedding(self, text: str) -> Tuple[np.ndarray, int]:
data = {"input": text, "model": "text-embedding-3-small"}

response = requests.post(self.url, headers=self.headers, json=data)
if response.status_code == 200:
data = response.json()
embedding = np.array(data["data"][0]["embedding"])
token = data["usage"]["prompt_tokens"]
return embedding, token
else:
response.raise_for_status()


class OpenAIEmbeddingModel(EmbeddingModel):
def __init__(self, model: str = "text-embedding-3-small", api_key: str = None):
if not api_key:
Expand Down

0 comments on commit 4b09171

Please sign in to comment.