From 51713fcc4e591229b175222ab51f4b55b85e0259 Mon Sep 17 00:00:00 2001 From: Kirushikesh Date: Sat, 14 Sep 2024 08:52:00 -0400 Subject: [PATCH] Changed the underlying LLM and lint fix --- .../agentchat_huggingface_langchain.ipynb | 74 +++++++++++-------- 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/notebook/agentchat_huggingface_langchain.ipynb b/notebook/agentchat_huggingface_langchain.ipynb index a902b0acd19..000bcbdbcfd 100644 --- a/notebook/agentchat_huggingface_langchain.ipynb +++ b/notebook/agentchat_huggingface_langchain.ipynb @@ -9,7 +9,7 @@ } }, "source": [ - "# Using LangChain with AutoGen and Hugging Face" + "# Using AutoGen AgentChat with LangChain and Hugging Face\"" ] }, { @@ -67,13 +67,14 @@ "metadata": {}, "outputs": [], "source": [ - "from types import SimpleNamespace\n", - "import os\n", "import json\n", + "import os\n", + "from types import SimpleNamespace\n", "\n", - "from autogen import AssistantAgent, UserProxyAgent, config_list_from_json\n", "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage\n", - "from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline" + "from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline\n", + "\n", + "from autogen import AssistantAgent, UserProxyAgent, config_list_from_json" ] }, { @@ -172,29 +173,26 @@ " gen_config_params = config.get(\"params\", {})\n", " self.model_name = config[\"model\"]\n", " pipeline = HuggingFacePipeline.from_model_id(\n", - " model_id=self.model_name,\n", - " task=\"text-generation\",\n", - " pipeline_kwargs=gen_config_params,\n", - " device=self.device\n", + " model_id=self.model_name, task=\"text-generation\", pipeline_kwargs=gen_config_params,device=self.device,\n", " )\n", " self.model = ChatHuggingFace(llm=pipeline)\n", " print(f\"Loaded model {config['model']} to {self.device}\")\n", "\n", " def _to_chatml_format(self, message):\n", " \"\"\"Convert message to ChatML format.\"\"\"\n", - " if message['role'] == 'system':\n", + " if message[\"role\"] == \"system\":\n", " return SystemMessage(content=message[\"content\"])\n", - " if message['role'] == 'assistant':\n", + " if message[\"role\"] == \"assistant\":\n", " return AIMessage(content=message[\"content\"])\n", - " if message['role'] == 'user':\n", + " if message[\"role\"] == \"user\":\n", " return HumanMessage(content=message[\"content\"])\n", " raise ValueError(f\"Unknown message type: {type(message)}\")\n", - " \n", + "\n", " def create(self, params):\n", " \"\"\"Create a response using the model.\"\"\"\n", " if params.get(\"stream\", False) and \"messages\" in params:\n", " raise NotImplementedError(\"Local models do not support streaming.\")\n", - " \n", + "\n", " num_of_responses = params.get(\"n\", 1)\n", " response = SimpleNamespace()\n", " inputs = [self._to_chatml_format(m) for m in params[\"messages\"]]\n", @@ -282,19 +280,23 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ[\"OAI_CONFIG_LIST\"] = json.dumps([{\n", - " \"model\": \"microsoft/Phi-3.5-mini-instruct\",\n", - " \"model_client_cls\": \"CustomModelClient\",\n", - " \"device\": 0,\n", - " \"n\": 1,\n", - " \"params\": {\n", - " \"max_new_tokens\": 100,\n", - " \"top_k\": 50,\n", - " \"temperature\": 0.1,\n", - " \"do_sample\": True,\n", - " \"return_full_text\": False\n", - " }\n", - "}])" + "os.environ[\"OAI_CONFIG_LIST\"] = json.dumps(\n", + " [\n", + " {\n", + " \"model\": \"mistralai/Mistral-7B-Instruct-v0.2\",\n", + " \"model_client_cls\": \"CustomModelClient\",\n", + " \"device\": 0,\n", + " \"n\": 1,\n", + " \"params\": {\n", + " \"max_new_tokens\": 500,\n", + " \"top_k\": 50,\n", + " \"temperature\": 0.1,\n", + " \"do_sample\": True,\n", + " \"return_full_text\": False,\n", + " },\n", + " }\n", + " ]\n", + ")" ] }, { @@ -309,6 +311,19 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "from huggingface_hub import login\n", + "\n", + "# The Mistral-7B-Instruct-v0.2 is a gated model which requires API token to access\n", + "login(token = getpass.getpass(\"Enter your HuggingFace API Token\"))" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -334,10 +349,7 @@ ], "source": [ "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom})\n", - "user_proxy = UserProxyAgent(\n", - " \"user_proxy\",\n", - " code_execution_config=False\n", - ")" + "user_proxy = UserProxyAgent(\"user_proxy\", code_execution_config=False)" ] }, {