Skip to content

Commit

Permalink
Specify model inference and embedding endpoint separately (#286)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Nov 7, 2023
1 parent 8ad1209 commit e2a685a
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 22 deletions.
22 changes: 20 additions & 2 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def configure():
# search for key in enviornment
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.")
# TODO: eventually stop relying on env variables and pass in keys explicitly
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()

# azure credentials
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=False).ask()
Expand Down Expand Up @@ -77,7 +79,21 @@ def configure():
if len(endpoint_options) == 1:
default_endpoint = endpoint_options[0]
else:
default_endpoint = questionary.select("Select default endpoint:", endpoint_options).ask()
default_endpoint = questionary.select("Select default inference endpoint:", endpoint_options).ask()

# configure embedding provider
endpoint_options.append("local") # can compute embeddings locally
if len(endpoint_options) == 1:
default_embedding_endpoint = endpoint_options[0]
print(f"Using embedding endpoint {default_embedding_endpoint}")
else:
default_embedding_endpoint = questionary.select("Select default embedding endpoint:", endpoint_options).ask()

# configure embedding dimentions
default_embedding_dim = 1536
if default_embedding_endpoint == "local":
# HF model uses lower dimentionality
default_embedding_dim = 384

# configure preset
default_preset = questionary.select("Select default preset:", preset_options, default=DEFAULT_PRESET).ask()
Expand Down Expand Up @@ -127,6 +143,8 @@ def configure():
model=default_model,
preset=default_preset,
model_endpoint=default_endpoint,
embedding_model=default_embedding_endpoint,
embedding_dim=default_embedding_dim,
default_persona=default_persona,
default_human=default_human,
default_agent=default_agent,
Expand Down
2 changes: 1 addition & 1 deletion memgpt/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def store_docs(name, docs, show_progress=True):
text = node.text.replace("\x00", "\uFFFD") # hacky fix for error on null characters
assert (
len(node.embedding) == config.embedding_dim
), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}"
), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}: {node.embedding}"
passages.append(Passage(text=text, embedding=vector))

# insert into storage
Expand Down
2 changes: 0 additions & 2 deletions memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ def delete(self):
self.db_model.__table__.drop(self.engine)

def save(self):
# don't need to save
print("Saving db")
return

@staticmethod
Expand Down
4 changes: 0 additions & 4 deletions memgpt/connectors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def get(self, id: str) -> Passage:

def insert(self, passage: Passage):
nodes = [TextNode(text=passage.text, embedding=passage.embedding)]
print("nodes", nodes)
self.nodes += nodes
if isinstance(self.index, EmptyIndex):
self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
Expand All @@ -96,7 +95,6 @@ def insert_many(self, passages: List[Passage]):
self.nodes += nodes
if isinstance(self.index, EmptyIndex):
self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
print("new size", len(self.get_nodes()))
else:
orig_size = len(self.get_nodes())
self.index.insert_nodes(nodes)
Expand All @@ -113,15 +111,13 @@ def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Pas
)
nodes = retriever.retrieve(query)
results = [Passage(embedding=node.embedding, text=node.text) for node in nodes]
print(results)
return results

def save(self):
# assert len(self.nodes) == len(self.get_nodes()), f"Expected {len(self.nodes)} nodes, got {len(self.get_nodes())} nodes"
self.nodes = self.get_nodes()
os.makedirs(self.save_directory, exist_ok=True)
pickle.dump(self.nodes, open(self.save_path, "wb"))
print("Saved local", self.save_path)

@staticmethod
def list_loaded_data():
Expand Down
25 changes: 12 additions & 13 deletions memgpt/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typer
import os
from llama_index.embeddings import OpenAIEmbedding


Expand All @@ -10,10 +11,11 @@ def embedding_model():
# load config
config = MemGPTConfig.load()

# TODO: use embedding_endpoint in the future
if config.model_endpoint == "openai":
return OpenAIEmbedding()
elif config.model_endpoint == "azure":
endpoint = config.embedding_model
if endpoint == "openai":
model = OpenAIEmbedding(api_base="https://api.openai.com/v1", api_key=config.openai_key)
return model
elif endpoint == "azure":
return OpenAIEmbedding(
model="text-embedding-ada-002",
deployment_name=config.azure_embedding_deployment,
Expand All @@ -22,17 +24,14 @@ def embedding_model():
api_type="azure",
api_version=config.azure_version,
)
else:
elif endpoint == "local":
# default to hugging face model
from llama_index.embeddings import HuggingFaceEmbedding

os.environ["TOKENIZERS_PARALLELISM"] = "False"
model = "BAAI/bge-small-en-v1.5"
typer.secho(
f"Warning: defaulting to HuggingFace embedding model {model} since model endpoint is not OpenAI or Azure.",
fg=typer.colors.YELLOW,
)
typer.secho(f"Warning: ensure torch and transformers are installed")
# return f"local:{model}"

# loads BAAI/bge-small-en-v1.5
return HuggingFaceEmbedding(model_name=model)
else:
# use env variable OPENAI_API_BASE
model = OpenAIEmbedding()
return model

0 comments on commit e2a685a

Please sign in to comment.