diff --git a/data/emall/product.json b/data/emall/product.json new file mode 100644 index 0000000..f39d4ce --- /dev/null +++ b/data/emall/product.json @@ -0,0 +1,11 @@ +[ + { + "content": "product name: GlowPod\nPrice: $29.99\nšŸ’« Escape the chaos, one breath at a time.\nTurn your room into a sanctuary with GlowPod AromaDiffuser. šŸŒæ A gentle mist of your favorite essential oils, soft LED mood lighting, and a whisper-quiet designā€”perfect for unwinding after a long day.\nšŸŒ™ Your night routine just got better:\n- Relax after work.\n- Sleep soundly.\n- Wake up refreshed.\nāœØ Life's too busy not to find your calm. Shop now and bring GlowPod home." + }, + { + "content": "product name: Mistify\nPrice: $34.99\nšŸŒˆ Whatā€™s the vibe today? Calm? Focused? Energized?\nLet Mistify AirMist set the tone. šŸ§˜ Diffuse calming lavender, refreshing citrus, or your go-to essential oils. The adjustable mist modes and elegant design blend beautifully into any space.\nšŸ“Œ Perfect for:\n- WFH productivity boosts.\n- Cozy reading nooks.\n- Creating spa-like vibes at home.\nšŸŒŸ Start your self-care journeyā€”$34.99 well spent on *you*. Tap the link and feel the difference." + }, + { + "content": "product name: ZenCloud\nPrice: $24.99\nšŸ’Ø Breathe in calm, exhale stress.\nZenCloud VaporSphere isnā€™t just a diffuserā€”itā€™s your ticket to daily tranquility. šŸŒŗ With its portable size and minimalist design, you can create your personal zen zone anytime, anywhere.\nā¤ļø Why youā€™ll love it:\nāœ”ļø Enhances mood with aromatherapy.\nāœ”ļø Helps with dry air during winter.\nāœ”ļø Compact and travel-friendly.\nāœØ This isnā€™t just a productā€”itā€™s a vibe. Ready to elevate your space? šŸŒæ Click to shop ZenCloud now!" + } +] diff --git a/oasis/social_agent/agent.py b/oasis/social_agent/agent.py index 819f30f..c78978d 100644 --- a/oasis/social_agent/agent.py +++ b/oasis/social_agent/agent.py @@ -129,12 +129,12 @@ async def perform_action_by_llm(self): openai_messages, _ = self.memory.get_context() content = "" # sometimes self.memory.get_context() would lose system prompt - start_message = openai_messages[0] - if start_message["role"] != self.system_message.role_name: - openai_messages = [{ - "role": self.system_message.role_name, - "content": self.system_message.content, - }] + openai_messages + # start_message = openai_messages[0] + # if start_message["role"] != self.system_message.role_name: + # openai_messages = [{ + # "role": self.system_message.role_name, + # "content": self.system_message.content, + # }] + openai_messages if not openai_messages: openai_messages = [{ @@ -165,7 +165,12 @@ async def perform_action_by_llm(self): exec_functions = [] while retry > 0: - + start_message = openai_messages[0] + if start_message["role"] != self.system_message.role_name: + openai_messages = [{ + "role": self.system_message.role_name, + "content": self.system_message.content, + }] + openai_messages mes_id = await self.infe_channel.write_to_receive_queue( openai_messages) mes_id, content = await self.infe_channel.read_from_send_queue( diff --git a/oasis/social_agent/agent_action.py b/oasis/social_agent/agent_action.py index a6c3fc4..778c19b 100644 --- a/oasis/social_agent/agent_action.py +++ b/oasis/social_agent/agent_action.py @@ -48,6 +48,7 @@ def get_openai_function_list(self) -> list[OpenAIFunction]: self.unfollow, self.mute, self.unmute, + self.purchase_product, ] ] @@ -598,3 +599,18 @@ async def undo_dislike_comment(self, comment_id: int): """ return await self.perform_action(comment_id, ActionType.UNDO_DISLIKE_COMMENT.value) + + async def purchase_product(self, product_name: str, purchase_num: int): + r"""Purchase a product. + + Args: + product_name (str): The name of the product to be purchased. + purchase_num (int): The number of products to be purchased. + + Returns: + dict: A dictionary with 'success' indicating if the purchase was + successful. + """ + purchase_message = (product_name, purchase_num) + return await self.perform_action(purchase_message, + ActionType.PURCHASE_PRODUCT.value) diff --git a/oasis/social_platform/database.py b/oasis/social_platform/database.py index f4050ef..2c77144 100644 --- a/oasis/social_platform/database.py +++ b/oasis/social_platform/database.py @@ -33,6 +33,7 @@ COMMENT_SCHEMA_SQL = "comment.sql" COMMENT_LIKE_SCHEMA_SQL = "comment_like.sql" COMMENT_DISLIKE_SCHEMA_SQL = "comment_dislike.sql" +PRODUCT_SCHEMA_SQL = "product.sql" TABLE_NAMES = { "user", @@ -46,6 +47,7 @@ "comment.sql", "comment_like.sql", "comment_dislike.sql", + "product.sql", } @@ -146,6 +148,12 @@ def create_db(db_path: str | None = None): comment_dislike_sql_script = sql_file.read() cursor.executescript(comment_dislike_sql_script) + # Read and execute the product table SQL script: + product_sql_path = osp.join(schema_dir, PRODUCT_SCHEMA_SQL) + with open(product_sql_path, "r") as sql_file: + product_sql_script = sql_file.read() + cursor.executescript(product_sql_script) + # Commit the changes: conn.commit() diff --git a/oasis/social_platform/platform.py b/oasis/social_platform/platform.py index 15fe4a0..783fb9e 100644 --- a/oasis/social_platform/platform.py +++ b/oasis/social_platform/platform.py @@ -195,6 +195,56 @@ async def sign_up(self, agent_id, user_message): except Exception as e: return {"success": False, "error": str(e)} + async def sign_up_product(self, product_id: int, product_name: str): + # Note: do not sign up the product with the same product name + try: + product_insert_query = ( + "INSERT INTO product (product_id, product_name) VALUES (?, ?)") + self.pl_utils._execute_db_command(product_insert_query, + (product_id, product_name), + commit=True) + return {"success": True, "product_id": product_id} + except Exception as e: + return {"success": False, "error": str(e)} + + async def purchase_product(self, agent_id, purchase_message): + product_name, purchase_num = purchase_message + if self.recsys_type == RecsysType.REDDIT: + current_time = self.sandbox_clock.time_transfer( + datetime.now(), self.start_time) + else: + current_time = os.environ["SANDBOX_TIME"] + # try: + user_id = agent_id + # Check if a like record already exists + product_check_query = ( + "SELECT * FROM 'product' WHERE product_name = ?") + self.pl_utils._execute_db_command(product_check_query, + (product_name, )) + check_result = self.db_cursor.fetchone() + if not check_result: + # Product not found + return {"success": False, "error": "No such product."} + else: + product_id = check_result[0] + + product_update_query = ( + "UPDATE product SET sales = sales + ? WHERE product_name = ?") + self.pl_utils._execute_db_command(product_update_query, + (purchase_num, product_name), + commit=True) + + # Record the action in the trace table + action_info = { + "product_name": product_name, + "purchase_num": purchase_num + } + self.pl_utils._record_trace(user_id, ActionType.PURCHASE_PRODUCT.value, + action_info, current_time) + return {"success": True, "product_id": product_id} + # except Exception as e: + # return {"success": False, "error": str(e)} + async def refresh(self, agent_id: int): # Retrieve posts for a specific id from the rec table if self.recsys_type == RecsysType.REDDIT: diff --git a/oasis/social_platform/schema/product.sql b/oasis/social_platform/schema/product.sql new file mode 100644 index 0000000..ad8060b --- /dev/null +++ b/oasis/social_platform/schema/product.sql @@ -0,0 +1,6 @@ +-- This is the schema definition for the product table +CREATE TABLE product ( + product_id INTEGER PRIMARY KEY, + product_name TEXT, + sales INTEGER DEFAULT 0 +); diff --git a/oasis/social_platform/typing.py b/oasis/social_platform/typing.py index 5e39e48..c34e928 100644 --- a/oasis/social_platform/typing.py +++ b/oasis/social_platform/typing.py @@ -38,6 +38,7 @@ class ActionType(Enum): DISLIKE_COMMENT = "dislike_comment" UNDO_DISLIKE_COMMENT = "undo_dislike_comment" DO_NOTHING = "do_nothing" + PURCHASE_PRODUCT = "purchase_product" class RecsysType(Enum): diff --git a/scripts/reddit_emall_demo/action_space_prompt.txt b/scripts/reddit_emall_demo/action_space_prompt.txt new file mode 100644 index 0000000..ac7d26d --- /dev/null +++ b/scripts/reddit_emall_demo/action_space_prompt.txt @@ -0,0 +1,24 @@ +# OBJECTIVE +You're a Twitter user, and I'll present you with some posts. After you see the posts, choose some actions from the following functions. +Suppose you are a real Twitter user. Please simulate real behavior. + +- do_nothing: Most of the time, you just don't feel like reposting or liking a post, and you just want to look at it. In such cases, choose this action "do_nothing" +- repost: Repost a post. + - Arguments: "post_id" (integer) - The ID of the post to be reposted. You can `repost` when you want to spread it. +- like_post: Likes a specified post. + - Arguments: "post_id" (integer) - The ID of the tweet to be liked. You can `like` when you feel something interesting or you agree with. +- dislike_post: Dislikes a specified post. + - Arguments: "post_id" (integer) - The ID of the post to be disliked. You can use `dislike` when you disagree with a post or find it uninteresting. +- create_comment: Creates a comment on a specified post. + - Arguments: + "post_id" (integer) - The ID of the post to comment on. + "content" (str) - The content of the comment. + Use `create_comment` to engage in conversations or share your thoughts on a post. +- follow: Follow a user specified by 'followee_id'. You can `follow' when you respect someone, love someone, or care about someone. + - Arguments: "followee_id" (integer) - The ID of the user to be followed. +- mute: Mute a user specified by 'mutee_id'. You can `mute' when you hate someone, dislike someone, or disagree with someone. + - Arguments: "mutee_id" (integer) - The ID of the user to be followed. + - Arguments: "post_id" (integer) - The ID of the post to be disliked. You can use `dislike` when you disagree with a post or find it uninteresting. +- purchase_product: Purchase a product. + - Arguments: "product_name" (string) - The name of the product to be purchased. + - Arguments: "purchase_num" (integer) - The number of products to be purchased. diff --git a/scripts/reddit_emall_demo/emall.yaml b/scripts/reddit_emall_demo/emall.yaml new file mode 100644 index 0000000..30647ac --- /dev/null +++ b/scripts/reddit_emall_demo/emall.yaml @@ -0,0 +1,22 @@ +--- +data: + user_path: ./data/reddit/user_data_36.json # Path to the user profile file + pair_path: ./data/emall/product.json # Path to the initial post file + db_path: ./emall.db # Path for saving the social media database after the experiment +simulation: + recsys_type: reddit + controllable_user: true # Whether to use a controllable user, who posts prepared posts on the simulated social platform according to our instructions + allow_self_rating: false # Reddit feature: does not allow users to rate their own content + show_score: true # Reddit feature: users can only see scores, not separate upvote and downvote counts + activate_prob: 0.2 # Probability of each agent being activated to perform an action at each timestep + clock_factor: 10 # Magnification factor of the first timestep in real-world time, not recommended to change + num_timesteps: 1 # Number of timesteps the simulation experiment runs + max_rec_post_len: 50 # Number of posts in each user's recommendation list cache + round_post_num: 30 # Number of posts sent by controllable_user at each timestep + follow_post_agent: false # Whether all agents follow the controllable_user + mute_post_agent: false # Whether all agents mute the controllable_user + refresh_rec_post_count: 5 # Number of posts an agent sees each time they refresh + action_space_file_path: ./scripts/reddit_emall_demo/action_space_prompt.txt # Path to the action_space_prompt file +inference: + model_type: gpt-4o-mini # Name of the OpenAI model + is_openai_model: true # Whether it is an OpenAI model diff --git a/scripts/reddit_emall_demo/emall_simulation.py b/scripts/reddit_emall_demo/emall_simulation.py new file mode 100644 index 0000000..0bd4f79 --- /dev/null +++ b/scripts/reddit_emall_demo/emall_simulation.py @@ -0,0 +1,215 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the ā€œLicenseā€); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an ā€œAS ISā€ BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# flake8: noqa: E402 +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import os +import random +import sys +from datetime import datetime, timedelta +from typing import Any + +from colorama import Back +from yaml import safe_load + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + +from oasis.clock.clock import Clock +from oasis.social_agent.agents_generator import (gen_control_agents_with_data, + generate_reddit_agents) +from oasis.social_platform.channel import Channel +from oasis.social_platform.platform import Platform +from oasis.social_platform.typing import ActionType + +social_log = logging.getLogger(name="social") +social_log.setLevel("DEBUG") +now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") +file_handler = logging.FileHandler(f"./log/social-{str(now)}.log", + encoding="utf-8") +file_handler.setLevel("DEBUG") +file_handler.setFormatter( + logging.Formatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s")) +social_log.addHandler(file_handler) +stream_handler = logging.StreamHandler() +stream_handler.setLevel("DEBUG") +stream_handler.setFormatter( + logging.Formatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s")) +social_log.addHandler(stream_handler) + +parser = argparse.ArgumentParser(description="Arguments for script.") +parser.add_argument( + "--config_path", + type=str, + help="Path to the YAML config file.", + required=False, + default="", +) + +DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") +DEFAULT_DB_PATH = os.path.join(DATA_DIR, "mock_reddit.db") +DEFAULT_USER_PATH = os.path.join(DATA_DIR, "reddit", + "filter_user_results.json") +DEFAULT_PAIR_PATH = os.path.join(DATA_DIR, "reddit", "RS-RC-pairs.json") + +ROUND_POST_NUM = 20 + + +async def running( + db_path: str | None = DEFAULT_DB_PATH, + user_path: str | None = DEFAULT_USER_PATH, + pair_path: str | None = DEFAULT_PAIR_PATH, + round_post_num: str | None = ROUND_POST_NUM, + num_timesteps: int = 3, + clock_factor: int = 60, + recsys_type: str = "reddit", + controllable_user: bool = True, + allow_self_rating: bool = False, + show_score: bool = True, + max_rec_post_len: int = 20, + activate_prob: float = 0.1, + follow_post_agent: bool = False, + mute_post_agent: bool = True, + model_configs: dict[str, Any] | None = None, + inference_configs: dict[str, Any] | None = None, + refresh_rec_post_count: int = 10, + action_space_file_path: str = None, +) -> None: + db_path = DEFAULT_DB_PATH if db_path is None else db_path + user_path = DEFAULT_USER_PATH if user_path is None else user_path + pair_path = DEFAULT_PAIR_PATH if pair_path is None else pair_path + if os.path.exists(db_path): + os.remove(db_path) + + start_time = datetime(2024, 8, 6, 8, 0) + clock = Clock(k=clock_factor) + twitter_channel = Channel() + print(action_space_file_path) + with open(action_space_file_path, "r", encoding="utf-8") as file: + action_space_prompt = file.read() + + infra = Platform( + db_path, + twitter_channel, + clock, + start_time, + allow_self_rating=allow_self_rating, + show_score=show_score, + recsys_type=recsys_type, + max_rec_post_len=max_rec_post_len, + refresh_rec_post_count=refresh_rec_post_count, + ) + await infra.sign_up_product(1, "GlowPod") + await infra.sign_up_product(2, "Mistify") + await infra.sign_up_product(3, "ZenCloud") + + inference_channel = Channel() + + twitter_task = asyncio.create_task(infra.running()) + + if inference_configs["model_type"][:3] == "gpt": + is_openai_model = True + if not controllable_user: + raise ValueError("Uncontrollable user is not supported") + else: + agent_graph, id_mapping = await gen_control_agents_with_data( + twitter_channel, + 1, + ) + agent_graph = await generate_reddit_agents( + user_path, + twitter_channel, + inference_channel, + agent_graph, + id_mapping, + follow_post_agent, + mute_post_agent, + action_space_prompt, + inference_configs["model_type"], + is_openai_model, + ) + with open(pair_path, "r") as f: + pairs = json.load(f) + + for timestep in range(num_timesteps): + os.environ["TIME_STAMP"] = str(timestep + 1) + if timestep == 0: + start_time_0 = datetime.now() + print(Back.GREEN + f"timestep:{timestep}" + Back.RESET) + social_log.info(f"timestep:{timestep + 1}.") + + post_agent = agent_graph.get_agent(0) + if timestep == 0: + await post_agent.perform_action_by_data( + "create_post", content=pairs[0]["content"]) + await post_agent.perform_action_by_data( + "create_post", content=pairs[1]["content"]) + await post_agent.perform_action_by_data( + "create_post", content=pairs[2]["content"]) + + await infra.update_rec_table() + social_log.info("update rec table.") + tasks = [] + for _, agent in agent_graph.get_agents(): + if agent.user_info.is_controllable is False: + if random.random() < activate_prob: + tasks.append(agent.perform_action_by_llm()) + random.shuffle(tasks) + await asyncio.gather(*tasks) + + if timestep == 0: + time_difference = datetime.now() - start_time_0 + + # Convert two hours into seconds since time_difference is a + # timedelta object + two_hours_in_seconds = timedelta(hours=2).total_seconds() + + # Calculate two hours divided by the time difference (in seconds) + clock_factor = two_hours_in_seconds / \ + time_difference.total_seconds() + clock.k = clock_factor + social_log.info(f"clock_factor: {clock_factor}") + + await twitter_channel.write_to_receive_queue((None, None, ActionType.EXIT)) + + await twitter_task + + social_log.info("Simulation finish!") + + +if __name__ == "__main__": + args = parser.parse_args() + + if os.path.exists(args.config_path): + with open(args.config_path, "r") as f: + cfg = safe_load(f) + data_params = cfg.get("data") + simulation_params = cfg.get("simulation") + model_configs = cfg.get("model") + inference_params = cfg.get("inference") + asyncio.run( + running( + **data_params, + **simulation_params, + model_configs=model_configs, + inference_configs=inference_params, + ), + debug=True, + ) + else: + asyncio.run(running()) diff --git a/test/agent/test_action_docstring.py b/test/agent/test_action_docstring.py index 76b107c..2830a7b 100644 --- a/test/agent/test_action_docstring.py +++ b/test/agent/test_action_docstring.py @@ -42,6 +42,7 @@ def test_transfer_to_openai_function(): SocialAction.dislike_comment, SocialAction.undo_dislike_comment, SocialAction.do_nothing, + SocialAction.purchase_product, ] ] assert action_funcs is not None diff --git a/test/agent/test_twitter_user_agent_all_actions.py b/test/agent/test_twitter_user_agent_all_actions.py index 5215376..8ac3243 100644 --- a/test/agent/test_twitter_user_agent_all_actions.py +++ b/test/agent/test_twitter_user_agent_all_actions.py @@ -160,5 +160,10 @@ async def test_agents_actions(setup_twitter): assert return_message["success"] is True await asyncio.sleep(random.uniform(0, 0.1)) + await infra.sign_up_product(1, "apple") + return_message = await action_agent.env.action.purchase_product("apple", 1) + assert return_message["success"] is True + await asyncio.sleep(random.uniform(0, 0.1)) + await channel.write_to_receive_queue((None, None, ActionType.EXIT)) await task diff --git a/test/infra/database/test_product.py b/test/infra/database/test_product.py new file mode 100644 index 0000000..9b71f1e --- /dev/null +++ b/test/infra/database/test_product.py @@ -0,0 +1,129 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the ā€œLicenseā€); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an ā€œAS ISā€ BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import os +import os.path as osp +import sqlite3 + +import pytest + +from oasis.social_platform.platform import Platform +from oasis.social_platform.typing import ActionType + +parent_folder = osp.dirname(osp.abspath(__file__)) +test_db_filepath = osp.join(parent_folder, "test.db") + + +class MockChannel: + + def __init__(self): + self.call_count = 0 + self.messages = [] # Used to store sent messages + + async def receive_from(self): + if self.call_count == 0: + self.call_count += 1 + return ("id_", (1, ('apple', 1), + ActionType.PURCHASE_PRODUCT.value)) + if self.call_count == 1: + self.call_count += 1 + return ("id_", (2, ('apple', 2), + ActionType.PURCHASE_PRODUCT.value)) + if self.call_count == 2: + self.call_count += 1 + return ("id_", (2, ('banana', 1), + ActionType.PURCHASE_PRODUCT.value)) + if self.call_count == 3: + self.call_count += 1 + return ("id_", (2, ('orange', 1), + ActionType.PURCHASE_PRODUCT.value)) + else: + return ("id_", (None, None, "exit")) + + async def send_to(self, message): + self.messages.append(message) # Store message for later assertion + if self.call_count == 1: + print(message[2]) + msg = "Purchase apple from user 1 failed" + assert message[2]["success"] is True, msg + assert message[2]["product_id"] == 1 + elif self.call_count == 2: + msg = "Purchase apple from user 2 failed" + assert message[2]["success"] is True, msg + assert message[2]["product_id"] == 1 + elif self.call_count == 3: + msg = "Purchase banana from user 2 failed" + assert message[2]["success"] is True, msg + assert message[2]["product_id"] == 2 + elif self.call_count == 4: + assert message[2]["success"] is False + + +@pytest.fixture +def setup_platform(): + # Ensure test.db does not exist before testing + if os.path.exists(test_db_filepath): + os.remove(test_db_filepath) + + # Create database and tables + db_path = test_db_filepath + + mock_channel = MockChannel() + instance = Platform(db_path, mock_channel) + return instance + + +@pytest.mark.asyncio +async def test_search_user(setup_platform): + try: + platform = setup_platform + + # Insert 1 user into the user table before the test starts + conn = sqlite3.connect(test_db_filepath) + cursor = conn.cursor() + cursor.execute( + ("INSERT INTO user " + "(user_id, agent_id, user_name, num_followings, num_followers) " + "VALUES (?, ?, ?, ?, ?)"), + (1, 1, "user1", 0, 0), + ) + cursor.execute( + ("INSERT INTO user " + "(user_id, agent_id, user_name, num_followings, num_followers) " + "VALUES (?, ?, ?, ?, ?)"), + (2, 2, "user2", 0, 0), + ) + conn.commit() + + conn = sqlite3.connect(test_db_filepath) + cursor = conn.cursor() + + await platform.sign_up_product(1, "apple") + await platform.sign_up_product(2, "banana") + # print_db_contents(test_db_filepath) + + await platform.running() + + # Verify that the trace table correctly recorded the operation + cursor.execute( + "SELECT * FROM product WHERE product_name='apple' and sales=3") + assert cursor.fetchone() is not None, "apple sales is not 3" + cursor.execute( + "SELECT * FROM product WHERE product_name='banana' and sales=1") + assert cursor.fetchone() is not None, "banana sales is not 1" + + finally: + conn.close() + # Cleanup + if os.path.exists(test_db_filepath): + os.remove(test_db_filepath)