From bba9e796062471becaf1cd7f91b6b37c376e42c4 Mon Sep 17 00:00:00 2001 From: lzl65825 Date: Mon, 1 Jul 2024 01:25:57 -0700 Subject: [PATCH] feat: AutoFlow on TravelPlanner --- .gitignore | 1 + README.md | 62 +- requirements_travel.txt | 9 + src/{auto_agi.py => auto_main.py} | 189 ++++-- src/flow/flow.py | 2 +- src/info/OpenAGI/OpenAGI_Flow_gpt4gpt.txt | 6 + src/info/OpenAGI/OpenAGI_Flow_gpt4mixtral.txt | 8 + ...enAGI_Flow.txt => OpenAGI_Flow_manual.txt} | 0 src/info/OpenAGI/OpenAGI_Flow_mixtral4gpt.txt | 17 + .../OpenAGI/OpenAGI_Flow_mixtral4mixtral.txt | 12 + .../TravelPlanner_Flow_manual.txt | 45 ++ src/info/TravelPlanner/task_instruction.txt | 1 + src/info/TravelPlanner/tools.txt | 7 + src/main.py | 183 ++++-- src/travel_api/__init__.py | 0 src/travel_api/accommodations/__init__.py | 0 src/travel_api/accommodations/apis.py | 31 + src/travel_api/attractions/apis.py | 34 ++ src/travel_api/cities/apis.py | 23 + src/travel_api/flights/__init__.py | 0 src/travel_api/flights/apis.py | 70 +++ src/travel_api/googleDistanceMatrix/apis.py | 123 ++++ src/travel_api/notebook/apis.py | 40 ++ src/travel_api/notebook/test.py | 0 src/travel_api/planner/apis.py | 393 ++++++++++++ src/travel_api/planner/env.py | 202 ++++++ src/travel_api/planner/sole_planning.py | 112 ++++ src/travel_api/planner/test.py | 1 + src/travel_api/restaurants/__init__.py | 0 src/travel_api/restaurants/apis.py | 50 ++ src/utils/travel_commonsense_constraint.py | 575 ++++++++++++++++++ src/utils/travel_evaluation.py | 229 +++++++ src/utils/travel_hard_constraint.py | 275 +++++++++ src/utils/travel_utils.py | 128 ++++ 34 files changed, 2717 insertions(+), 111 deletions(-) create mode 100644 requirements_travel.txt rename src/{auto_agi.py => auto_main.py} (69%) create mode 100644 src/info/OpenAGI/OpenAGI_Flow_gpt4gpt.txt create mode 100644 src/info/OpenAGI/OpenAGI_Flow_gpt4mixtral.txt rename src/info/OpenAGI/{OpenAGI_Flow.txt => OpenAGI_Flow_manual.txt} (100%) create mode 100644 src/info/OpenAGI/OpenAGI_Flow_mixtral4gpt.txt create mode 100644 src/info/OpenAGI/OpenAGI_Flow_mixtral4mixtral.txt create mode 100644 src/info/TravelPlanner/TravelPlanner_Flow_manual.txt create mode 100644 src/info/TravelPlanner/task_instruction.txt create mode 100644 src/info/TravelPlanner/tools.txt create mode 100644 src/travel_api/__init__.py create mode 100644 src/travel_api/accommodations/__init__.py create mode 100644 src/travel_api/accommodations/apis.py create mode 100644 src/travel_api/attractions/apis.py create mode 100644 src/travel_api/cities/apis.py create mode 100644 src/travel_api/flights/__init__.py create mode 100644 src/travel_api/flights/apis.py create mode 100644 src/travel_api/googleDistanceMatrix/apis.py create mode 100644 src/travel_api/notebook/apis.py create mode 100644 src/travel_api/notebook/test.py create mode 100644 src/travel_api/planner/apis.py create mode 100644 src/travel_api/planner/env.py create mode 100644 src/travel_api/planner/sole_planning.py create mode 100644 src/travel_api/planner/test.py create mode 100644 src/travel_api/restaurants/__init__.py create mode 100644 src/travel_api/restaurants/apis.py create mode 100644 src/utils/travel_commonsense_constraint.py create mode 100644 src/utils/travel_evaluation.py create mode 100644 src/utils/travel_hard_constraint.py create mode 100644 src/utils/travel_utils.py diff --git a/.gitignore b/.gitignore index ac59399..7bea27f 100644 --- a/.gitignore +++ b/.gitignore @@ -161,5 +161,6 @@ cython_debug/ log openagi_data travel_database +results .DS_Store diff --git a/README.md b/README.md index c63755b..a25bb56 100644 --- a/README.md +++ b/README.md @@ -33,11 +33,14 @@ conda install pytorch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 pytorch-cuda= ``` pip install -r requirements.txt +pip install -r requirements_travel.txt ``` 3. Download the OpenAGI data from this [Google Drive link](https://drive.google.com/drive/folders/1AjT6y7qLIMxcmHhUBG5IE1_5SnCPR57e?usp=share_link), unzip it to the `AutoFlow` directory and rename it as `openagi_data`. -4. Make sure you are in the *AutoFlow/src* folder before running the codes. Otherwise, +4. Download the [database](https://drive.google.com/file/d/1pF1Sw6pBmq2sFkJvm-LzJOqrmfWoQgxE/view?usp=drive_link) and unzip it to the `AutoFlow` directory (i.e., `your/path/AutoFlow`) and rename it as `travel_database`. + +5. Make sure you are in the *AutoFlow/src* folder before running the codes. Otherwise, ``` cd src @@ -45,28 +48,75 @@ cd src ## Running Command Examples -OpenAGI on gpt-4-1106-preview: +(Notice that --model_name can be different from --auto_model_name) + +OpenAGI task when using gpt-4-1106-preview as the workflow interpreter LLM: ```commandline -python auto_agi.py +python auto_main.py +--flow_name=OpenAGI_Flow.txt +--task=OpenAGI --model_name="gpt-4-1106-preview" --auto_model_name="gpt-4-1106-preview" --log_file_name=../log/autoagi_gpt4gpt.txt --output_dir=./gpt4gpt --auto_flow_name="autoagi_gpt4gpt_Flow.txt" +--auto_epoch=30 --openai_key="YOUR OPENAI KEY" +--max_round=20 ``` -OpenAGI on TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ: +OpenAGI task when using TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ as the workflow interpreter LLM: ```commandline -python auto_agi.py +python auto_main.py +--flow_name=OpenAGI_Flow.txt +--task=OpenAGI --model_name="TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ" --auto_model_name="mistralai/Mixtral-8x7B-Instruct-v0.1" --log_file_name=../log/autoagi_mixtral4mixtral.txt --output_dir=./mixtral4mixtral --auto_flow_name="autoagi_mixtral4mixtral_Flow.txt" +--auto_epoch=30 +--openai_key="YOUR OPENAI KEY" +--max_round=20 +``` + +TravelPlanner task when using gpt-4-1106-preview as the workflow interpreter LLM: +```commandline +python auto_main.py +--flow_name=TravelPlanner_Flow_manual.txt +--tool_name=tools.txt +--task=TravelPlanner +--model_name="gpt-4-1106-preview" +--auto_model_name="gpt-4-1106-preview" +--log_file_name=../log/auto_travel_gpt4gpt.txt +--auto_flow_name=TravelPlanner_Flow_gpt4gpt.txt +--auto_epoch=30 --openai_key="YOUR OPENAI KEY" +--max_round=100 ``` +TravelPlanner task when using TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ as the workflow interpreter LLM: +```commandline +python auto_main.py +--flow_name=TravelPlanner_Flow_manual.txt +--tool_name=tools.txt +--task=TravelPlanner +--auto_model_name=gpt-4-1106-preview +--model_name=TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ +--log_file_name=../log/auto_travel_gpt4mixtral.txt +--auto_flow_name=TravelPlanner_Flow_gpt4mixtral.txt +--auto_epoch=30 +--max_round=100 +``` + +## Known Issues + +[PPOTrainer](https://huggingface.co/docs/trl/main/en/ppo_trainer) class in the latest version of [trl](https://github.com/huggingface/trl) package (==0.9.4) has a known [issue](https://github.com/huggingface/trl/issues/1691) when using multiple GPUs for reinforcement learning. + +## Generated Workflow + +In the `AutoFlow/src/info/OpenAGI` folder, there are manually designed workflow and automatically generated workflows. The manual workflow is `OpenAGI_Flow_manual.txt`, and the file name of automatically generated workflows is in the form of `OpenAGI_Flow_manual_*4*.txt`. For example, `OpenAGI_Flow_manual_gpt4mixtral.txt` means the workflow is generated by GPT and used for Mixtral as the interpreter LLM. + ## Reference -- We leveraged the dataset of [OpenAGI](https://github.com/agiresearch/OpenAGI) projects and based on [CoRE language] (https://github.com/agiresearch/CoRE) to implement our experiment. +- We leveraged the dataset of [OpenAGI](https://github.com/agiresearch/OpenAGI) and [TravelPlanner](https://github.com/OSU-NLP-Group/TravelPlanner) projects and based on [CoRE language] (https://github.com/agiresearch/CoRE) to implement our experiment. diff --git a/requirements_travel.txt b/requirements_travel.txt new file mode 100644 index 0000000..1244029 --- /dev/null +++ b/requirements_travel.txt @@ -0,0 +1,9 @@ +langchain==0.1.4 +pandas==2.0.1 +tiktoken==0.4.0 +openai==1.13.3 +langchain_google_genai==0.0.4 +gradio==3.50.2 +datasets==2.15.0 +tiktoken==0.4.0 +func_timeout==4.3.5 \ No newline at end of file diff --git a/src/auto_agi.py b/src/auto_main.py similarity index 69% rename from src/auto_agi.py rename to src/auto_main.py index fa258db..e644b39 100644 --- a/src/auto_agi.py +++ b/src/auto_main.py @@ -4,12 +4,13 @@ from pathlib import Path from traceback import print_exc +from datasets import load_dataset, DownloadMode from openai import OpenAI from vllm import LLM # from openagi_main import openagi_main as openagi_gpt_main # from mixtral_main import mixtral_main as openagi_mixtral_main -from main import main as openagi_main +from main import main as exe_main from utils.flow_utils import set_logger, ReadLineFromFile, get_response_from_client import random @@ -40,6 +41,21 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, BitsAndBytesConfig, AutoFeatureExtractor, MixtralForCausalLM from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead +from utils.travel_evaluation import eval_score +from utils.travel_utils import get_result_file + + +def load_exe_LLM(args): + if 'gpt' in args.model_name: + openai_key = args.openai_key + exe_client = OpenAI(api_key=openai_key) + elif 'gptq' in args.model_name.lower(): + exe_client = LLM(model=args.model_name, download_dir=args.cache_dir, quantization='gptq', enforce_eager=True, dtype=torch.float16, tensor_parallel_size=8)#gpu_memory_utilization=0.9)# , + else: + raise NotImplementedError + + return exe_client + def autoagi_gpt(args): task_description_list = ReadLineFromFile("../openagi_data/task_description.txt") task_description = task_description_list[0] @@ -60,21 +76,15 @@ def autoagi_gpt(args): f'[Key], [Step Name], and [Step Instruction] are all in the string form.\n' \ f'[Branch Step Name] should be appear as a unique [Step Name] in the workflow.\n' \ - if 'gpt' in args.model_name: - openai_key = args.openai_key - exe_client = OpenAI(api_key=openai_key) - elif 'gptq' in args.model_name.lower(): - exe_client = LLM(model=args.model_name, download_dir=args.cache_dir, quantization='gptq', enforce_eager=True, dtype=torch.float16, tensor_parallel_size=8)#gpu_memory_utilization=0.9)# , - else: - raise NotImplementedError + exe_client = load_exe_LLM(args) chat_history = [{'role': 'system', 'content': autoagi_instruction}, {'role': 'user', 'content': user_instruction}] manual_flow = '\n'.join(ReadLineFromFile(args.flow_file)) chat_history.append({'role': 'assistant', 'content': manual_flow}) - args.dataset = 'train' - baseline = openagi_main(args, exe_client) - args.dataset = 'test' - reward = openagi_main(args, exe_client) + args.dataset = args.set_type = 'train' + baseline = exe_main(args, exe_client) + args.dataset = args.set_type = 'test' + reward = exe_main(args, exe_client) logging.info(f'```\nReward:\n{reward}```\n') chat_history.append({'role': 'user', 'content': f'The execution performance of given workflow is {baseline}. ' f'Provide a new workflow in the same form of previous one.'}) @@ -91,18 +101,18 @@ def autoagi_gpt(args): chat_history.append({'role': 'assistant', 'content': res}) logging.info(f'```\nFlows:\n{res}```\n') try: - args.dataset = 'train' - reward = openagi_main(args, exe_client) - chat_history.append({'role': 'user', 'content': f'The execution performance of given workflow is {reward}. ' + args.dataset = args.set_type = 'train' + reward = exe_main(args, exe_client) + chat_history.append({'role': 'user', 'content': f'The execution performance of given workflow is {reward}.\n' f'Provide a new workflow in the same form of previous one.'}) logging.info(f'```\nReward:\n{reward}```\n') if reward > baseline: baseline = reward logging.info(f'\n\nNew Testing:\n\n') - args.dataset = 'test' - reward = openagi_main(args, exe_client) + args.dataset = args.set_type = 'test' + reward = exe_main(args, exe_client) logging.info(f'```\nTesting Reward:\n{reward}```\n') - args.dataset = 'train' + args.dataset = args.set_type = 'train' except Exception as e: print_exc() @@ -237,13 +247,7 @@ def autoagi_mixtral(args): args.flow_file = args.auto_flow_file baseline = -1.0 - if 'gpt' in args.model_name: - openai_key = args.openai_key - exe_client = OpenAI(api_key=openai_key) - elif 'gptq' in args.model_name.lower(): - exe_client = LLM(model=args.model_name, download_dir=args.cache_dir, quantization='gptq', enforce_eager=True, dtype=torch.float16, tensor_parallel_size=8)#gpu_memory_utilization=0.9)# , - else: - raise NotImplementedError + exe_client = load_exe_LLM(args) for epoch in range(args.auto_epochs): input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda() @@ -263,8 +267,8 @@ def autoagi_mixtral(args): logging.info(f'```\nFlows:\n{out_flow}```\n') try: - args.dataset = 'train' - reward = openagi_main(args, exe_client) + args.dataset = args.set_type = 'train' + reward = exe_main(args, exe_client) except Exception as e: print_exc() @@ -276,10 +280,10 @@ def autoagi_mixtral(args): if reward > baseline: baseline = reward logging.info(f'\n\nNew Testing:\n\n') - args.dataset = 'test' - reward = openagi_main(args, exe_client) + args.dataset = args.set_type = 'test' + reward = exe_main(args, exe_client) logging.info(f'```\nTesting Reward:\n{reward}```\n') - args.dataset = 'train' + args.dataset = args.set_type = 'train' Path(args.output_dir + f"/model/step_{epoch}").mkdir(parents=True, exist_ok=True) Path(args.output_dir + f"/ppo_trainer/step_{epoch}").mkdir(parents=True, exist_ok=True) model.save_pretrained(args.output_dir + f"/model/step_{epoch}") @@ -290,6 +294,105 @@ def autoagi_mixtral(args): train_stat = ppo_trainer.step([input_ids[0]], [output[0]], reward) logging.info(f'\n\nFinish step {epoch}!!!!!\n\n') + +def travel_weighted_score(eval_result): + ret = 0.0 + for k, v in eval_result.items(): + if k == 'Delivery Rate': + ret += v * 10 + elif k == 'Final Pass Rate': + ret += v * 100 + elif 'macro' in k.lower(): + ret += v * 5 + else: + ret += v + logging.info(f'Weighted Score: {ret}') + return ret + + +def auto_travel_gpt(args): + exe_client = load_exe_LLM(args) + + task_description_list = load_dataset('osunlp/TravelPlanner', 'test', download_mode=DownloadMode.FORCE_REDOWNLOAD, cache_dir=args.cache_dir)['test'] # Assume we will not test the testset + first_task_description = task_description_list[0]['query'] + second_task_description = task_description_list[-1]['query'] + + auto_travel_instruction = f'You are a proficient expert in designing workflows for complex task planning and can revise existing workflows based on their execution performances.\n' \ + f'Example task descriptions:\n```{first_task_description}```\n```{second_task_description}```\n\n' \ + f'An execution large language model will receive the task description as query, and then follow your generated workflow for providing plans as the task solution.\n' \ + f'You must provide a workflow each time, and the User will reply to you with the performances from the execution large language model. ' \ + f'The execution performance is a float number between 0 and 1; the higher, the better. ' + + user_instruction = f'Provide a workflow with several steps. Each step is a one-line string. ' \ + f'Each step is in the form of: ```[Step Name]:::[Step Type]:::[Step Instruction]:::[Step Branch]```\n' \ + f'[Step Type] could be "process", "terminal", or "decision".\n' \ + f'[Step Branch] consists of several branches. Each branch is in the form of "[Key]::[Branch Step Name]" and separated by "::".\n' \ + f'Note: "process" step has exactly one branch, with "next" as the key; "decision" step has more than one branches; ' \ + f'"terminal" step has zero branch, indicating the end of working flow, but there could be multiple "terminal" steps.\n' \ + f'At least one "terminal" step exists, meaning the last step of the workflow!\n' \ + f'[Key], [Step Name], and [Step Instruction] are all in the string form.\n' \ + f'[Branch Step Name] should be appear as a unique [Step Name] in the workflow.\n' \ + + chat_history = [{'role': 'system', 'content': auto_travel_instruction}, {'role': 'user', 'content': user_instruction}] + manual_flow = '\n'.join(ReadLineFromFile(args.flow_file)) + chat_history.append({'role': 'assistant', 'content': manual_flow}) + + args.dataset = args.set_type = 'train' + args.results_name = f'auto_{args.auto_model_name}_0' + if not os.path.exists(get_result_file(args)): + exe_main(args, exe_client) + eval_result = eval_score(args.dataset, get_result_file(args), cache_dir=args.cache_dir)[0] + baseline_sentence = ', '.join([f'{k} is {v}' for k, v in eval_result.items()]) + baseline = travel_weighted_score(eval_result) + args.dataset = args.set_type = 'validation' + if not os.path.exists(get_result_file(args)): + exe_main(args, exe_client) + reward = eval_score(args.dataset, get_result_file(args), cache_dir=args.cache_dir) + logging.info(f'```\nReward:\n{reward}```\n') + travel_weighted_score(reward[0]) + chat_history.append({'role': 'user', 'content': f'The execution performance of given workflow is shown as below:\n{baseline_sentence}.\n' + f'Provide a new workflow in the same form of previous one.'}) + + openai_key = args.openai_key + client = OpenAI(api_key=openai_key) + args.flow_file = args.auto_flow_file + + for epoch in range(args.auto_epochs): + args.results_name = f'auto_{args.auto_model_name}_{epoch + 1}' + res = get_response_from_client(client, chat_history, args.auto_model_name)[0] + fout = open(args.flow_file, 'w') + fout.write(res) + fout.close() + chat_history.append({'role': 'assistant', 'content': res}) + logging.info(f'```\nFlows:\n{res}```\n') + try: + args.dataset = args.set_type = 'train' + exe_main(args, exe_client) + eval_result = eval_score(args.dataset, get_result_file(args), cache_dir=args.cache_dir) + logging.info(f'```\nReward:\n{eval_result}```\n') + eval_result = eval_result[0] + reward_sentence = ', '.join([f'{k} is {v}' for k, v in eval_result.items()]) + reward = travel_weighted_score(eval_result) + chat_history.append({'role': 'user', 'content': f'The execution performance of given workflow is {reward_sentence}. ' + f'Provide a new workflow in the same form of previous one.'}) + if reward > baseline: + baseline = reward + logging.info(f'\n\nNew Testing:\n\n') + args.dataset = args.set_type = 'validation' + exe_main(args, exe_client) + reward = eval_score(args.dataset, get_result_file(args), cache_dir=args.cache_dir) + logging.info(f'```\nReward:\n{reward}```\n') + travel_weighted_score(reward[0]) + args.dataset = args.set_type = 'train' + + except Exception as e: + print_exc() + chat_history.append({'role': 'user', 'content': f'When executing the workflow, there is an error: {e}. \n' + f'Please re-generate a new workflow in the same form of the first one.'}) + logging.info(f'```\nError:\n{e}```\n') + logging.info(f'Final Chat History:\n\n```{chat_history}\n\n```') + + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--openai_key", type=str, default='') @@ -298,10 +401,11 @@ def autoagi_mixtral(args): parser.add_argument("--cache_dir", type=str, default='../cache_dir/') parser.add_argument("--log_file_name", type=str, default='../log/autoagi.txt') parser.add_argument("--log_dir", type=str, default='') + parser.add_argument("--task", type=str, default='OpenAGI') # parser.add_argument("--tool_file", type=str, default='./info/OpenAGI/tools.txt') parser.add_argument("--tool_name", type=str, default='tools.txt') - # parser.add_argument("--flow_file", type=str, default='./info/OpenAGI/OpenAGI_Flow.txt') - parser.add_argument("--flow_name", type=str, default='OpenAGI_Flow.txt') + # parser.add_argument("--flow_file", type=str, default='./info/OpenAGI/OpenAGI_Flow_manual.txt') + parser.add_argument("--flow_name", type=str, default='OpenAGI_Flow_manual.txt') parser.add_argument("--auto_flow_name", type=str, default='auto_OpenAGI_Flow.txt') parser.add_argument("--seed", type=int, default=42) parser.add_argument("--auto_epochs", type=int, default=10) @@ -311,7 +415,6 @@ def autoagi_mixtral(args): parser.add_argument("--dataset", type=str, default='test') parser.add_argument("--auto_model_name", type=str, default='gpt-4-1106-preview', help='Flow Generator LLM name.') parser.add_argument("--model_name", type=str, default='gpt-4-1106-preview', help='Execution LLM name.') - parser.add_argument("--task", type=str, default='OpenAGI') parser.add_argument("--other_info_name", type=str, default='other_info.txt') parser.add_argument("--max_fail_times", type=int, default=2, help='Max allow fail times on tools arg choice') parser.add_argument("--get_observation", type=str, default='traverse', @@ -330,6 +433,8 @@ def autoagi_mixtral(args): parser.add_argument("--accumulate_steps", type=int, default=1) parser.add_argument("--warm_up_proportion", type=float, default=0.1) parser.add_argument("--output_dir", type=str, default='./') + parser.add_argument("--sample_query", type=int, default=1) + parser.add_argument("--random_sample_query", type=int, default=0) args = parser.parse_known_args()[0] args.device_list = ["cuda:0", "cpu"] @@ -342,9 +447,19 @@ def autoagi_mixtral(args): args.tool_file = os.path.join(args.info_dir, args.task, args.tool_name) args.other_file = os.path.join(args.info_dir, args.task, args.other_info_name) - if 'tral' in args.auto_model_name: - autoagi_mixtral(args) - elif 'gpt' in args.auto_model_name: - autoagi_gpt(args) + if args.task == 'OpenAGI': + if 'tral' in args.auto_model_name: + autoagi_mixtral(args) + elif 'gpt' in args.auto_model_name: + autoagi_gpt(args) + else: + raise NotImplementedError + elif args.task == 'TravelPlanner': + if 'tral' in args.auto_model_name: + raise NotImplementedError + elif 'gpt' in args.auto_model_name: + auto_travel_gpt(args) + else: + raise NotImplementedError else: raise NotImplementedError \ No newline at end of file diff --git a/src/flow/flow.py b/src/flow/flow.py index 302d778..34e7594 100644 --- a/src/flow/flow.py +++ b/src/flow/flow.py @@ -49,5 +49,5 @@ def __str__(self): return flow_str if __name__ == '__main__': - flow = Flow('../OpenAGI_Flow.txt') + flow = Flow('../OpenAGI_Flow_manual.txt') print(flow.__str__()) \ No newline at end of file diff --git a/src/info/OpenAGI/OpenAGI_Flow_gpt4gpt.txt b/src/info/OpenAGI/OpenAGI_Flow_gpt4gpt.txt new file mode 100644 index 0000000..4727b26 --- /dev/null +++ b/src/info/OpenAGI/OpenAGI_Flow_gpt4gpt.txt @@ -0,0 +1,6 @@ +Step 1:::Process:::Establish the main purpose of the project.:::next::Step 2 +Step 2:::Process:::Determine necessary data types for input-output based on the established purpose.:::next::Step 3 +Step 3:::Decision:::Examine the known list of models and select ones which cater to the identified data types.:::Matching Model Available::Step 4:::Matching Model Unavailable::Step 1 +Step 4:::Process:::Ingest the chosen model into the plan, ensuring a fluid transition of data from one model to the next.:::next::Step 5 +Step 5:::Decision:::Assess whether the model sequence in the plan enables a smooth transition of data types.:::Smooth Transition Confirmed::Step 6:::Smooth Transition Not Confirmed::Step 1 +Step 6:::Terminal:::Formulate and present the final workflow detailing the chosen models and their order of execution.::: \ No newline at end of file diff --git a/src/info/OpenAGI/OpenAGI_Flow_gpt4mixtral.txt b/src/info/OpenAGI/OpenAGI_Flow_gpt4mixtral.txt new file mode 100644 index 0000000..b95a2b2 --- /dev/null +++ b/src/info/OpenAGI/OpenAGI_Flow_gpt4mixtral.txt @@ -0,0 +1,8 @@ +Step 1:::Process:::Start by understanding the task requirement and identifying the input data type.:::next::Step 2 +Step 2:::Process:::Clearly define the desired output data type, in accordance with the task description.:::next::Step 3 +Step 3:::Process:::Conduct preliminary model selection from the available models list. The aim is to create a tentative sequence of models that starts with the input data type and ends with the output data type.:::next::Step 4 +Step 4:::Decision:::Confirm that all models in the tentative sequence are part of the provided models list.:::Yes::Step 5::No::Step 6 +Step 5:::Decision:::Verify the compatibility between model sequence outputs and inputs. Ensure the output from one model in the sequence seamlessly feeds into the next model in line.:::Confirmed compatibility::Step 7::Incompatibility found::Step 8 +Step 6:::Terminal:::If the tentative sequence includes models not in the provided list, terminate the process and notify the user of the task impossibility due to model restrictions.::: +Step 7:::Terminal:::On confirmation of model sequence compatibility, finalize the sequence and output it.::: +Step 8:::Process:::In case of incompatibilities between model outputs and inputs in the sequence, rearrange the models or add intermediary models to ensure smooth data transition. Return to Step 4 for validation.:::next::Step 4 \ No newline at end of file diff --git a/src/info/OpenAGI/OpenAGI_Flow.txt b/src/info/OpenAGI/OpenAGI_Flow_manual.txt similarity index 100% rename from src/info/OpenAGI/OpenAGI_Flow.txt rename to src/info/OpenAGI/OpenAGI_Flow_manual.txt diff --git a/src/info/OpenAGI/OpenAGI_Flow_mixtral4gpt.txt b/src/info/OpenAGI/OpenAGI_Flow_mixtral4gpt.txt new file mode 100644 index 0000000..16154ee --- /dev/null +++ b/src/info/OpenAGI/OpenAGI_Flow_mixtral4gpt.txt @@ -0,0 +1,17 @@ +Read Input Data:::process:::Read input data.:::next::Determine Input Format +Determine Input Format:::process:::Determine input format.:::next::Identify Algorithm +Identify Algorithm:::process:::Identify algorithm based on format.:::next::Check Algorithm Availability +Check Algorithm Availability:::process:::Check availability of algorithm in execution environment.:::next::Set Algorithm Parameters +Set Algorithm Parameters:::process:::Set up algorithm configuration parameters.:::next::Initialize Algorithm Environment +Initialize Algorithm Environment:::process:::Initialize algorithm environment.:::next::Run Algorithm +Run Algorithm:::process:::Run algorithm to process input data.:::next::Analyze Output Data +Analyze Output Data:::process:::Analyze output data.:::next::Format Output Data +Format Output Data:::process:::Format output data.:::next::Return Output Data +Return Output Data:::terminal:::Return output data.::: +Check Input Data Type:::process:::Check if input data is text.:::next::Identify Language +Identify Language:::process:::Identify language of input data.:::next::Identify User Intent +Identify User Intent:::process:::Identify intent of the user input query.:::next::Identify Query Source +Identify Query Source:::process:::Identify appropriate source for query.:::next::Execute Query +Execute Query:::process:::Execute query against identified source.:::next::Format Query Response +Format Query Response:::process:::Format query response.:::next::Return Response +Return Response:::terminal:::Return response.::: \ No newline at end of file diff --git a/src/info/OpenAGI/OpenAGI_Flow_mixtral4mixtral.txt b/src/info/OpenAGI/OpenAGI_Flow_mixtral4mixtral.txt new file mode 100644 index 0000000..faf8053 --- /dev/null +++ b/src/info/OpenAGI/OpenAGI_Flow_mixtral4mixtral.txt @@ -0,0 +1,12 @@ +Identify Source Data Type:::process:::identify the source data type based on the input.:::next::Identify Desired Output Data Type +Identify Desired Output Data Type:::process:::identify the desired output data type based on the input.:::next::Create Data Type Mapping +Create Data Type Mapping:::process:::Create a mapping between data types based on the source data type and desired output data type.:::next::Check Data Types Availability +Check Data Types Availability:::decision:::Check whether all data types in the mapping is available in the provided models.:::Yes::First Data Type Check::No::Create Data Type Mapping +First Data Type Check:::decision:::Check whether the first data type in the mapping is the same as the source data type.:::Yes::Last Data Type Check::No::Create Data Type Mapping +Last Data Type Check:::decision:::Check whether the last data type in the mapping is the same as the desired output data type.:::Yes::Output Mapping::No::Create Data Type Mapping +Output Mapping:::terminal:::Output the mapping.::: +Select First Model:::process:::Select the first model in the mapping, and create a mapping between data types.:::next::Check Model Data Types Availability +Check Model Data Types Availability:::decision:::Check whether all data types in the mapping is available in the provided models.:::Yes::Select Next Model::No::Select First Model +Select Next Model:::process:::Select the next model in the mapping, and create a new mapping between data types.:::next::Check Model Compatibility +Check Model Compatibility:::decision:::Check whether the current model's output data type is the same as the next model's input data type.:::Yes::Select First Model::No::Return Model Sequence +Return Model Sequence:::terminal:::Return the sequence of models and the corresponding mapping between data types.::: \ No newline at end of file diff --git a/src/info/TravelPlanner/TravelPlanner_Flow_manual.txt b/src/info/TravelPlanner/TravelPlanner_Flow_manual.txt new file mode 100644 index 0000000..15c2a23 --- /dev/null +++ b/src/info/TravelPlanner/TravelPlanner_Flow_manual.txt @@ -0,0 +1,45 @@ +Step 1:::Process:::Determine date range, deprature city and destination in task description.:::next::Step 2 +Step 2:::Decision:::Is the destination a state or a city?:::city::step 7::state::step 3 +Step 3:::Process:::Find city list from the destination state:::next::step 4 +Step 4:::Decision:::Based on the input query, determine the duration:::3 days::Step 7::5 days::Step 5::7 days::Step 6 +Step 5:::Process:::Select two cities from the destination state for further exploration in the travel itinerary, ensuring they fit within the travel plans:::next::Step 7 +Step 6:::Process:::Select three cities from the destination state for further exploration in the travel itinerary, ensuring they fit within the travel plans:::next::Step 7 +Step 7:::Process:::Estimate the cost of taking a taxi from departure city to the first destination city.:::next::Step 8 +Step 8:::Process:::Estimate the cost of self-driving from departure city to the first destination city.:::next::Step 9 +Step 9:::Process:::Estimate the cost of taking a flight on the first date from departure city to the first destination city.:::next::Step 10 +Step 10:::Process:::Select the most suitable transportation among taxi, self-driving, and flight to the first destination city for the first two days of this trip. The selection should be constrained by preferences and budget detailed in the task description, avoiding scheduling conflicts.:::next::step 11 +Step 11:::Process:::Find restaurant list at the first destination city.:::next::Step 12 +Step 12:::Process:::Select suitable restaurants of breakfast, lunch, and dinner in the first destination city for first two days in this trip. The selection should be constrained by the budget and preferences detailed in the task description, avoiding duplicates and scheduling conflicts.:::next::step 13 +Step 13:::Process:::Find attraction list at the first destination city.:::next::Step 14 +Step 14:::Process:::Select one attraction for each day in the first destination city for first two days in this trip. The selection should be constrained by the budget and preferences detailed in the task description, avoiding duplicates and scheduling conflicts.:::next::step 15 +Step 15:::Process:::Find accommodation list at at the first destination city.:::next::Step 16 +Step 16:::Process:::Select accommodation for the first destination city for first two days in this trip. The selection should be constrained by the budget and preferences detailed in the task description.:::next::step 17 +Step 17:::Decision:::Is this a 3-days trip?:::Yes::Step 39::No::Step 18 +Step 18:::Process:::Estimate the cost of taking a taxi from the first destination city to the second destination city.:::next::Step 19 +Step 19:::Process:::Estimate the cost of self-driving from the first destination city to the second destination city.:::next::Step 20 +Step 20:::Process:::Estimate the cost of taking a flight on the third date from the first destination city to the second destination city.:::next::Step 21 +Step 21:::Process:::Select the most suitable transportation among taxi, self-driving, and flight to the second destination city for day 3 and day 4 of this trip. The selection should be constrained by preferences and budget detailed in the task description, avoiding scheduling conflicts.:::next::step 22 +Step 22:::Process:::Find restaurant list at the second destination city.:::next::Step 23 +Step 23:::Process:::Select suitable restaurants of breakfast, lunch, and dinner in the second destination city for day 3 and day 4 in this trip. The selection should be constrained by the budget and preferences detailed in the task description, avoiding duplicates and scheduling conflicts.:::next::step 24 +Step 24:::Process:::Find attraction list at the second destination city.:::next::Step 25 +Step 25:::Process:::Select one attraction for each day in the second destination city for day 3 and day 4 in this trip. The selection should be constrained by the budget and preferences detailed in the task description, avoiding duplicates and scheduling conflicts.:::next::step 26 +Step 26:::Process:::Find accommodation list at at the second destination city.:::next::Step 27 +Step 27:::Process:::Select accommodation for the second destination city for day 3 and day 4 in this trip. The selection should be constrained by the budget and preferences detailed in the task description.:::next::step 28 +Step 28:::Decision:::Is this a 5-days trip?:::Yes::Step 39::No::Step 29 +Step 29:::Process:::Estimate the cost of taking a taxi from the second destination city to the third destination city.:::next::Step 30 +Step 30:::Process:::Estimate the cost of self-driving from the second destination city to the third destination city.:::next::Step 31 +Step 31:::Process:::Estimate the cost of taking a flight on the fifth date from the second destination city to the third destination city.:::next::Step 32 +Step 32:::Process:::Select the most suitable transportation among taxi, self-driving, and flight to the third destination city for day 5 and day 6 of this trip. The selection should be constrained by preferences and budget detailed in the task description, avoiding scheduling conflicts.:::next::step 33 +Step 33:::Process:::Find restaurant list at the third destination city.:::next::Step 34 +Step 34:::Process:::Select suitable restaurants of breakfast, lunch, and dinner in the third destination city for day 5 and day 6 in this trip. The selection should be constrained by the budget and preferences detailed in the task description, avoiding duplicates and scheduling conflicts.:::next::step 35 +Step 35:::Process:::Find attraction list at the third destination city.:::next::Step 36 +Step 36:::Process:::Select one attraction for each day in the third destination city for day 5 and day 6 in this trip. The selection should be constrained by the budget and preferences detailed in the task description, avoiding duplicates and scheduling conflicts.:::next::step 37 +Step 37:::Process:::Find accommodation list at at the third destination city.:::next::Step 38 +Step 38:::Process:::Select accommodation for the third destination city for day 5 and day 6 in this trip. The selection should be constrained by the budget and preferences detailed in the task description.:::next::step 39 +Step 39:::Process:::Estimate the cost of taking a taxi from the last destination city back to the departure city.:::next::Step 40 +Step 40:::Process:::Estimate the cost of self-driving from the last destination city back to the departure city.:::next::Step 41 +Step 41:::Process:::Estimate the cost of taking a flight on the first date from the last destination city back to the departure city.:::next::Step 42 +Step 42:::Process:::Select the most suitable transportation among taxi, self-driving, and flight to the departure city for last day of this trip. The selection should be constrained by preferences and budget detailed in the task description, avoiding scheduling conflicts.:::next::step 43 +Step 43:::Process:::Select suitable restaurants of breakfast, lunch, and dinner in the last destination city for the last day in this trip. The selection should be constrained by the budget and preferences detailed in the task description, avoiding duplicates and scheduling conflicts.:::next::step 44 +Step 44:::Process:::Select one attraction in the last destination city for the last day in this trip. The selection should be constrained by the budget and preferences detailed in the task description, avoiding duplicates and scheduling conflicts.:::next::step 45 +Step 45:::Terminal:::Output the whole plans for all days.::: \ No newline at end of file diff --git a/src/info/TravelPlanner/task_instruction.txt b/src/info/TravelPlanner/task_instruction.txt new file mode 100644 index 0000000..fdbd70e --- /dev/null +++ b/src/info/TravelPlanner/task_instruction.txt @@ -0,0 +1 @@ +You are a proficient planner . Based on the provided information and query , please give me a detailed plan , including specifics such as flight numbers , restaurant names , and hotel names . Note that all the information in your plan should be derived from the provided data . You must adhere to the format given in the example . Additionally , all details should align with common sense . Attraction visits and meals are expected to be diverse . The symbol '-' indicates that information is unnecessary . For example , in the provided sample , you do not need to plan after returning to the departure city . When you travel to two cities in one day , you should note it in the 'Current City ' section as in the example ( i . e . , from A to B ) . \ No newline at end of file diff --git a/src/info/TravelPlanner/tools.txt b/src/info/TravelPlanner/tools.txt new file mode 100644 index 0000000..f61c7ed --- /dev/null +++ b/src/info/TravelPlanner/tools.txt @@ -0,0 +1,7 @@ +Provided tools: +FlightSearch [ Departure City , Destination City , Date ]: Description : A flight information retrieval tool . Parameters : Departure City : The city you ' ll be flying out from . Destination City : The city you aim to reach . Date : The date of your travel in YYYY -MM - DD format . Example : FlightSearch [ New York , London , 2022 -10 -01] would fetch flights from New York to London on October 1, 2022. +GoogleDistanceMatrix [ Origin , Destination , Mode ]: Description : Estimate the distance , time and cost between two cities . Parameters : Origin : The departure city of your journey . Destination : The destination city of your journey . Mode : The method of transportation . Choices include 'self - driving ' and 'taxi '. Example : DistanceMatrix [ Paris , Lyon , self - driving ] would provide driving distance , time and cost between Paris and Lyon . +AccommodationSearch [ City ]: Description : Discover accommodations in your desired city . Parameter : City - The name of the city where you 're seeking accommodation . Example : AccommodationSearch [ Rome ] would present a list of hotel rooms in Rome . +RestaurantSearch [ City ]: Description : Explore dining options in a city of your choice . Parameter : City - The name of the city where you 're seeking restaurant . Example : RestaurantSearch [ Tokyo ] would show a curated list of restaurants in Tokyo . +CitySearch [ State ]: Description : Find cities in a state of your choice . Parameter : State - The name of the city where you 're seeking cities . Example : CitySearch [ California ] would return cities in California . +AttractionSearch [ City ]: Description : Find attractions in a city of your choice . Parameter : City - The name of the city where you 're seeking attractions . Example : AttractionSearch [ London ] would return attractions in London . \ No newline at end of file diff --git a/src/main.py b/src/main.py index f337f55..35773d5 100644 --- a/src/main.py +++ b/src/main.py @@ -3,7 +3,7 @@ import os import numpy as np -from datasets import load_dataset +from datasets import load_dataset, DownloadMode from tqdm import tqdm @@ -11,6 +11,7 @@ from utils.notebook import Notebook +from utils.travel_utils import convert_to_json_with_gpt, get_result_file, write_result_into_file, get_baseline_result from utils.flow_utils import ReadLineFromFile, get_prompt, get_observation, notebook_summarize, get_response_from_client, check_tool_use, check_tool_name, \ get_tool_arg, check_branch, set_logger from flow.flow import Flow @@ -23,6 +24,12 @@ from torchvision import transforms from torchmetrics.multimodal import CLIPScore +from travel_api.flights.apis import FlightSearch +from travel_api.accommodations.apis import AccommodationSearch +from travel_api.restaurants.apis import RestaurantSearch +from travel_api.googleDistanceMatrix.apis import GoogleDistanceMatrix +from travel_api.attractions.apis import AttractionSearch +from travel_api.cities.apis import CitySearch from vllm import LLM, SamplingParams import torch @@ -30,10 +37,7 @@ from transformers import AutoModel, AutoFeatureExtractor - - -# from evaluation import OpenAGI_evaluate - +from utils.travel_evaluation import eval_score def global_args(): parser = argparse.ArgumentParser() @@ -50,22 +54,27 @@ def global_args(): parser.add_argument("--tool_name", type=str, default='tools.txt') parser.add_argument("--other_info_name", type=str, default='other_info.txt') parser.add_argument("--log_dir", type=str, default='../log/') - parser.add_argument("--set_type", type=str, default='validation') + parser.add_argument("--dataset", type=str, default='validation') + parser.add_argument("--avoid_dup_tool_call", action='store_true') parser.add_argument("--seed", type=int, default=42) parser.add_argument("--get_observation", type=str, default='traverse', help='How to get observations, "traverse" stands for asking one by one, "direct" stands for directly asking.') parser.add_argument("--batch_size", type=int, default=5) parser.add_argument("--max_fail_times", type=int, default=2, help='Max allow fail times on tools arg choice') parser.add_argument("--max_round", type=int, default=100, help='Max allow round of executions') parser.add_argument("--log_file_name", type=str, default='travelplanner.txt') + parser.add_argument("--sample_query", type=int, default=0) + parser.add_argument("--random_sample_query", type=int, default=0) args = parser.parse_known_args()[0] return args -def finish_one_task(instruction, tool_info, other_info, flow, task_idx, query, tool_list, notebook, args, client): - +def finish_one_task(client, instruction, tool_info, other_info, flow, task_idx, query, tool_list, notebook, args): + notebook.reset() + if args.task == "TravelPlanner": + result_file = get_result_file(args) - plan_round = 0 + plan_round = 1 flow_ptr = flow.header logging.info(f'```\ntask id:\n{task_idx}```\n') logging.info(f'```\nquery:\n{query}```\n') @@ -78,10 +87,13 @@ def finish_one_task(instruction, tool_info, other_info, flow, task_idx, query, t return_res = dict() while True: if plan_round >= args.max_round: - return_res['reward'] = 0.0 + if args.task == "TravelPlanner": + current_interaction = '\n'.join(current_progress) + '\n' + '\n'.join(notebook.list_all_str()) + result, price = convert_to_json_with_gpt(current_interaction, args.openai_key) + total_price += price + submit_result = {"idx":task_idx,"query":query,"plan":result} + write_result_into_file(submit_result, result_file) break - plan_round += 1 - chat_history = [] if isinstance(instruction, str): chat_history.append({ @@ -95,7 +107,7 @@ def finish_one_task(instruction, tool_info, other_info, flow, task_idx, query, t # generate prompt prompt = get_prompt(tool_info, flow_ptr, query, current_progress, observations, args.model_name, other_info) - logging.info(f'Prompt: \n```\n{prompt}\n```') + logging.info(f'Input Prompt: \n```\n{prompt}\n```') chat_history.append({ 'role': 'user', 'content': prompt @@ -115,21 +127,24 @@ def finish_one_task(instruction, tool_info, other_info, flow, task_idx, query, t # current_progress.append(f'Answer {plan_round}: ```{res}```') # check tool use - if len(tool_list) > 0: + try: + tool_use, price = check_tool_use(client, '\n'.join(tool_calling_list), flow_ptr, str(res), tool_info, model_name=args.model_name) + total_price += price + except Exception as e: + logging.error(f"Error when checking tool use: {e}") + tool_use = False + + if tool_use: try: - tool_use, price = check_tool_use(client, '\n'.join(tool_calling_list), flow_ptr, str(res), tool_info, model_name=args.model_name) + tool_name, price= check_tool_name(client, flow_ptr, str(res), list(tool_list.keys()), model_name=args.model_name) total_price += price + tool = tool_list[tool_name] except Exception as e: - logging.error(f"Error when checking tool use: {e}") + logging.error(f"Error when getting tool name: {e}") tool_use = False - - if tool_use: + else: for k in range(args.max_fail_times): try: - tool_name, price = check_tool_name(client, flow_ptr, str(res), list(tool_list.keys()), - model_name=args.model_name) - total_price += price - tool = tool_list[tool_name] param, price = get_tool_arg(client, flow_ptr, str(res), tool_info, tool_name, model_name=args.model_name) total_price += price if param == 'None': @@ -138,9 +153,10 @@ def finish_one_task(instruction, tool_info, other_info, flow, task_idx, query, t param_sep = [p.strip() for p in param.strip().split(',')] tool_result = tool.run(*param_sep) tool_calling = f'{tool_name} [ {param} ]' - if tool_calling in tool_calling_list: - current_progress.append(f'Answer {plan_round}: ```{res}```') - break + if args.avoid_dup_tool_call: + if tool_calling in tool_calling_list: + current_progress.append(f'Answer {plan_round}: ```{res}```') + break tool_calling_list.append(tool_calling) short_summary, price = notebook_summarize(client, tool_info, tool_calling, args.model_name) total_price += price @@ -153,18 +169,21 @@ def finish_one_task(instruction, tool_info, other_info, flow, task_idx, query, t if k + 1 == args.max_fail_times: # Max Fail attempts logging.error('Reach Max fail attempts on Get Tool Parameters.') # if reach max fail attempts, do not use tool in this step. - current_progress.append(f'Answer {plan_round}: ```{res}```') - return_res['reward'] = 0.0 - return total_price, return_res - # exit(1) + # current_progress.append(f'Answer {plan_round}: ```{res}```') + tool_use = False + break else: continue - else: - current_progress.append(f'Answer {plan_round}: ```{res}```') - # output_record.append(None) + if not tool_use: + current_progress.append(f'Answer {plan_round}: ```{res}```') # terminate condition if len(flow_ptr.branch) == 0 and flow_ptr.type.lower() == 'terminal': + if args.task == 'TravelPlanner': + result, price = convert_to_json_with_gpt(str(res), args.openai_key) + total_price += price + submit_result = {"idx":task_idx,"query":query,"plan":result} + write_result_into_file(submit_result, result_file) if args.task == 'OpenAGI': eval_device = "cuda:0" clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") @@ -180,7 +199,7 @@ def finish_one_task(instruction, tool_info, other_info, flow, task_idx, query, t dataset = GeneralDataset(task_idx, data_path) dataloader = DataLoader(dataset, batch_size=args.batch_size) seq_com = SeqCombine(args) - module_list = parse_module_list_with_gpt(args, res).split(',') + module_list = parse_module_list_with_gpt(client, res).split(',') sentence_model = SentenceTransformer('all-MiniLM-L6-v2', device="cpu") module_list = match_module_seq(module_list, sentence_model).split(',') print(module_list) @@ -217,52 +236,74 @@ def finish_one_task(instruction, tool_info, other_info, flow, task_idx, query, t else: try: branch, price = check_branch(client, res, flow_ptr, model_name=args.model_name) - except AssertionError: - return_res['reward'] = 0.0 - break - total_price += price + total_price += price + except Exception as e: + logging.error(f"Error when checking branch: {e}") + branch = list(flow_ptr.branch.keys())[0] flow_ptr = flow_ptr.branch[branch] logging.info(f'Current Block: \n```\n{flow_ptr}```') - + plan_round += 1 + logging.info(f'The price for task {task_idx} is {total_price}') return total_price, return_res + def load_query(args): if args.task == 'OpenAGI': task_description = ReadLineFromFile("../openagi_data/task_description.txt") - if 'test' in args.dataset: - idx_list = [2,3,10,15,20,35,45,55,65,70,90,106,107] - elif 'train' in args.dataset: - idx_list = [5, 7, 13, 25, 40, 50, 60, 75, 80, 95, 100, 105, 109] + return [(i, task_description[i+1]) for i in range(len(task_description))] + + elif args.task == 'TravelPlanner': + query_data_list = load_dataset('osunlp/TravelPlanner', args.set_type, download_mode=DownloadMode.FORCE_REDOWNLOAD, cache_dir=args.cache_dir)[args.set_type] + if args.sample_query: + assert args.random_sample_query == 0 # Not Implement random sampling + levels, days = ['easy', 'medium', 'hard'], [3, 5, 7] + task_ids = [] + for level in levels: + for day in days: + for idx, query_data in enumerate(query_data_list): + if query_data['level'] == level and query_data['days'] == day: + task_ids.append(idx) + break + print(f'Sampled Task IDs: {task_ids}') + ret = [(i, query_data_list[i]['query']) if i in task_ids else (i, None) for i in range(len(query_data_list))] else: - raise NotImplementedError - return [(i, task_description[i+1]) for i in idx_list] + ret = [(i, query_data_list[i]['query']) for i in range(len(query_data_list))] + return ret + else: raise NotImplementedError - + + def load_tool(args): if args.task == 'OpenAGI': return "", {} - else: - return "", {} + elif args.task == 'TravelPlanner': + tool_info_list = ReadLineFromFile(args.tool_file) + tool_name_list = [tool_description.split()[0] for tool_description in tool_info_list[1:]] + tool_info = '\n'.join(tool_info_list) + + # create tool_list, tool name as the key and tool as value + tool_list = dict() + for tool_name in tool_name_list: + try: + tool_list[tool_name] = globals()[tool_name]() + except: + raise Exception(f"{tool_name} is not found") + return tool_info, tool_list + def load_other_info(args): if args.task == 'OpenAGI': - other_info_list = ReadLineFromFile(args.other_file) + other_info_list = ReadLineFromFile(args.tool_file) other_info = '\n'.join(other_info_list) return other_info - else: + elif args.task == 'TravelPlanner': return "" -def main(args, client): - # args = global_args() - # args.log_name = os.path.join(args.log_dir, args.log_file_name) - # set_logger(args) - - # load flow - # args.flow_file = os.path.join(args.info_dir, args.task, args.flow_name) +def main(args, client): flow = Flow(args.flow_file) logging.info(f'```\nFlows:\n{flow}```\n') @@ -283,7 +324,7 @@ def main(args, client): tool_info, tool_list = load_tool(args) else: tool_info, tool_list = "", dict() - logging.info(f'```\ntool_info:\n{tool_info}```\n') + logging.info(f'```\ntool_info:\n{tool_info}\n```\n') # load other_info args.other_file = os.path.join(args.info_dir, args.task, args.other_info_name) @@ -291,7 +332,7 @@ def main(args, client): other_info = load_other_info(args) else: other_info = "" - logging.info(f'```\nother_info:\n{tool_info}```\n') + logging.info(f'```\nother_info:\n{other_info}\n```\n') # Create a notebook to save observations notebook = Notebook() @@ -305,20 +346,25 @@ def main(args, client): similairies = [] valid = [] - return_res = None # Answer every query for idx, query in task_query: + if query is None: + assert args.task == 'TravelPlanner' + result_file = get_result_file(args) + copied_baseline = get_baseline_result(args, idx) + write_result_into_file(copied_baseline, result_file, is_string=True) + continue try: - price, return_res = finish_one_task(instruction, tool_info, other_info, flow, idx, query, tool_list, notebook, args, client) + price, return_res = finish_one_task(client, instruction, tool_info, other_info, flow, idx, query, tool_list, notebook, args) total_price += price except Exception as e: logging.error(f"Error when answering {query}: {e}") + if args.task == 'TravelPlanner': + result_file = get_result_file(args) + submit_result = {"idx":idx,"query":query,"plan":None} + write_result_into_file(submit_result, result_file) if args.task == 'OpenAGI': - if return_res is None: - ave_task_reward = 0.0 - else: - ave_task_reward = return_res['reward'] - + ave_task_reward = return_res['reward'] if 0 <= idx <= 14: similairies.append(ave_task_reward) elif 15 <= idx <= 104 or 107 <= idx <= 184: @@ -336,7 +382,9 @@ def main(args, client): logging.info(f'The price for {args.task} is {total_price}') if args.task == 'OpenAGI': logging.info(f'Clips: {np.mean(clips)}, BERTS: {np.mean(berts)}, ViT: {np.mean(similairies)}, Rewards: {np.mean(rewards)}, Valid: {np.mean(valid)}') - return np.mean(rewards) + return np.mean(rewards) + else: + return None if __name__ == '__main__': @@ -350,8 +398,9 @@ def main(args, client): openai_key = args.openai_key client = OpenAI(api_key=openai_key) elif 'gptq' in args.model_name.lower(): - client = LLM(model=args.model_name, download_dir=args.cache_dir, quantization='gptq', enforce_eager=True, dtype=torch.float16, gpu_memory_utilization=0.9)# , tensor_parallel_size=8)# + client = LLM(model=args.model_name, download_dir=args.cache_dir, quantization='gptq', enforce_eager=True, dtype=torch.float16, tensor_parallel_size=8)#, gpu_memory_utilization=0.7) else: raise NotImplementedError main(args, client) + diff --git a/src/travel_api/__init__.py b/src/travel_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/travel_api/accommodations/__init__.py b/src/travel_api/accommodations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/travel_api/accommodations/apis.py b/src/travel_api/accommodations/apis.py new file mode 100644 index 0000000..fa3a3b8 --- /dev/null +++ b/src/travel_api/accommodations/apis.py @@ -0,0 +1,31 @@ +import pandas as pd +from pandas import DataFrame +from typing import Optional +from utils.flow_utils import extract_before_parenthesis + + +class AccommodationSearch: + def __init__(self, path="../travel_database/accommodations/clean_accommodations_2022.csv"): + self.path = path + self.data = pd.read_csv(self.path).dropna()[['NAME','price','room type', 'house_rules', 'minimum nights', 'maximum occupancy', 'review rate number', 'city']] + print("Accommodations loaded.") + + def load_db(self): + self.data = pd.read_csv(self.path).dropna() + + def run(self, + city: str, + ) -> DataFrame or str: + """Search for accommodations by city.""" + results = self.data[self.data["city"] == city] + if len(results) == 0: + return "There is no attraction in this city." + + return results + + def run_for_annotation(self, + city: str, + ) -> DataFrame: + """Search for accommodations by city.""" + results = self.data[self.data["city"] == extract_before_parenthesis(city)] + return results \ No newline at end of file diff --git a/src/travel_api/attractions/apis.py b/src/travel_api/attractions/apis.py new file mode 100644 index 0000000..d7ac123 --- /dev/null +++ b/src/travel_api/attractions/apis.py @@ -0,0 +1,34 @@ +import pandas as pd +from pandas import DataFrame +from typing import Optional +from utils.flow_utils import extract_before_parenthesis + + +class AttractionSearch: + def __init__(self, path="../travel_database/attractions/attractions.csv"): + self.path = path + self.data = pd.read_csv(self.path).dropna()[['Name','Latitude','Longitude','Address','Phone','Website',"City"]] + print("Attractions loaded.") + + def load_db(self): + self.data = pd.read_csv(self.path) + + def run(self, + city: str, + ) -> DataFrame or str: + """Search for Accommodations by city and date.""" + results = self.data[self.data["City"] == city] + # the results should show the index + results = results.reset_index(drop=True) + if len(results) == 0: + return "There is no attraction in this city." + return results + + def run_for_annotation(self, + city: str, + ) -> DataFrame: + """Search for Accommodations by city and date.""" + results = self.data[self.data["City"] == extract_before_parenthesis(city)] + # the results should show the index + results = results.reset_index(drop=True) + return results \ No newline at end of file diff --git a/src/travel_api/cities/apis.py b/src/travel_api/cities/apis.py new file mode 100644 index 0000000..33e4161 --- /dev/null +++ b/src/travel_api/cities/apis.py @@ -0,0 +1,23 @@ +from pandas import DataFrame + +class CitySearch: + def __init__(self ,path="../travel_database/background/citySet_with_states.txt") -> None: + self.path = path + self.load_data() + print("Cities loaded.") + + def load_data(self): + cityStateMapping = open(self.path, "r").read().strip().split("\n") + self.data = {} + for unit in cityStateMapping: + city, state = unit.split("\t") + if state not in self.data: + self.data[state] = [city] + else: + self.data[state].append(city) + + def run(self, state) -> dict or str: + if state not in self.data: + return ValueError("Invalid State") + else: + return self.data[state] \ No newline at end of file diff --git a/src/travel_api/flights/__init__.py b/src/travel_api/flights/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/travel_api/flights/apis.py b/src/travel_api/flights/apis.py new file mode 100644 index 0000000..880c80d --- /dev/null +++ b/src/travel_api/flights/apis.py @@ -0,0 +1,70 @@ +import pandas as pd +from pandas import DataFrame +from typing import Optional +from utils.flow_utils import extract_before_parenthesis + +class FlightSearch: + + def __init__(self, path="../travel_database/flights/clean_Flights_2022.csv"): + self.path = path + self.data = None + + self.data = pd.read_csv(self.path).dropna()[['Flight Number', 'Price', 'DepTime', 'ArrTime', 'ActualElapsedTime','FlightDate','OriginCityName','DestCityName','Distance']] + print("Flights API loaded.") + + def load_db(self): + self.data = pd.read_csv(self.path).dropna().rename(columns={'Unnamed: 0': 'Flight Number'}) + + def run(self, + origin: str, + destination: str, + departure_date: str, + ) -> DataFrame or str: + """Search for flights by origin, destination, and departure date.""" + results = self.data[self.data["OriginCityName"] == origin] + results = results[results["DestCityName"] == destination] + results = results[results["FlightDate"] == departure_date] + # if order == "ascPrice": + # results = results.sort_values(by=["Price"], ascending=True) + # elif order == "descPrice": + # results = results.sort_values(by=["Price"], ascending=False) + # elif order == "ascDepTime": + # results = results.sort_values(by=["DepTime"], ascending=True) + # elif order == "descDepTime": + # results = results.sort_values(by=["DepTime"], ascending=False) + # elif order == "ascArrTime": + # results = results.sort_values(by=["ArrTime"], ascending=True) + # elif order == "descArrTime": + # results = results.sort_values(by=["ArrTime"], ascending=False) + if len(results) == 0: + return "There is no flight from {} to {} on {}.".format(origin, destination, departure_date) + return results + + def run_for_annotation(self, + origin: str, + destination: str, + departure_date: str, + ) -> DataFrame: + """Search for flights by origin, destination, and departure date.""" + results = self.data[self.data["OriginCityName"] == extract_before_parenthesis(origin)] + results = results[results["DestCityName"] == extract_before_parenthesis(destination)] + results = results[results["FlightDate"] == departure_date] + # if order == "ascPrice": + # results = results.sort_values(by=["Price"], ascending=True) + # elif order == "descPrice": + # results = results.sort_values(by=["Price"], ascending=False) + # elif order == "ascDepTime": + # results = results.sort_values(by=["DepTime"], ascending=True) + # elif order == "descDepTime": + # results = results.sort_values(by=["DepTime"], ascending=False) + # elif order == "ascArrTime": + # results = results.sort_values(by=["ArrTime"], ascending=True) + # elif order == "descArrTime": + # results = results.sort_values(by=["ArrTime"], ascending=False) + return results.to_string(index=False) + + def get_city_set(self): + city_set = set() + for unit in self.data['data']: + city_set.add(unit[5]) + city_set.add(unit[6]) \ No newline at end of file diff --git a/src/travel_api/googleDistanceMatrix/apis.py b/src/travel_api/googleDistanceMatrix/apis.py new file mode 100644 index 0000000..c292d87 --- /dev/null +++ b/src/travel_api/googleDistanceMatrix/apis.py @@ -0,0 +1,123 @@ +import requests +from utils.flow_utils import extract_before_parenthesis +import os +from requests.exceptions import SSLError +import time +import sys +import pandas as pd +import numpy as np + +# This tool refers to the "DistanceMatrix" in the paper. Considering this data obtained from Google API, we consistently use this name in the code. +# Please be assured that this will not influence the experiment results shown in the paper. + +class GoogleDistanceMatrix: + def __init__(self, subscription_key: str="") -> None: + self.gplaces_api_key: str = subscription_key + self.data = pd.read_csv('../travel_database/googleDistanceMatrix/distance.csv') + print("GoogleDistanceMatrix loaded.") + + def run(self, origin, destination, mode='driving'): + origin = extract_before_parenthesis(origin) + destination = extract_before_parenthesis(destination) + info = {"origin": origin, "destination": destination,"cost": None, "duration": None, "distance": None} + response = self.data[(self.data['origin'] == origin) & (self.data['destination'] == destination)] + if len(response) > 0: + if response['duration'].values[0] is None or response['distance'].values[0] is None or response['duration'].values[0] is np.nan or response['distance'].values[0] is np.nan: + return "No valid information." + info["duration"] = response['duration'].values[0] + info["distance"] = response['distance'].values[0] + if 'driving' in mode: + info["cost"] = int(eval(info["distance"].replace("km","").replace(",","")) * 0.05) + elif mode == "taxi": + info["cost"] = int(eval(info["distance"].replace("km","").replace(",",""))) + if 'day' in info["duration"]: + return "No valid information." + return f"{mode}, from {origin} to {destination}, duration: {info['duration']}, distance: {info['distance']}, cost: {info['cost']}" + + return f"{mode}, from {origin} to {destination}, no valid information." + + def run_for_evaluation(self, origin, destination, mode='driving'): + origin = extract_before_parenthesis(origin) + destination = extract_before_parenthesis(destination) + info = {"origin": origin, "destination": destination,"cost": None, "duration": None, "distance": None} + response = self.data[(self.data['origin'] == origin) & (self.data['destination'] == destination)] + if len(response) > 0: + if response['duration'].values[0] is None or response['distance'].values[0] is None or response['duration'].values[0] is np.nan or response['distance'].values[0] is np.nan: + return info + info["duration"] = response['duration'].values[0] + info["distance"] = response['distance'].values[0] + + if 'day' not in info["duration"]: + if 'driving' in mode: + info["cost"] = int(eval(info["distance"].replace("km","").replace(",","")) * 0.05) + elif mode == "taxi": + info["cost"] = int(eval(info["distance"].replace("km","").replace(",",""))) + + return info + + return info + + + def run_online(self, origin, destination, mode="driving"): + # mode in ['driving','taxi','walking', 'distance','transit'] + endpoint = "https://maps.googleapis.com/maps/api/distancematrix/json" + + params = { + "origins": origin, + "destinations": destination, + "mode": mode if mode=="taxi" else "driving", + "key": self.gplaces_api_key + } + + while True: + try: + response = requests.get(endpoint, params=params) + break + except SSLError: + time.sleep(30) + + data = response.json() + info = {"origin": origin, "destination": destination,"cost": None, "duration": None, "distance": None} + if data['status'] == "OK": + element = data['rows'][0]['elements'][0] + if element['status'] == "OK": + info["duration"] = element['duration']['text'] + info["distance"] = element['distance']['text'] + if 'driving' in mode: + info["cost"] = int(eval(info["distance"].replace("km","").replace(",","")) * 0.05) + elif mode == "taxi": + info["cost"] = int(eval(info["distance"].replace("km","").replace(",",""))) + # if 'day' in info["duration"]: + # return "No valid information." + return f"{mode}, from {origin} to {destination}, duration: {info['duration']}, distance: {info['distance']}, cost: {info['cost']}" + + return "No valid information." + + def run_for_annotation(self, origin, destination, mode="driving"): + # mode in ['driving','taxi','walking', 'distance','transit'] + endpoint = "https://maps.googleapis.com/maps/api/distancematrix/json" + + params = { + "origins": extract_before_parenthesis(origin), + "destinations": extract_before_parenthesis(destination), + "mode": mode if mode!="taxi" else "driving", + "key": self.gplaces_api_key + } + + response = requests.get(endpoint, params=params) + data = response.json() + info = {} + if data['status'] == "OK": + element = data['rows'][0]['elements'][0] + if element['status'] == "OK": + info["duration"] = element['duration']['text'] + info["distance"] = element['distance']['text'] + info["cost"] = None + if 'driving' in mode: + info["cost"] = int(eval(info["distance"].replace("km","").replace(",","")) * 0.05) + elif mode == "taxi": + info["cost"] = int(eval(info["distance"].replace("km","").replace(",",""))) + else: + info = {"duration": "N/A", "distance": "N/A", "cost": "N/A", "Hint":"Please check the input."} + return info + diff --git a/src/travel_api/notebook/apis.py b/src/travel_api/notebook/apis.py new file mode 100644 index 0000000..b818be6 --- /dev/null +++ b/src/travel_api/notebook/apis.py @@ -0,0 +1,40 @@ +from pandas import DataFrame + +class Notebook: + def __init__(self) -> None: + self.data = [] + + def write(self, input_data: DataFrame, short_description: str): + self.data.append({"Short Description": short_description, "Content":input_data}) + return f"The information has been recorded in Notebook, and its index is {len(self.data)-1}." + + def update(self, input_data: DataFrame, index: int, short_decription: str): + self.data[index]["Content"] = input_data + self.data[index]["Short Description"] = short_decription + + return f"The information has been updated in Notebook." + + def list(self): + results = [] + for idx, unit in enumerate(self.data): + results.append({"index":idx, "Short Description":unit['Short Description']}) + + return results + + def list_all(self): + results = [] + for idx, unit in enumerate(self.data): + if type(unit['Content']) == DataFrame: + results.append({"index":idx, "Short Description":unit['Short Description'], "Content":unit['Content'].to_string(index=False)}) + else: + results.append({"index":idx, "Short Description":unit['Short Description'], "Content":unit['Content']}) + + return results + + def read(self, index): + return self.data[index] + + def reset(self): + self.data = [] + + diff --git a/src/travel_api/notebook/test.py b/src/travel_api/notebook/test.py new file mode 100644 index 0000000..e69de29 diff --git a/src/travel_api/planner/apis.py b/src/travel_api/planner/apis.py new file mode 100644 index 0000000..d4e083f --- /dev/null +++ b/src/travel_api/planner/apis.py @@ -0,0 +1,393 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) +from langchain.prompts import PromptTemplate +from agents.prompts import planner_agent_prompt, cot_planner_agent_prompt, react_planner_agent_prompt,reflect_prompt,react_reflect_planner_agent_prompt, REFLECTION_HEADER +from langchain.chat_models import ChatOpenAI +from langchain.llms.base import BaseLLM +from langchain.schema import ( + AIMessage, + HumanMessage, + SystemMessage +) +from env import ReactEnv,ReactReflectEnv +import tiktoken +import re +import openai +import time +from enum import Enum +from typing import List, Union, Literal +from langchain_google_genai import ChatGoogleGenerativeAI +import argparse + + +OPENAI_API_KEY = os.environ['OPENAI_API_KEY'] +GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY'] + + +def catch_openai_api_error(): + error = sys.exc_info()[0] + if error == openai.error.APIConnectionError: + print("APIConnectionError") + elif error == openai.error.RateLimitError: + print("RateLimitError") + time.sleep(60) + elif error == openai.error.APIError: + print("APIError") + elif error == openai.error.AuthenticationError: + print("AuthenticationError") + else: + print("API error:", error) + + +class ReflexionStrategy(Enum): + """ + REFLEXION: Apply reflexion to the next reasoning trace + """ + REFLEXION = 'reflexion' + + +class Planner: + def __init__(self, + # args, + agent_prompt: PromptTemplate = planner_agent_prompt, + model_name: str = 'gpt-3.5-turbo-1106', + ) -> None: + + self.agent_prompt = agent_prompt + self.scratchpad: str = '' + self.model_name = model_name + self.enc = tiktoken.encoding_for_model("gpt-3.5-turbo") + + if model_name in ['mistral-7B-32K']: + self.llm = ChatOpenAI(temperature=0, + max_tokens=4096, + openai_api_key="EMPTY", + openai_api_base="http://localhost:8301/v1", + model_name="gpt-3.5-turbo") + + if model_name in ['ChatGLM3-6B-32K']: + self.llm = ChatOpenAI(temperature=0, + max_tokens=4096, + openai_api_key="EMPTY", + openai_api_base="http://localhost:8501/v1", + model_name="gpt-3.5-turbo") + + elif model_name in ['mixtral']: + self.max_token_length = 30000 + self.llm = ChatOpenAI(temperature=0, + max_tokens=4096, + openai_api_key="EMPTY", + openai_api_base="http://10.176.40.135:8000/v1", + model_name="YOUR/MODEL/PATH") + + elif model_name in ['gemini']: + self.llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY) + else: + self.llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=4096, openai_api_key=OPENAI_API_KEY) + + + print(f"PlannerAgent {model_name} loaded.") + + def run(self, text, query, log_file=None) -> str: + if log_file: + log_file.write('\n---------------Planner\n'+self._build_agent_prompt(text, query)) + # print(self._build_agent_prompt(text, query)) + if self.model_name in ['gemini']: + return str(self.llm.invoke(self._build_agent_prompt(text, query)).content) + else: + if len(self.enc.encode(self._build_agent_prompt(text, query))) > 12000: + return 'Max Token Length Exceeded.' + else: + return self.llm([HumanMessage(content=self._build_agent_prompt(text, query))]).content + + def _build_agent_prompt(self, text, query) -> str: + return self.agent_prompt.format( + text=text, + query=query) + + +class ReactPlanner: + """ + A question answering ReAct Agent. + """ + def __init__(self, + agent_prompt: PromptTemplate = react_planner_agent_prompt, + model_name: str = 'gpt-3.5-turbo-1106', + ) -> None: + + self.agent_prompt = agent_prompt + self.react_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key=OPENAI_API_KEY,model_kwargs={"stop": ["Action","Thought","Observation"]}) + self.env = ReactEnv() + self.query = None + self.max_steps = 30 + self.reset() + self.finished = False + self.answer = '' + self.enc = tiktoken.encoding_for_model("gpt-3.5-turbo") + + def run(self, text, query, reset = True) -> None: + + self.query = query + self.text = text + + if reset: + self.reset() + + + while not (self.is_halted() or self.is_finished()): + self.step() + + return self.answer, self.scratchpad + + + def step(self) -> None: + # Think + self.scratchpad += f'\nThought {self.curr_step}:' + self.scratchpad += ' ' + self.prompt_agent() + print(self.scratchpad.split('\n')[-1]) + + # Act + self.scratchpad += f'\nAction {self.curr_step}:' + action = self.prompt_agent() + self.scratchpad += ' ' + action + print(self.scratchpad.split('\n')[-1]) + + # Observe + self.scratchpad += f'\nObservation {self.curr_step}: ' + + action_type, action_arg = parse_action(action) + + if action_type == 'CostEnquiry': + try: + input_arg = eval(action_arg) + if type(input_arg) != dict: + raise ValueError('The sub plan can not be parsed into json format, please check. Only one day plan is supported.') + observation = f'Cost: {self.env.run(input_arg)}' + except SyntaxError: + observation = f'The sub plan can not be parsed into json format, please check.' + except ValueError as e: + observation = str(e) + + elif action_type == 'Finish': + self.finished = True + observation = f'The plan is finished.' + self.answer = action_arg + + else: + observation = f'Action {action_type} is not supported.' + + self.curr_step += 1 + + self.scratchpad += observation + print(self.scratchpad.split('\n')[-1]) + + def prompt_agent(self) -> str: + while True: + try: + return format_step(self.react_llm([HumanMessage(content=self._build_agent_prompt())]).content) + except: + catch_openai_api_error() + print(self._build_agent_prompt()) + print(len(self.enc.encode(self._build_agent_prompt()))) + time.sleep(5) + + def _build_agent_prompt(self) -> str: + return self.agent_prompt.format( + query = self.query, + text = self.text, + scratchpad = self.scratchpad) + + def is_finished(self) -> bool: + return self.finished + + def is_halted(self) -> bool: + return ((self.curr_step > self.max_steps) or ( + len(self.enc.encode(self._build_agent_prompt())) > 14000)) and not self.finished + + def reset(self) -> None: + self.scratchpad = '' + self.answer = '' + self.curr_step = 1 + self.finished = False + + +class ReactReflectPlanner: + """ + A question answering Self-Reflecting React Agent. + """ + def __init__(self, + agent_prompt: PromptTemplate = react_reflect_planner_agent_prompt, + reflect_prompt: PromptTemplate = reflect_prompt, + model_name: str = 'gpt-3.5-turbo-1106', + ) -> None: + + self.agent_prompt = agent_prompt + self.reflect_prompt = reflect_prompt + if model_name in ['gemini']: + self.react_llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY) + self.reflect_llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key=GOOGLE_API_KEY) + else: + self.react_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key=OPENAI_API_KEY,model_kwargs={"stop": ["Action","Thought","Observation,'\n"]}) + self.reflect_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key=OPENAI_API_KEY,model_kwargs={"stop": ["Action","Thought","Observation,'\n"]}) + self.model_name = model_name + self.env = ReactReflectEnv() + self.query = None + self.max_steps = 30 + self.reset() + self.finished = False + self.answer = '' + self.reflections: List[str] = [] + self.reflections_str: str = '' + self.enc = tiktoken.encoding_for_model("gpt-3.5-turbo") + + def run(self, text, query, reset = True) -> None: + + self.query = query + self.text = text + + if reset: + self.reset() + + + while not (self.is_halted() or self.is_finished()): + self.step() + if self.env.is_terminated and not self.finished: + self.reflect(ReflexionStrategy.REFLEXION) + + + return self.answer, self.scratchpad + + + def step(self) -> None: + # Think + self.scratchpad += f'\nThought {self.curr_step}:' + self.scratchpad += ' ' + self.prompt_agent() + print(self.scratchpad.split('\n')[-1]) + + # Act + self.scratchpad += f'\nAction {self.curr_step}:' + action = self.prompt_agent() + self.scratchpad += ' ' + action + print(self.scratchpad.split('\n')[-1]) + + # Observe + self.scratchpad += f'\nObservation {self.curr_step}: ' + + action_type, action_arg = parse_action(action) + + if action_type == 'CostEnquiry': + try: + input_arg = eval(action_arg) + if type(input_arg) != dict: + raise ValueError('The sub plan can not be parsed into json format, please check. Only one day plan is supported.') + observation = f'Cost: {self.env.run(input_arg)}' + except SyntaxError: + observation = f'The sub plan can not be parsed into json format, please check.' + except ValueError as e: + observation = str(e) + + elif action_type == 'Finish': + self.finished = True + observation = f'The plan is finished.' + self.answer = action_arg + + else: + observation = f'Action {action_type} is not supported.' + + self.curr_step += 1 + + self.scratchpad += observation + print(self.scratchpad.split('\n')[-1]) + + def reflect(self, strategy: ReflexionStrategy) -> None: + print('Reflecting...') + if strategy == ReflexionStrategy.REFLEXION: + self.reflections += [self.prompt_reflection()] + self.reflections_str = format_reflections(self.reflections) + else: + raise NotImplementedError(f'Unknown reflection strategy: {strategy}') + print(self.reflections_str) + + def prompt_agent(self) -> str: + while True: + try: + if self.model_name in ['gemini']: + return format_step(self.react_llm.invoke(self._build_agent_prompt()).content) + else: + return format_step(self.react_llm([HumanMessage(content=self._build_agent_prompt())]).content) + except: + catch_openai_api_error() + print(self._build_agent_prompt()) + print(len(self.enc.encode(self._build_agent_prompt()))) + time.sleep(5) + + def prompt_reflection(self) -> str: + while True: + try: + if self.model_name in ['gemini']: + return format_step(self.reflect_llm.invoke(self._build_reflection_prompt()).content) + else: + return format_step(self.reflect_llm([HumanMessage(content=self._build_reflection_prompt())]).content) + except: + catch_openai_api_error() + print(self._build_reflection_prompt()) + print(len(self.enc.encode(self._build_reflection_prompt()))) + time.sleep(5) + + def _build_agent_prompt(self) -> str: + return self.agent_prompt.format( + query = self.query, + text = self.text, + scratchpad = self.scratchpad, + reflections = self.reflections_str) + + def _build_reflection_prompt(self) -> str: + return self.reflect_prompt.format( + query = self.query, + text = self.text, + scratchpad = self.scratchpad) + + def is_finished(self) -> bool: + return self.finished + + def is_halted(self) -> bool: + return ((self.curr_step > self.max_steps) or ( + len(self.enc.encode(self._build_agent_prompt())) > 14000)) and not self.finished + + def reset(self) -> None: + self.scratchpad = '' + self.answer = '' + self.curr_step = 1 + self.finished = False + self.reflections = [] + self.reflections_str = '' + self.env.reset() + +def format_step(step: str) -> str: + return step.strip('\n').strip().replace('\n', '') + +def parse_action(string): + pattern = r'^(\w+)\[(.+)\]$' + match = re.match(pattern, string) + + try: + if match: + action_type = match.group(1) + action_arg = match.group(2) + return action_type, action_arg + else: + return None, None + + except: + return None, None + +def format_reflections(reflections: List[str], + header: str = REFLECTION_HEADER) -> str: + if reflections == []: + return '' + else: + return header + 'Reflections:\n- ' + '\n- '.join([r.strip() for r in reflections]) + +# if __name__ == '__main__': + \ No newline at end of file diff --git a/src/travel_api/planner/env.py b/src/travel_api/planner/env.py new file mode 100644 index 0000000..5b8d1ba --- /dev/null +++ b/src/travel_api/planner/env.py @@ -0,0 +1,202 @@ +from tools.flights.apis import Flights +from tools.accommodations.apis import Accommodations +from tools.restaurants.apis import Restaurants +from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix +from tools.attractions.apis import Attractions +from evaluation.hard_constraint import extract_from_to,get_valid_name_city +import math + +class ReactEnv: + def __init__(self): + + self.flight = Flights() + self.accommodation = Accommodations() + self.restaurants = Restaurants() + self.googleDistanceMatrix = GoogleDistanceMatrix() + self.attractions = Attractions() + + def run(self, tested_data): + + total_cost = 0 + unit = tested_data + people_number = tested_data['people_number'] + returned_info = [] + + if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-': + value = unit['transportation'] + org_city, dest_city = extract_from_to(value) + if org_city == None or dest_city == None: + org_city, dest_city = extract_from_to(unit['current_city']) + if 'flight number' in value.lower(): + try: + res = self.flight.data[self.flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]] + if len(res) > 0: + total_cost += res['Price'].values[0] * people_number + else: + returned_info.append('The filght information is not valid') + except: + returned_info.append('The filght information is not valid') + + elif 'self-driving' in value.lower() or 'taxi' in value.lower(): + try: + if 'self-driving' in value.lower(): + # print(org_city,dest_city) + cost = self.googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'self-driving')['cost'] + if cost == None: + returned_info.append('The transporation information is not valid, please check.') + else: + total_cost += cost * math.ceil(people_number * 1.0 / 5) + else: + cost = self.googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'taxi')['cost'] + if cost == None: + returned_info.append('The transporation information is not valid, please check.') + else: + total_cost += cost * math.ceil(people_number * 1.0 / 4) + except: + returned_info.append('The transporation information is not valid, please check. You have to make sure there are two cities (from A to B) in your transportation plan.') + + if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-': + name, city = get_valid_name_city(unit['breakfast']) + if name != '-' and city != '-': + res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)] + if len(res) > 0: + total_cost += res['Average Cost'].values[0] * people_number + else: + returned_info.append('The breakfase information is not valid, please check.') + + if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-': + name, city = get_valid_name_city(unit['lunch']) + if name != '-' and city != '-': + res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)] + if len(res) > 0: + total_cost += res['Average Cost'].values[0] * people_number + else: + returned_info.append('The lunch information is not valid, please check.') + + if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-': + name, city = get_valid_name_city(unit['dinner']) + if name != '-' and city != '-': + res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)] + if len(res) > 0: + total_cost += res['Average Cost'].values[0] * people_number + else: + returned_info.append('The dinner information is not valid, please check.') + + if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-': + name, city = get_valid_name_city(unit['accommodation']) + if name != '-' and city != '-': + res = self.accommodation.data[(self.accommodation.data['NAME'] == name) & (self.accommodation.data['city'] == city)] + if len(res) > 0: + total_cost += res['price'].values[0] * math.ceil(people_number * 1.0 / res['maximum occupancy'].values[0]) + else: + returned_info.append('The accommodation information is not valid, please check.') + + if len(returned_info) == 0: + return "The cost of your plan is " + str(total_cost) + " dollars." + else: + message = "Sorry, the cost of your plan is not available because of the following reasons:" + for idx, info in enumerate(returned_info): + message += str(idx + 1) + ". " + info + " " + '\t' + return message + +class ReactReflectEnv(ReactEnv): + def __init__(self): + super().__init__() + self.is_terminated = False + self.max_retry_step = 3 + self.retry_step = 0 + + def reset(self): + self.is_terminated = False + self.retry_step = 0 + + def run(self, tested_data): + total_cost = 0 + unit = tested_data + people_number = tested_data['people_number'] + returned_info = [] + + if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-': + value = unit['transportation'] + org_city, dest_city = extract_from_to(value) + if org_city == None or dest_city == None: + org_city, dest_city = extract_from_to(unit['current_city']) + + + if org_city == None or dest_city == None: + returned_info.append('The transporation information is not valid, please check.') + + else: + if 'flight number' in value.lower(): + try: + res = self.flight.data[self.flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]] + if len(res) > 0: + total_cost += res['Price'].values[0] * people_number + else: + returned_info.append('The filght information is not valid') + except: + returned_info.append('The filght information is not valid') + + elif 'self-driving' in value.lower() or 'taxi' in value.lower(): + if 'self-driving' in value.lower(): + cost = self.googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'self-driving')['cost'] + if cost == None: + returned_info.append('The transporation information is not valid, please check.') + else: + total_cost += cost * math.ceil(people_number * 1.0 / 5) + else: + cost = self.googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'taxi')['cost'] + if cost == None: + returned_info.append('The transporation information is not valid, please check.') + else: + total_cost += cost * math.ceil(people_number * 1.0 / 4) + + if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-': + name, city = get_valid_name_city(unit['breakfast']) + if name != '-' and city != '-': + res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)] + if len(res) > 0: + total_cost += res['Average Cost'].values[0] * people_number + else: + returned_info.append('The breakfase information is not valid, please check.') + + if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-': + name, city = get_valid_name_city(unit['lunch']) + if name != '-' and city != '-': + res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)] + if len(res) > 0: + total_cost += res['Average Cost'].values[0] * people_number + else: + returned_info.append('The lunch information is not valid, please check.') + + if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-': + name, city = get_valid_name_city(unit['dinner']) + if name != '-' and city != '-': + res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)] + if len(res) > 0: + total_cost += res['Average Cost'].values[0] * people_number + else: + returned_info.append('The dinner information is not valid, please check.') + + if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-': + name, city = get_valid_name_city(unit['accommodation']) + if name != '-' and city != '-': + res = self.accommodation.data[(self.accommodation.data['NAME'] == name) & (self.accommodation.data['city'] == city)] + if len(res) > 0: + total_cost += res['price'].values[0] * math.ceil(people_number * 1.0 / res['maximum occupancy'].values[0]) + else: + returned_info.append('The accommodation information is not valid, please check.') + + if len(returned_info) == 0: + self.retry_step = 0 + self.is_terminated = False + return "The cost of your plan is " + str(total_cost) + " dollars." + else: + message = "Sorry, the cost of your plan is not available because of the following reasons:" + for idx, info in enumerate(returned_info): + message += str(idx + 1) + ". " + info + " " + '\t' + self.retry_step += 1 + if self.retry_step >= self.max_retry_step: + self.is_terminated = True + return message + diff --git a/src/travel_api/planner/sole_planning.py b/src/travel_api/planner/sole_planning.py new file mode 100644 index 0000000..26d135a --- /dev/null +++ b/src/travel_api/planner/sole_planning.py @@ -0,0 +1,112 @@ +import os +import re +import sys +sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../.."))) +os.chdir(os.path.dirname(os.path.abspath(__file__))) +from agents.prompts import planner_agent_prompt, cot_planner_agent_prompt, react_planner_agent_prompt,react_reflect_planner_agent_prompt,reflect_prompt +# from utils.func import get_valid_name_city,extract_before_parenthesis, extract_numbers_from_filenames +import json +import time +from langchain.callbacks import get_openai_callback + +from tqdm import tqdm +from tools.planner.apis import Planner, ReactPlanner, ReactReflectPlanner +import openai +import argparse +from datasets import load_dataset + + + + +def load_line_json_data(filename): + data = [] + with open(filename, 'r', encoding='utf-8') as f: + for line in f.read().strip().split('\n'): + unit = json.loads(line) + data.append(unit) + return data + +def extract_numbers_from_filenames(directory): + # Define the pattern to match files + pattern = r'annotation_(\d+).json' + + # List all files in the directory + files = os.listdir(directory) + + # Extract numbers from filenames that match the pattern + numbers = [int(re.search(pattern, file).group(1)) for file in files if re.match(pattern, file)] + + return numbers + + +def catch_openai_api_error(): + error = sys.exc_info()[0] + if error == openai.error.APIConnectionError: + print("APIConnectionError") + elif error == openai.error.RateLimitError: + print("RateLimitError") + time.sleep(60) + elif error == openai.error.APIError: + print("APIError") + elif error == openai.error.AuthenticationError: + print("AuthenticationError") + else: + print("API error:", error) + + +if __name__ == "__main__": + + # model_name= ['gpt-3.5-turbo-1106','gpt-4-1106-preview','gemini','mixtral'][1] + # set_type = ['dev','test'][0] + # strategy = ['direct','cot','react','reflexion'][0] + + parser = argparse.ArgumentParser() + parser.add_argument("--set_type", type=str, default="validation") + parser.add_argument("--model_name", type=str, default="gpt-3.5-turbo-1106") + parser.add_argument("--output_dir", type=str, default="./") + parser.add_argument("--strategy", type=str, default="direct") + args = parser.parse_args() + directory = f'{args.output_dir}/{args.set_type}' + if args.set_type == 'validation': + query_data_list = load_dataset('osunlp/TravelPlanner','validation')['validation'] + elif args.set_type == 'test': + query_data_list = load_dataset('osunlp/TravelPlanner','test')['test'] + numbers = [i for i in range(1,len(query_data_list)+1)] + + if args.strategy == 'direct': + planner = Planner(model_name=args.model_name, agent_prompt=planner_agent_prompt) + elif args.strategy == 'cot': + planner = Planner(model_name=args.model_name, agent_prompt=cot_planner_agent_prompt) + elif args.strategy == 'react': + planner = ReactPlanner(model_name=args.model_name, agent_prompt=react_planner_agent_prompt) + elif args.strategy == 'reflexion': + planner = ReactReflectPlanner(model_name=args.model_name, agent_prompt=react_reflect_planner_agent_prompt,reflect_prompt=reflect_prompt) + + + with get_openai_callback() as cb: + for number in tqdm(numbers[:]): + + query_data = query_data_list[number-1] + reference_information = query_data['reference_information'] + while True: + if args.strategy in ['react','reflexion']: + planner_results, scratchpad = planner.run(reference_information, query_data['query']) + else: + planner_results = planner.run(reference_information, query_data['query']) + if planner_results != None: + break + print(planner_results) + # check if the directory exists + if not os.path.exists(os.path.join(f'{args.output_dir}/{args.set_type}')): + os.makedirs(os.path.join(f'{args.output_dir}/{args.set_type}')) + if not os.path.exists(os.path.join(f'{args.output_dir}/{args.set_type}/generated_plan_{number}.json')): + result = [{}] + else: + result = json.load(open(os.path.join(f'{args.output_dir}/{args.set_type}/generated_plan_{number}.json'))) + if args.strategy in ['react','reflexion']: + result[-1][f'{args.model_name}_{args.strategy}_sole-planning_results_logs'] = scratchpad + result[-1][f'{args.model_name}_{args.strategy}_sole-planning_results'] = planner_results + # write to json file + with open(os.path.join(f'{args.output_dir}/{args.set_type}/generated_plan_{number}.json'), 'w') as f: + json.dump(result, f, indent=4) + print(cb) diff --git a/src/travel_api/planner/test.py b/src/travel_api/planner/test.py new file mode 100644 index 0000000..fedfdd7 --- /dev/null +++ b/src/travel_api/planner/test.py @@ -0,0 +1 @@ +print(eval("[ddd")) \ No newline at end of file diff --git a/src/travel_api/restaurants/__init__.py b/src/travel_api/restaurants/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/travel_api/restaurants/apis.py b/src/travel_api/restaurants/apis.py new file mode 100644 index 0000000..10adb2a --- /dev/null +++ b/src/travel_api/restaurants/apis.py @@ -0,0 +1,50 @@ +import pandas as pd +from pandas import DataFrame +from typing import Optional +from utils.flow_utils import extract_before_parenthesis + +class RestaurantSearch: + def __init__(self, path="../travel_database/restaurants/clean_restaurant_2022.csv"): + self.path = path + self.data = pd.read_csv(self.path).dropna()[['Name','Average Cost','Cuisines','Aggregate Rating','City']] + print("Restaurants loaded.") + + def load_db(self): + self.data = pd.read_csv(self.path).dropna() + + def run(self, + city: str, + ) -> DataFrame or str: + """Search for restaurant .""" + results = self.data[self.data["City"] == city] + # results = results[results["date"] == date] + # if price_order == "asc": + # results = results.sort_values(by=["Average Cost"], ascending=True) + # elif price_order == "desc": + # results = results.sort_values(by=["Average Cost"], ascending=False) + + # if rating_order == "asc": + # results = results.sort_values(by=["Aggregate Rating"], ascending=True) + # elif rating_order == "desc": + # results = results.sort_values(by=["Aggregate Rating"], ascending=False) + if len(results) == 0: + return "There is no restaurant in this city." + return results + + def run_for_annotation(self, + city: str, + ) -> DataFrame: + """Search for restaurant .""" + results = self.data[self.data["City"] == extract_before_parenthesis(city)] + # results = results[results["date"] == date] + # if price_order == "asc": + # results = results.sort_values(by=["Average Cost"], ascending=True) + # elif price_order == "desc": + # results = results.sort_values(by=["Average Cost"], ascending=False) + + # if rating_order == "asc": + # results = results.sort_values(by=["Aggregate Rating"], ascending=True) + # elif rating_order == "desc": + # results = results.sort_values(by=["Aggregate Rating"], ascending=False) + + return results \ No newline at end of file diff --git a/src/utils/travel_commonsense_constraint.py b/src/utils/travel_commonsense_constraint.py new file mode 100644 index 0000000..4119143 --- /dev/null +++ b/src/utils/travel_commonsense_constraint.py @@ -0,0 +1,575 @@ +from utils.travel_utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames +from travel_api.flights.apis import FlightSearch +from travel_api.accommodations.apis import AccommodationSearch +from travel_api.restaurants.apis import RestaurantSearch +from travel_api.googleDistanceMatrix.apis import GoogleDistanceMatrix +from travel_api.attractions.apis import AttractionSearch +import math +import json +import re +import os +import sys +from tqdm import tqdm +import argparse + +# sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) +# os.chdir(os.path.dirname(os.path.abspath(__file__))) + +flight = FlightSearch() +accommodation = AccommodationSearch() +restaurants = RestaurantSearch() +googleDistanceMatrix = GoogleDistanceMatrix() +attractions = AttractionSearch() + +city_state_set = open('../travel_database/background/citySet_with_states.txt','r').read().split('\n') +city_state_map = {x:y for x,y in [unit.split('\t') for unit in city_state_set]} + + +def load_line_json_data(filename): + data = [] + with open(filename, 'r', encoding='utf-8') as f: + for line in f.read().strip().split('\n'): + unit = json.loads(line) + data.append(unit) + return data + + +def count_consecutive_values(lst): + if not lst: + return [] + + result = [] + current_string = lst[0] + count = 1 + + for i in range(1, len(lst)): + if lst[i] == current_string: + count += 1 + else: + result.append((current_string, count)) + current_string = lst[i] + count = 1 + + result.append((current_string, count)) # Add the last group of values + return result + + +def transportation_match(text: str): + + if 'taxi' in text.lower(): + return 'Taxi' + + elif 'self-driving' in text.lower(): + return 'Self-driving' + + elif 'flight' in text.lower(): + return 'Flight' + + +def extract_from_to(text: str): + """ + Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string. + + Args: + - text (str): The input string. + + Returns: + - tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None). + """ + pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)" + matches = re.search(pattern, text) + return matches.groups() if matches else (None, None) + + + +def is_valid_city_sequence(city_list): + """ + Checks if the city sequence is valid. A valid sequence has every city (except the first and last) + appearing consecutively, and no city should appear again once its sequence is over. + + Args: + - city_list (list): List of cities. + + Returns: + - bool: True if the sequence is valid, False otherwise. + """ + + # If the list has less than 3 cities, it's invalid. + if len(city_list) < 3: + return False + + # Set to keep track of visited cities + visited_cities = set() + + i = 0 + while i < len(city_list): + city = city_list[i] + + # If the city was already visited, it's invalid. + if city in visited_cities and (i != 0 and i != len(city_list) - 1): + return False + + # Count the consecutive occurrences of the city + count = 0 + while i < len(city_list) and city_list[i] == city: + count += 1 + i += 1 + + # If the city appeared only once in the medium, it's invalid. + if count == 1 and 0 < i - 1 < len(city_list) - 1: + return False + + visited_cities.add(city) + + return True + + + +def is_reasonalbe_visiting_city(question, tested_data): + + city_list = [] + + # print(tested_data) + for i in range(min(question['days'],len(tested_data))): + if type(tested_data) != list or 'current_city' not in tested_data[i]: + return False, "No current city exists." + + city_value = tested_data[i]['current_city'] + + if 'from' in city_value: + city1, city2 = extract_from_to(city_value) + city1 = extract_before_parenthesis(city1) + city2 = extract_before_parenthesis(city2) + if i == 0 and city1 != question['org']: + return False, f"The first day's city should be {question['org']}." + + city_list += [city1, city2] + + else: + city_list.append(extract_before_parenthesis(city_value)) + + if city_list[0] != city_list[-1]: + return False, "The trip should be a closed circle." + + if not is_valid_city_sequence(city_list): + return False, "The city sequence is invalid." + + for idx, city in enumerate(city_list): + if city not in city_state_map: + return False, f"{city} is not a valid city." + if idx not in [0,len(city_list)-1] and question['days'] >3 and city_state_map[city] != question['dest']: + return False, f"{city} is not in {question['dest']}." + + return True, None + + +def is_valid_restaurants(question, tested_data): + if type(tested_data) != list: + return False, "KeyError" + + restaurants_list = [] + + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + + if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-': + if unit['breakfast'] not in restaurants_list: + restaurants_list.append(unit['breakfast']) + else: + return False, f"The restaurant in day {i+1} breakfast is repeated." + # elif 'breakfast' not in unit : + # return False, f"No Breakfast Info." + + if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-': + if unit['lunch'] not in restaurants_list: + restaurants_list.append(unit['lunch']) + else: + return False, f"The restaurant in day {i+1} lunch {unit['lunch']} is repeated." + # elif 'lunch' not in unit: + # return False, f"No Lunch Info." + + if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-': + if unit['dinner'] not in restaurants_list: + restaurants_list.append(unit['dinner']) + else: + return False, f"The restaurant in day {i+1} dinner is repeated." + # elif 'dinner' not in unit: + # return False, f"No Dinner Info." + + return True, None + +def is_valid_attractions(question, tested_data): + if type(tested_data) != list: + return False, "KeyError" + + attractions_list = [] + + for i in range(min(question['days'],len(tested_data))): + + unit = tested_data[i] + + if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-': + for attraction in unit['attraction'].split(';')[:-1]: + if attraction not in attractions_list: + attractions_list.append(attraction) + else: + return False, f"The attraction '{attraction}' in day {i+1} is repeated." + + # elif 'attraction' not in unit: + # return False, f"No Attraction Info." + + return True, None + +def is_valid_transportation(question, tested_data): + if type(tested_data) != list: + return False, "KeyError" + + if 'transportation' in tested_data[0] and tested_data[0]['transportation'] and tested_data[0]['transportation'] != '-': + transportation_list = [transportation_match(tested_data[0]['transportation'])] + + else: + return False, "The transportation in day 1 should not be empty." + + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + + if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-': + transportation_list.append(transportation_match(unit['transportation'])) + # elif 'transportation' not in unit: + # return False, f"No Transportation Info." + + if (('Self-driving' in transportation_list) and ('Flight' in transportation_list)) or (('Taxi' in transportation_list) and ('Self-driving' in transportation_list)): + return False, "The transportation is conflicting." + + return True, None + +def is_valid_information_in_current_city(question, tested_data): + if type(tested_data) != list: + return False, "KeyError" + + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + if 'current_city' not in unit: + return False, "No current city exists." + current_city = unit['current_city'] + final_city_list = [] + + if 'from' in current_city: + city1, city2 = extract_from_to(current_city) + city1 = extract_before_parenthesis(city1) + city2 = extract_before_parenthesis(city2) + final_city_list = [city1, city2] + else: + final_city_list = extract_before_parenthesis(current_city) + + if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-': + for city in final_city_list: + if city not in unit['transportation']: + # print(city) + return False, f"The transportation in day {i+1} is invalid city choice." + # elif 'transportation' not in unit: + # return False, f"No Transportation Info." + + if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-': + + flag = False + + for city in final_city_list: + if city in unit['breakfast']: + flag = True + + if not flag: + return False, f"The breakfast in day {i+1} is invalid city choice." + # elif 'breakfast' not in unit: + # return False, f"No Breakfast Info." + + if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-': + flag = False + + for city in final_city_list: + if city in unit['lunch']: + flag = True + + if not flag: + return False, f"The lunch in day {i+1} is invalid city choice." + # elif 'lunch' not in unit: + # return False, f"No Lunch Info." + + if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-': + flag = False + + for city in final_city_list: + if city in unit['dinner']: + flag = True + + if not flag: + return False, f"The dinner in day {i+1} is invalid city choice." + # elif 'dinner' not in unit: + # return False, f"No Dinner Info." + + if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-': + + attraction_list = unit['attraction'].split(';')[:-1] + + for attraction in attraction_list: + flag = False + for city in final_city_list: + if city in attraction: + flag = True + if not flag: + return False, f"The attraction in day {i+1} is invalid city choice." + + # elif 'attraction' not in unit: + # return False, f"No Attraction Info." + + + if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-': + + if final_city_list[-1] not in unit['accommodation']: + return False, f"The accommodation in day {i+1} is invalid city choice." + + # elif 'accommodation' not in unit: + # return False, f"No Accommodation Info." + + return True, None + +# hallucination +def is_valid_information_in_sandbox(question, tested_data): + if type(tested_data) != list: + return False, "KeyError" + + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + + if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-': + value = unit['transportation'] + org_city, dest_city = extract_from_to(value) + if org_city == None or dest_city == None: + org_city, dest_city = extract_from_to(unit['current_city']) + if 'flight number' in value.lower(): + try: + org_city = extract_before_parenthesis(org_city) + dest_city = extract_before_parenthesis(dest_city) + except TypeError: + raise ValueError("The transportation {} in day {} can not be parsed.".format(value,i+1)) + # print(value) + try: + if len(flight.data[(flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]) & (flight.data['OriginCityName']==org_city) & (flight.data['DestCityName']==dest_city)]) < 1: + return False, f"The flight number in day {i+1} is invalid in the sandbox." + except IndexError: + return False, f"The flight number in day {i + 1} is invalid in the sandbox." + + elif 'self-driving' in value.lower() or 'taxi' in value.lower(): + try: + org_city = extract_before_parenthesis(org_city) + dest_city = extract_before_parenthesis(dest_city) + except TypeError: + org_city = '-' + dest_city = '-' + print("The transportation {} in day {} can not be parsed and '-' will be used instead.".format(value,i+1)) + + if 'self-driving' in value.lower(): + if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='self-driving')['cost'] == None: + return False, f"The self-driving in day {i+1} is invalid in the sandbox." + else: + if googleDistanceMatrix.run_for_evaluation(org_city, dest_city, mode='taxi')['cost'] == None: + return False, f"The taxi in day {i+1} is invalid in the sandbox." + + if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-': + name, city = get_valid_name_city(unit['breakfast']) + if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1: + return False, f"The breakfast in day {i+1} is invalid in the sandbox." + # elif 'breakfast' not in unit: + # return False, f"No Breakfast Info." + + if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-': + name, city = get_valid_name_city(unit['lunch']) + if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1: + return False, f"The lunch in day {i+1} is invalid in the sandbox." + # elif 'lunch' not in unit: + # return False, f"No Lunch Info." + + if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-': + name, city = get_valid_name_city(unit['dinner']) + if len(restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)]) < 1: + return False, f"The dinner in day {i+1} is invalid in the sandbox." + # elif 'dinner' not in unit: + # return False, f"No Dinner Info." + + if 'attraction' in unit and unit['attraction'] and unit['attraction'] != '-': + attractions_list = unit['attraction'].split(';')[:-1] + for attraction in attractions_list: + name, city = get_valid_name_city(attraction) + if len(attractions.data[(attractions.data['Name'].astype(str).str.contains(re.escape(name))) & (attractions.data['City'] == city)]) < 1: + return False, f"The attraction {attraction} in day {i+1} is invalid in the sandbox." + # elif 'attraction' not in unit: + # return False, f"No Attraction Info." + + if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-': + name, city = get_valid_name_city(unit['accommodation']) + # print(name,city) + # print(accommodation.data[accommodation.data['NAME'].astype(str).str.contains(re.escape(name))]) + if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) < 1: + return False, f"The accommodation in day {i+1} is invalid in the sandbox." + # elif 'accommodation' not in unit: + # return False, f"No Accommodation Info." + + return True, None + + +def is_valid_accommodaton(question, tested_data): + if type(tested_data) != list: + return False, "KeyError" + + data = [] + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + + if 'accommodation' not in unit: + return False, f"No Accommodation Info." + + data.append(unit['accommodation']) + # data = [unit['accommodation'] for unit in tested_data] + consectutive_accommodation = count_consecutive_values(data) + for unit in consectutive_accommodation: + # print(unit) + if unit and unit[0] not in ['-',''] : + name, city = get_valid_name_city(unit[0]) + # print(unit[0],name,city) + # try: + if len(accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)]) == 1 and unit[1] < accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)].iloc[0]['minimum nights']: + return False, f"The accommodation {unit[0]} do not obey the minumum nights rule." + # can not parse data + # except re.error: + # continue + + return True, None + +def is_valid_visiting_city_number(question, tested_data): + + city_set = set() + if type(tested_data) != list: + return False, "KeyError" + + for i in range(min(question['days'],len(tested_data))): + city_value = tested_data[i]['current_city'] + + if 'from' in city_value: + city1, city2 = extract_from_to(city_value) + city1 = extract_before_parenthesis(city1) + city2 = extract_before_parenthesis(city2) + if i==0 and city1 != question['org']: + return False, f"The first day's city should be {question['org']}." + + city_set.add(city1) + city_set.add(city2) + + else: + city_set.add(extract_before_parenthesis(city_value)) + + city_set.discard(question['org']) + + if len(city_set) != question['visiting_city_number']: + return False, f"The number of visiting cities should be {question['visiting_city_number']}." + + return True, None + +def is_valid_days(question, tested_data): + if type(tested_data) != list: + return False, "KeyError" + + lens = 0 + for i in range(min(question['days'],len(tested_data))): + if tested_data[i] != {} and 'current_city' in tested_data[i] and \ + tested_data[i]['current_city'] != "You don't need to fill in the information for this or later days.": + lens += 1 + + if lens != question['days']: + # print(lens) + return False, f"The number of days should be {question['days']}." + else: + return True, None + +def is_not_absent(question, tested_data): + needed_info = 6 * question['days'] + total_valid_info = 0 + + if not is_valid_days(question, tested_data)[0]: + return False, "Invalid Days" + + if not is_valid_visiting_city_number(question, tested_data)[0]: + return False, "Invalid City Number" + if type(tested_data) != list: + return False, "KeyError" + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + + if 'transportation' not in unit: + return False, f"No Transportation Info." + + if 'breakfast' not in unit: + return False, f"No Breakfast Info." + + if 'lunch' not in unit: + return False, f"No Lunch Info." + + if 'dinner' not in unit: + return False, f"No Dinner Info." + + if 'attraction' not in unit: + return False, f"No Attraction Info." + + if 'accommodation' not in unit: + return False, f"No Accommodation Info." + + if ('from ' in unit['current_city'] or 'to ' in unit['current_city']) and unit['transportation'] in ['','-']: + return False, f"No transportation in day {i+1} is not allowed." + + if ('from ' not in unit['current_city'] and ' to ' not in unit['current_city']) and unit['attraction'] in ['','-']: + return False, f"No attaction in day {i+1} is not allowed." + + if i != question['days'] - 1 and unit['accommodation'] in ['','-']: + return False, f"No accommodation in day {i+1} is not allowed." + + if (unit['breakfast'] in ['','-'] or unit['lunch'] in ['','-'] or unit['dinner'] in ['','-']) and 'from ' not in unit['current_city']: + return False, f"No meal in day {i+1} is not allowed." + + + for key in unit: + if unit[key] and unit[key] != '-': + total_valid_info += 1 + + + if total_valid_info * 1.0 / needed_info < 0.5: + return False, f"The absent information is more than 50%." + + return True, None + + +def evaluation(query_data, tested_data): + return_info = {} + return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data) + return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data) + return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data) + return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data) + return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data) + return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data) + return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data) + return_info['is_not_absent'] = is_not_absent(query_data, tested_data) + return return_info + +def boolean_evaluation(query_data, tested_data): + return_info = {} + return_info['is_reasonalbe_visiting_city'] = is_reasonalbe_visiting_city(query_data, tested_data) + return_info['is_valid_restaurants'] = is_valid_restaurants(query_data, tested_data) + return_info['is_valid_accommodation'] = is_valid_accommodaton(query_data, tested_data) + return_info['is_valid_attractions'] = is_valid_attractions(query_data, tested_data) + return_info['is_valid_transportation'] = is_valid_transportation(query_data, tested_data) + return_info['is_valid_information_in_current_city'] = is_valid_information_in_current_city(query_data, tested_data) + return_info['is_valid_information_in_sandbox'] = is_valid_information_in_sandbox(query_data, tested_data) + return_info['is_not_absent'] = is_not_absent(query_data, tested_data) + for key in return_info: + if return_info[key][0] == False: + print(return_info[key][1]) + return False + return True diff --git a/src/utils/travel_evaluation.py b/src/utils/travel_evaluation.py new file mode 100644 index 0000000..3b0153b --- /dev/null +++ b/src/utils/travel_evaluation.py @@ -0,0 +1,229 @@ +import json +import os + +import numpy as np +from datasets import load_dataset + +from tqdm import tqdm + +from evaluate import load +from utils.travel_commonsense_constraint import evaluation as commonsense_eval +from utils.travel_hard_constraint import evaluation as hard_eval + +def load_line_json_data(filename): + data = [] + with open(filename, 'r', encoding='utf-8') as f: + for line in f.read().strip().split('\n'): + unit = json.loads(line) + data.append(unit) + return data + +def load_file_json_data(filepath, mode): + data = [] + if mode == 'train': + file_cnt = 45 + elif mode == 'test': + file_cnt = 1000 + else: + file_cnt = 180 + for idx in range(1, file_cnt + 1): + plan = load_line_json_data(os.path.join(filepath, mode, f'plan_{idx}.json'))[0][0] + if 'plan' in plan: + data.append({'plan': plan['plan']}) + return data + + +def count_true_false(data): + """Count the number of true and false values in a list.""" + true_count = data.count(True) + false_count = data.count(False) + return true_count, false_count + + +def statistics(commonsense_statistic): + """Generate statistics for each level and day in the given data with a different structure.""" + result = {level: {day: {} for day in commonsense_statistic[level]} for level in commonsense_statistic} + + for level, days in commonsense_statistic.items(): + for day, dicts in days.items(): + for dct in dicts: + if dct: + for key, data in dct.items(): + true_count, false_count = count_true_false(data) + if key not in result[level][day]: + result[level][day][key] = {"true": 0, "false": 0} + result[level][day][key]["true"] += true_count + result[level][day][key]["false"] += false_count + + return result + +def paper_term_mapping(commonsense_constraint_record, hard_constraint_record): + mapping_dict = {'is_valid_information_in_current_city':'Within Current City','is_valid_information_in_sandbox':'Within Sandbox','is_reasonalbe_visiting_city':'Reasonable City Route','is_valid_restaurants':'Diverse Restaurants','is_valid_transportation':'Non-conf. Transportation','is_valid_attractions':'Diverse Attractions','is_valid_accommodation':'Minimum Nights Stay','is_not_absent':'Complete Information','valid_cost':'Budget','valid_room_rule':'Room Rule','valid_cuisine':'Cuisine','valid_room_type':'Room Type','valid_transportation':'Transportation'} + remap_commonsense_constraint_record = {level:{day:{} for day in [3,5,7]} for level in ['easy','medium','hard']} + remap_hard_constraint_record = {level:{day:{} for day in [3,5,7]} for level in ['easy','medium','hard']} + for level in commonsense_constraint_record: + for day in commonsense_constraint_record[level]: + remap_commonsense_constraint_record[level][day] = {mapping_dict[key] : val for key,val in commonsense_constraint_record[level][day].items()} + remap_hard_constraint_record[level][day] = {mapping_dict[key] : val for key,val in hard_constraint_record[level][day].items()} + return remap_commonsense_constraint_record, remap_hard_constraint_record + + +def eval_score(set_type: str, file_path: str, cache_dir: str): + + if set_type == 'train': + query_data_list = load_dataset('osunlp/TravelPlanner','train',download_mode="force_redownload", cache_dir=cache_dir)['train'] + elif set_type == 'validation': + query_data_list = load_dataset('osunlp/TravelPlanner','validation',download_mode="force_redownload", cache_dir=cache_dir)['validation'] + else: + raise NotImplementedError + + query_data_list = [x for x in query_data_list] + hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']} + commonsenseConstraint_statistic = {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']} + if file_path.endswith('json') or file_path.endswith('jsonl'): + tested_plans = load_line_json_data(file_path) + else: + tested_plans = load_file_json_data(file_path, set_type) + delivery_cnt = 0 + plan_constraint_store = [] + for idx in tqdm(range(0,len(query_data_list))): + query_data = query_data_list[idx] + tested_plan = tested_plans[idx] + if type(query_data) == str: + query_data = eval(query_data) + if type(tested_plan) == str: + tested_plan = eval(tested_plan) + if type(query_data['local_constraint']) == str: + query_data['local_constraint'] = eval(query_data['local_constraint']) + + if tested_plan['plan']: + delivery_cnt += 1 + commonsense_info_box = commonsense_eval(query_data,tested_plan['plan']) + else: + commonsense_info_box = None + + if commonsense_info_box and commonsense_info_box['is_not_absent'][0] and commonsense_info_box['is_valid_information_in_sandbox'][0]: + hard_info_box = hard_eval(query_data,tested_plan['plan']) + else: + hard_info_box = None + + plan_constraint_store.append({'commonsense_constraint':commonsense_info_box,'hard_constraint':hard_info_box}) + + commonsenseConstraint_statistic[query_data['level']][query_data['days']].append(commonsense_info_box) + hardConstraint_statistic[query_data['level']][query_data['days']].append(hard_info_box) + + constraint_record = {key: {day: {'house rule':0, 'cuisine':0, 'room type':0, 'transportation':0} for day in [3,5,7]} for key in ['medium','hard']} + constraint_mapping = {'house rule':'valid_room_rule','cuisine':'valid_cuisine','room type':'valid_room_type','transportation':'valid_transportation'} + mapping_constraint_record = {key: {day: {'valid_room_rule':0, 'valid_cuisine':0, 'valid_room_type':0, 'valid_transportation':0} for day in [3,5,7]} for key in ['medium','hard']} + count_record = {key:{day:0 for day in [3,5,7]} for key in ['easy','medium','hard']} + + for unit in query_data_list: + count_record[unit['level']][unit['days']] += 1 + for key in constraint_record['medium'][3]: + if unit['local_constraint'][key] != None: + constraint_record[unit['level']][unit['days']][key] += 1 + mapping_constraint_record[unit['level']][unit['days']][constraint_mapping[key]] += 1 + + commonsenseConstraint_statistic_processed = statistics(commonsenseConstraint_statistic) + hardConstraint_statistic_processed = statistics(hardConstraint_statistic) + + + data_record = {key:{day:[] for day in [3,5,7]} for key in ['easy','medium','hard']} + + constraint_dis_record = {"commonsense":{"pass":0,"total":0},"hard":{"pass":0,"total":0}} + constraint_count = {key:{day:{} for day in [3,5,7]} for key in ['easy','medium','hard']} + + for constraint in ['commonsense','hard']: + if constraint == 'commonsense': + constraint_statistic = commonsenseConstraint_statistic_processed + elif constraint == 'hard': + constraint_statistic = hardConstraint_statistic_processed + + key_dict = {'commonsense':['is_valid_information_in_current_city','is_valid_information_in_sandbox','is_reasonalbe_visiting_city','is_valid_restaurants','is_valid_transportation','is_valid_attractions','is_valid_accommodation','is_not_absent'],'hard':['valid_cost','valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']} + + for key in constraint_statistic: + for key2 in constraint_statistic[key]: + if key2 == -1: + print(constraint_statistic[key]) + exit(0) + for key3 in key_dict[constraint]: + data_record[key][key2].append('0/0') + if key3 in constraint_statistic[key][key2]: + constraint_dis_record[constraint]['pass'] += constraint_statistic[key][key2][key3]['true'] + if constraint == 'hard': + if key == 'hard' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']: + data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}" + constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3] + hardConstraint_statistic_processed[key][key2][key3]['total'] = mapping_constraint_record[key][key2][key3] + elif key == 'medium' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type']: + data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}" + constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3] + hardConstraint_statistic_processed[key][key2][key3]['total'] = mapping_constraint_record[key][key2][key3] + else: + data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}" + if key3 in ['valid_cost','valid_visitng_city_number','valid_days']: + constraint_dis_record[constraint]['total'] += count_record[key][key2] + constraint_count[key][key2][key3] = count_record[key][key2] + hardConstraint_statistic_processed[key][key2][key3]['total'] = count_record[key][key2] + else: + data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}" + constraint_dis_record[constraint]['total'] += count_record[key][key2] + constraint_count[key][key2][key3] = count_record[key][key2] + commonsenseConstraint_statistic_processed[key][key2][key3]['total'] = count_record[key][key2] + final_all_cnt = 0 + final_commonsense_cnt = 0 + final_hardConstraint_cnt = 0 + final_all_cnt_map = {level:0 for level in ['easy','medium','hard']} + for idx in (range(0,len(query_data_list))): + if plan_constraint_store[idx]['commonsense_constraint']: + final_commonsense_pass = True + final_hardConstraint_pass = True + for item in plan_constraint_store[idx]['commonsense_constraint']: + if plan_constraint_store[idx]['commonsense_constraint'][item][0] is not None and not plan_constraint_store[idx]['commonsense_constraint'][item][0]: + final_commonsense_pass = False + break + if plan_constraint_store[idx]['hard_constraint'] is None: + continue + for item in plan_constraint_store[idx]['hard_constraint']: + if plan_constraint_store[idx]['hard_constraint'][item][0] is not None and plan_constraint_store[idx]['hard_constraint'][item][0] == False: + final_hardConstraint_pass = False + break + + if final_commonsense_pass: + final_commonsense_cnt += 1 + if final_hardConstraint_pass: + final_hardConstraint_cnt += 1 + if final_commonsense_pass and final_hardConstraint_pass: + final_all_cnt += 1 + final_all_cnt_map[query_data_list[idx]['level']] += 1 + + result = {} + + remap_commonsense_constraint_record, remap_hard_constraint_record = paper_term_mapping(commonsenseConstraint_statistic_processed, hardConstraint_statistic_processed) + + if set_type == 'train': + result['Delivery Rate'] = delivery_cnt / 45 + result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 360 + result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 45 + result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 105 + result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 45 + result['Final Pass Rate'] = final_all_cnt / 45 + + elif set_type == 'validation': + result['Delivery Rate'] = delivery_cnt / 180 + result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 1440 + result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 180 + result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 420 + result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 180 + result['Final Pass Rate'] = final_all_cnt / 180 + + elif set_type == 'test': + result['Delivery Rate'] = delivery_cnt / 1000 + result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 8000 + result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 1000 + result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 2290 + result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 1000 + result['Final Pass Rate'] = final_all_cnt / 1000 + + + return result, {"Commonsense Constraint":remap_commonsense_constraint_record, "Hard Constraint":remap_hard_constraint_record} diff --git a/src/utils/travel_hard_constraint.py b/src/utils/travel_hard_constraint.py new file mode 100644 index 0000000..aef2d22 --- /dev/null +++ b/src/utils/travel_hard_constraint.py @@ -0,0 +1,275 @@ +from utils.travel_utils import get_valid_name_city,extract_before_parenthesis,extract_numbers_from_filenames +from travel_api.flights.apis import FlightSearch +from travel_api.accommodations.apis import AccommodationSearch +from travel_api.restaurants.apis import RestaurantSearch +from travel_api.googleDistanceMatrix.apis import GoogleDistanceMatrix +from travel_api.attractions.apis import AttractionSearch +import math +import json +import re +import numpy as np +import os +import sys +from tqdm import tqdm +import argparse + +# sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) +# os.chdir(os.path.dirname(os.path.abspath(__file__))) + +flight = FlightSearch() +accommodation = AccommodationSearch() +restaurants = RestaurantSearch() +googleDistanceMatrix = GoogleDistanceMatrix() +attractions = AttractionSearch() + + +def load_line_json_data(filename): + data = [] + with open(filename, 'r', encoding='utf-8') as f: + for line in f.read().strip().split('\n'): + unit = json.loads(line) + data.append(unit) + return data + + +def convert_bool_values(item): + if isinstance(item, dict): + # If the item is a dictionary, recurse on each value + return {key: convert_bool_values(value) for key, value in item.items()} + elif isinstance(item, list): + # If the item is a list, recurse on each item in the list + return [convert_bool_values(value) for value in item] + elif isinstance(item, tuple): + # If the item is a tuple, recurse on each item in the tuple and repackage as a tuple + return tuple(convert_bool_values(value) for value in item) + elif isinstance(item, np.bool_): # Here we check for numpy's bool_ type + # If the item is a numpy bool_, convert it to a standard Python bool + return bool(item) + else: + # If the item is any other type, return it unchanged + return item + + + + +def extract_from_to(text: str): + """ + Extracts 'A' and 'B' from the format "from A to B" in the given text, with B ending at a comma or the end of the string. + + Args: + - text (str): The input string. + + Returns: + - tuple: A tuple containing 'A' and 'B'. If no match is found, returns (None, None). + """ + pattern = r"from\s+(.+?)\s+to\s+([^,]+)(?=[,\s]|$)" + matches = re.search(pattern, text) + return matches.groups() if matches else (None, None) + + +def get_total_cost(question, tested_data): + if type(tested_data) != list: + return 1e100 + + total_cost = 0 + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + # transporation + if unit['transportation'] and unit['transportation'] != '-': + value = unit['transportation'] + org_city, dest_city = extract_from_to(value) + if org_city == None or dest_city == None: + org_city, dest_city = extract_from_to(unit['current_city']) + + if org_city == None or dest_city == None: + pass + else: + if 'flight number' in value.lower(): + res = flight.data[flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]] + if len(res) > 0: + total_cost += res['Price'].values[0] * question['people_number'] + + elif 'self-driving' in value.lower() or 'taxi' in value.lower(): + if 'self-driving' in value.lower(): + # print(org_city,dest_city) + cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'self-driving')['cost'] + total_cost += cost * math.ceil(question['people_number'] * 1.0 / 5) + else: + cost = googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'taxi')['cost'] + total_cost += cost * math.ceil(question['people_number'] * 1.0 / 4) + + # breakfast + if unit['breakfast'] and unit['breakfast'] != '-': + name, city = get_valid_name_city(unit['breakfast']) + res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)] + if len(res) > 0: + total_cost += res['Average Cost'].values[0] * question['people_number'] + + + # lunch + if unit['lunch'] and unit['lunch'] != '-': + name, city = get_valid_name_city(unit['lunch']) + res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)] + if len(res) > 0: + total_cost += res['Average Cost'].values[0] * question['people_number'] + + # dinner + if unit['dinner'] and unit['dinner'] != '-': + name, city = get_valid_name_city(unit['dinner']) + res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)] + if len(res) > 0: + total_cost += res['Average Cost'].values[0] * question['people_number'] + + # accommodation + if unit['accommodation'] and unit['accommodation'] != '-': + name, city = get_valid_name_city(unit['accommodation']) + res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)] + if len(res) > 0: + total_cost += res['price'].values[0] * math.ceil(question['people_number'] * 1.0 / res['maximum occupancy'].values[0]) + # print(total_cost) + return total_cost + + +def is_valid_room_rule(question, tested_data): + + if question['local_constraint']['house rule'] is None: + return None,None + if type(tested_data) != list: + return False, "KeyError" + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + if unit['accommodation'] and unit['accommodation'] != '-': + name, city = get_valid_name_city(unit['accommodation']) + res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)] + if len(res) > 0: + if question['local_constraint']['house rule'] == 'smoking' and 'No smoking' in str(res['house_rules'].values[0]): + return False, f"The house rule should be {question['local_constraint']['house rule']}." + if question['local_constraint']['house rule'] == 'parties' and 'No parties' in str(res['house_rules'].values[0]): + return False, f"The house rule should be {question['local_constraint']['house rule']}." + if question['local_constraint']['house rule'] == 'children under 10' and 'No children under 10' in str(res['house_rules'].values[0]): + return False, f"The house rule should be {question['local_constraint']['house rule']}." + if question['local_constraint']['house rule'] == 'visitors' and 'No visitors' in str(res['house_rules'].values[0]): + return False, f"The house rule should be {question['local_constraint']['house rule']}." + if question['local_constraint']['house rule'] == 'pets' and 'No pets' in str(res['house_rules'].values[0]): + return False, f"The house rule should be {question['local_constraint']['house rule']}." + + + return True, None + + + +def is_valid_cuisine(question, tested_data): + cuisine_set = set() + if question['local_constraint']['cuisine']: + if type(tested_data) != list: + return False, "KeyError" + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + + if unit['breakfast'] and unit['breakfast'] != '-': + name, city = get_valid_name_city(unit['breakfast']) + if city == question['org']: + continue + res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)] + if len(res) > 0: + for cuisine in question['local_constraint']['cuisine']: + if cuisine in res.iloc[0]['Cuisines']: + cuisine_set.add(cuisine) + + if unit['lunch'] and unit['lunch'] != '-': + name, city = get_valid_name_city(unit['lunch']) + if city == question['org']: + continue + res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)] + if len(res) > 0: + for cuisine in question['local_constraint']['cuisine']: + if cuisine in res.iloc[0]['Cuisines']: + cuisine_set.add(cuisine) + + if unit['dinner'] and unit['dinner'] != '-': + name, city = get_valid_name_city(unit['dinner']) + if city == question['org']: + continue + res = restaurants.data[(restaurants.data['Name'].astype(str).str.contains(re.escape(name))) & (restaurants.data['City'] == city)] + if len(res) > 0: + for cuisine in question['local_constraint']['cuisine']: + if cuisine in res.iloc[0]['Cuisines']: + cuisine_set.add(cuisine) + + if len(cuisine_set) == len(question['local_constraint']['cuisine']): + return True, None + else: + # judge which cuisine is not satisfied + for cuisine in question['local_constraint']['cuisine']: + if cuisine not in cuisine_set: + return False, f"The cuisine {cuisine} is not satisfied." + # return False, f"The cuisine should be {question['local_constraint']['cuisine']}." + else: + return None,None + + +def is_valid_transportation(question, tested_data): + if question['local_constraint']['transportation'] is None: + return None,None + if type(tested_data) != list: + return False, "KeyError" + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + if unit['transportation'] and unit['transportation'] != '-': + value = unit['transportation'] + if question['local_constraint']['transportation'] == 'no flight' and 'Flight' in value: + return False, f"The transportation should not be {question['local_constraint']['transportation']}." + elif question['local_constraint']['transportation'] == 'no self-driving' and 'Self-driving' in value: + return False, f"The transportation should not be {question['local_constraint']['transportation']}." + + return True, None + + +def is_valid_room_type(question, tested_data): + if question['local_constraint']['room type'] is None: + return None,None + if type(tested_data) != list: + return False, "KeyError" + for i in range(min(question['days'],len(tested_data))): + unit = tested_data[i] + if unit['accommodation'] and unit['accommodation'] != '-': + name, city = get_valid_name_city(unit['accommodation']) + res = accommodation.data[(accommodation.data['NAME'].astype(str).str.contains(re.escape(name))) & (accommodation.data['city'] == city)] + if len(res) > 0: + if question['local_constraint']['room type'] == 'not shared room' and res['room type'].values[0] == 'Shared room': + return False, f"The room type should be {question['local_constraint']['room type']}." + # "shared room", "not shared room", "private room", "entire room" + elif question['local_constraint']['room type'] == 'shared room' and res['room type'].values[0] != 'Shared room': + return False, f"The room type should be {question['local_constraint']['room type']}." + + elif question['local_constraint']['room type'] == 'private room' and res['room type'].values[0] != 'Private room': + return False, f"The room type should be {question['local_constraint']['room type']}." + + elif question['local_constraint']['room type'] == 'entire room' and res['room type'].values[0] != 'Entire home/apt': + return False, f"The room type should be {question['local_constraint']['room type']}." + + return True, None + + +def evaluation(query_data, tested_data): + return_info = {} + return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data) + return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data) + return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data) + return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data) + return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None) + return return_info + +def boolean_evaluation(query_data, tested_data): + return_info = {} + return_info['valid_cuisine'] = is_valid_cuisine(query_data, tested_data) + return_info['valid_room_rule'] = is_valid_room_rule(query_data, tested_data) + return_info['valid_transportation'] = is_valid_transportation(query_data, tested_data) + return_info['valid_room_type'] = is_valid_room_type(query_data, tested_data) + return_info['valid_cost'] = (bool(get_total_cost(query_data, tested_data) <= query_data['budget']), None) + for key in return_info: + if return_info[key][0] == False: + print(key) + return False + return True + diff --git a/src/utils/travel_utils.py b/src/utils/travel_utils.py new file mode 100644 index 0000000..701d6b6 --- /dev/null +++ b/src/utils/travel_utils.py @@ -0,0 +1,128 @@ +import re + +from openai import OpenAI +import json +import os +import logging +from utils.flow_utils import get_response_from_client, ReadLineFromFile + + +def convert_to_json_with_gpt(text, openai_key, max_fail=3, model_name='gpt-4-1106-preview'): + todo_prompt = """Please assist me in extracting valid information from a given natural language text and reconstructing it in JSON format, as demonstrated in the following example. If transportation details indicate a journey from one city to another (e.g., from A to B), the 'current_city' should be updated to the destination city (in this case, B). Use a ';' to separate different attractions, with each attraction formatted as 'Name, City'. If there's information about transportation, ensure that the 'current_city' aligns with the destination mentioned in the transportation details (i.e., the current city should follow the format 'from A to B'). Also, ensure that all flight numbers and costs are followed by a colon (i.e., 'Flight Number:' and 'Cost:'), consistent with the provided example. Each item should include ['day', 'current_city', 'transportation', 'breakfast', 'attraction', 'lunch', 'dinner', 'accommodation']. Replace non-specific information like 'eat at home/on the road' with '-'. Additionally, delete any '$' symbols. + -----EXAMPLE----- + [{ + "days": 1, + "current_city": "from Dallas to Peoria", + "transportation": "Flight Number: 4044830, from Dallas to Peoria, Departure Time: 13:10, Arrival Time: 15:01", + "breakfast": "-", + "attraction": "Peoria Historical Society, Peoria;Peoria Holocaust Memorial, Peoria;", + "lunch": "-", + "dinner": "Tandoor Ka Zaika, Peoria", + "accommodation": "Bushwick Music Mansion, Peoria" + }, + { + "days": 2, + "current_city": "Peoria", + "transportation": "-", + "breakfast": "Tandoor Ka Zaika, Peoria", + "attraction": "Peoria Riverfront Park, Peoria;The Peoria PlayHouse, Peoria;Glen Oak Park, Peoria;", + "lunch": "Cafe Hashtag LoL, Peoria", + "dinner": "The Curzon Room - Maidens Hotel, Peoria", + "accommodation": "Bushwick Music Mansion, Peoria" + }, + { + "days": 3, + "current_city": "from Peoria to Dallas", + "transportation": "Flight Number: 4045904, from Peoria to Dallas, Departure Time: 07:09, Arrival Time: 09:20", + "breakfast": "-", + "attraction": "-", + "lunch": "-", + "dinner": "-", + "accommodation": "-" + }] + -----EXAMPLE END----- + """ + + client = OpenAI(api_key=openai_key) + + total_price = 0.0 + attempt = 1 + while attempt <= max_fail: + prompt = todo_prompt + f"text: {text}\njson:" + response, price = get_response_from_client(client, [{'role': 'user', 'content': prompt}], model_name, 1.) + total_price += price + + logging.info((f'Generated JSON: \n```\n{response}\n```')) + + try: + result = response.split('```json')[1].split('```')[0] + except: + attempt += 1 + todo_prompt += f"Previous generated plan: {response}\nThis plan cannot be parsed. The plan has to follow the format ```json [The generated json format plan]```\n" + continue + + try: + result = eval(result) + except: + attempt += 1 + todo_prompt += f"Previous generated plan: {response}\nThis is an illegal json format.\n" + + break + + if attempt > max_fail: + result = None + + return result, total_price + + +def get_baseline_result(args, idx): + baseline_result_path = os.path.join(args.results_dir, args.task, f"baseline_{args.set_type}.jsonl") + baseline_results = ReadLineFromFile(baseline_result_path) + return baseline_results[idx] + + +def get_result_file(args): + result_file = os.path.join(args.results_dir, args.task, f"{args.set_type}_{args.model_name.replace('/','_')}_{args.get_observation}_{args.results_name}.jsonl") + if not os.path.exists(os.path.join(args.results_dir, args.task)): + os.makedirs(os.path.join(args.results_dir, args.task)) + return result_file + + +def write_result_into_file(result, result_file, is_string=False): + with open(result_file, 'a') as w: + if is_string: + w.write(result + '\n') + else: + output = json.dumps(result) + w.write(output + '\n') + w.close() + return + + +def extract_numbers_from_filenames(directory): + # Define the pattern to match files + pattern = r'annotation_(\d+).json' + + # List all files in the directory + files = os.listdir(directory) + + # Extract numbers from filenames that match the pattern + numbers = [int(re.search(pattern, file).group(1)) for file in files if re.match(pattern, file)] + + return numbers + + +def get_valid_name_city(info): + # Modified the pattern to preserve spaces at the end of the name + pattern = r'(.*?),\s*([^,]+)(\(\w[\w\s]*\))?$' + match = re.search(pattern, info) + if match: + return match.group(1).strip(), extract_before_parenthesis(match.group(2).strip()).strip() + else: + print(f"{info} can not be parsed, '-' will be used instead.") + return "-","-" + + +def extract_before_parenthesis(s): + match = re.search(r'^(.*?)\([^)]*\)', s) + return match.group(1) if match else s