Skip to content

Commit

Permalink
Changed the underlying LLM and lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirushikesh committed Sep 14, 2024
1 parent 2e6b204 commit 51713fc
Showing 1 changed file with 43 additions and 31 deletions.
74 changes: 43 additions & 31 deletions notebook/agentchat_huggingface_langchain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
}
},
"source": [
"# Using LangChain with AutoGen and Hugging Face"
"# Using AutoGen AgentChat with LangChain and Hugging Face\""
]
},
{
Expand Down Expand Up @@ -67,13 +67,14 @@
"metadata": {},
"outputs": [],
"source": [
"from types import SimpleNamespace\n",
"import os\n",
"import json\n",
"import os\n",
"from types import SimpleNamespace\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent, config_list_from_json\n",
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage\n",
"from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline"
"from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent, config_list_from_json"
]
},
{
Expand Down Expand Up @@ -172,29 +173,26 @@
" gen_config_params = config.get(\"params\", {})\n",
" self.model_name = config[\"model\"]\n",
" pipeline = HuggingFacePipeline.from_model_id(\n",
" model_id=self.model_name,\n",
" task=\"text-generation\",\n",
" pipeline_kwargs=gen_config_params,\n",
" device=self.device\n",
" model_id=self.model_name, task=\"text-generation\", pipeline_kwargs=gen_config_params,device=self.device,\n",
" )\n",
" self.model = ChatHuggingFace(llm=pipeline)\n",
" print(f\"Loaded model {config['model']} to {self.device}\")\n",
"\n",
" def _to_chatml_format(self, message):\n",
" \"\"\"Convert message to ChatML format.\"\"\"\n",
" if message['role'] == 'system':\n",
" if message[\"role\"] == \"system\":\n",
" return SystemMessage(content=message[\"content\"])\n",
" if message['role'] == 'assistant':\n",
" if message[\"role\"] == \"assistant\":\n",
" return AIMessage(content=message[\"content\"])\n",
" if message['role'] == 'user':\n",
" if message[\"role\"] == \"user\":\n",
" return HumanMessage(content=message[\"content\"])\n",
" raise ValueError(f\"Unknown message type: {type(message)}\")\n",
" \n",
"\n",
" def create(self, params):\n",
" \"\"\"Create a response using the model.\"\"\"\n",
" if params.get(\"stream\", False) and \"messages\" in params:\n",
" raise NotImplementedError(\"Local models do not support streaming.\")\n",
" \n",
"\n",
" num_of_responses = params.get(\"n\", 1)\n",
" response = SimpleNamespace()\n",
" inputs = [self._to_chatml_format(m) for m in params[\"messages\"]]\n",
Expand Down Expand Up @@ -282,19 +280,23 @@
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"OAI_CONFIG_LIST\"] = json.dumps([{\n",
" \"model\": \"microsoft/Phi-3.5-mini-instruct\",\n",
" \"model_client_cls\": \"CustomModelClient\",\n",
" \"device\": 0,\n",
" \"n\": 1,\n",
" \"params\": {\n",
" \"max_new_tokens\": 100,\n",
" \"top_k\": 50,\n",
" \"temperature\": 0.1,\n",
" \"do_sample\": True,\n",
" \"return_full_text\": False\n",
" }\n",
"}])"
"os.environ[\"OAI_CONFIG_LIST\"] = json.dumps(\n",
" [\n",
" {\n",
" \"model\": \"mistralai/Mistral-7B-Instruct-v0.2\",\n",
" \"model_client_cls\": \"CustomModelClient\",\n",
" \"device\": 0,\n",
" \"n\": 1,\n",
" \"params\": {\n",
" \"max_new_tokens\": 500,\n",
" \"top_k\": 50,\n",
" \"temperature\": 0.1,\n",
" \"do_sample\": True,\n",
" \"return_full_text\": False,\n",
" },\n",
" }\n",
" ]\n",
")"
]
},
{
Expand All @@ -309,6 +311,19 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"from huggingface_hub import login\n",
"\n",
"# The Mistral-7B-Instruct-v0.2 is a gated model which requires API token to access\n",
"login(token = getpass.getpass(\"Enter your HuggingFace API Token\"))"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -334,10 +349,7 @@
],
"source": [
"assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom})\n",
"user_proxy = UserProxyAgent(\n",
" \"user_proxy\",\n",
" code_execution_config=False\n",
")"
"user_proxy = UserProxyAgent(\"user_proxy\", code_execution_config=False)"
]
},
{
Expand Down

0 comments on commit 51713fc

Please sign in to comment.