From 12649d1f41f6b2b8c21b011f2654ec65c25470b2 Mon Sep 17 00:00:00 2001 From: Dmitrii Khizbullin Date: Mon, 3 Jul 2023 12:44:59 +0300 Subject: [PATCH 1/3] Removed inheritance of TaskSpecifier and TaskPlanner from ChatAgent --- camel/agents/task_agent.py | 28 +++++++++++++++------------- camel/societies/role_playing.py | 5 +++-- test/agents/test_task_agent.py | 8 ++++---- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/camel/agents/task_agent.py b/camel/agents/task_agent.py index 9deb4c572d..c14e24225f 100644 --- a/camel/agents/task_agent.py +++ b/camel/agents/task_agent.py @@ -20,7 +20,7 @@ from camel.typing import ModelType, RoleType, TaskType -class TaskSpecifyAgent(ChatAgent): +class TaskSpecifyAgent: r"""An agent that specifies a given task prompt by prompting the user to provide more details. @@ -54,6 +54,7 @@ def __init__( output_language: Optional[str] = None, ) -> None: + self.task_specify_prompt: Union[str, TextPrompt] if task_specify_prompt is None: task_specify_prompt_template = PromptTemplateGenerator( ).get_task_specify_prompt(task_type) @@ -72,10 +73,10 @@ def __init__( content="You can make a task more specific.", ) - super().__init__(system_message, model, model_config, - output_language=output_language) + self.chat_agent = ChatAgent(system_message, model, model_config, + output_language=output_language) - def step( + def specify_prompt( self, original_task_prompt: Union[str, TextPrompt], meta_dict: Optional[Dict[str, Any]] = None, @@ -92,7 +93,8 @@ def step( Returns: TextPrompt: The specified task prompt. """ - self.reset() + self.chat_agent.reset() + self.task_specify_prompt = self.task_specify_prompt.format( task=original_task_prompt) @@ -102,7 +104,7 @@ def step( task_msg = BaseMessage.make_user_message( role_name="Task Specifier", content=self.task_specify_prompt) - specifier_response = super().step(task_msg) + specifier_response = self.chat_agent.step(task_msg) if (specifier_response.msgs is None or len(specifier_response.msgs) == 0): raise RuntimeError("Task specification failed.") @@ -114,7 +116,7 @@ def step( return TextPrompt(specified_task_msg.content) -class TaskPlannerAgent(ChatAgent): +class TaskPlannerAgent: r"""An agent that helps divide a task into subtasks based on the input task prompt. @@ -148,10 +150,10 @@ def __init__( content="You are a helpful task planner.", ) - super().__init__(system_message, model, model_config, - output_language=output_language) + self.chat_agent = ChatAgent(system_message, model, model_config, + output_language=output_language) - def step( + def create_plan( self, task_prompt: Union[str, TextPrompt], ) -> TextPrompt: @@ -165,14 +167,14 @@ def step( TextPrompt: A prompt for the subtasks generated by the agent. """ # TODO: Maybe include roles information. - self.reset() + self.chat_agent.reset() + self.task_planner_prompt = self.task_planner_prompt.format( task=task_prompt) task_msg = BaseMessage.make_user_message( role_name="Task Planner", content=self.task_planner_prompt) - # sub_tasks_msgs, terminated, _ - task_response = super().step(task_msg) + task_response = self.chat_agent.step(task_msg) if task_response.msgs is None: raise RuntimeError("Got None Subtasks messages.") diff --git a/camel/societies/role_playing.py b/camel/societies/role_playing.py index d8d66efe71..d8be394bef 100644 --- a/camel/societies/role_playing.py +++ b/camel/societies/role_playing.py @@ -112,7 +112,7 @@ def __init__( output_language=output_language, **(task_specify_agent_kwargs or {}), ) - self.specified_task_prompt = task_specify_agent.step( + self.specified_task_prompt = task_specify_agent.specify_prompt( task_prompt, meta_dict=task_specify_meta_dict, ) @@ -127,7 +127,8 @@ def __init__( output_language=output_language, **(task_planner_agent_kwargs or {}), ) - self.planned_task_prompt = task_planner_agent.step(task_prompt) + self.planned_task_prompt = task_planner_agent.create_plan( + task_prompt) task_prompt = f"{task_prompt}\n{self.planned_task_prompt}" else: self.planned_task_prompt = None diff --git a/test/agents/test_task_agent.py b/test/agents/test_task_agent.py index c9945b20a0..470528a214 100644 --- a/test/agents/test_task_agent.py +++ b/test/agents/test_task_agent.py @@ -31,7 +31,7 @@ def test_task_specify_ai_society_agent(model: Optional[ModelType]): print(f"Original task prompt:\n{original_task_prompt}\n") task_specify_agent = TaskSpecifyAgent( model_config=ChatGPTConfig(temperature=1.0), model=model) - specified_task_prompt = task_specify_agent.step( + specified_task_prompt = task_specify_agent.specify_prompt( original_task_prompt, meta_dict=dict(assistant_role="Musician", user_role="Student")) assert ("{" and "}" not in task_specify_agent.task_specify_prompt) @@ -47,7 +47,7 @@ def test_task_specify_code_agent(model: Optional[ModelType]): model_config=ChatGPTConfig(temperature=1.0), model=model, ) - specified_task_prompt = task_specify_agent.step( + specified_task_prompt = task_specify_agent.specify_prompt( original_task_prompt, meta_dict=dict(domain="Chemistry", language="Python")) assert ("{" and "}" not in task_specify_agent.task_specify_prompt) @@ -63,11 +63,11 @@ def test_task_planner_agent(model: Optional[ModelType]): model_config=ChatGPTConfig(temperature=1.0), model=model, ) - specified_task_prompt = task_specify_agent.step( + specified_task_prompt = task_specify_agent.specify_prompt( original_task_prompt, meta_dict=dict(domain="Chemistry", language="Python")) print(f"Specified task prompt:\n{specified_task_prompt}\n") task_planner_agent = TaskPlannerAgent( model_config=ChatGPTConfig(temperature=1.0), model=model) - planned_task_prompt = task_planner_agent.step(specified_task_prompt) + planned_task_prompt = task_planner_agent.create_plan(specified_task_prompt) print(f"Planned task prompt:\n{planned_task_prompt}\n") From 04f449609b4997c71276ee81e1d5ac53ec5bab22 Mon Sep 17 00:00:00 2001 From: Dmitrii Khizbullin Date: Mon, 3 Jul 2023 13:51:04 +0300 Subject: [PATCH 2/3] Critic as well --- camel/agents/critic_agent.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/camel/agents/critic_agent.py b/camel/agents/critic_agent.py index e1021e0b15..a34c412b64 100644 --- a/camel/agents/critic_agent.py +++ b/camel/agents/critic_agent.py @@ -23,7 +23,7 @@ from camel.utils import get_first_int, print_text_animated -class CriticAgent(ChatAgent): +class CriticAgent: r"""A class for the critic agent that assists in selecting an option. Args: @@ -53,8 +53,8 @@ def __init__( verbose: bool = False, logger_color: Any = Fore.MAGENTA, ) -> None: - super().__init__(system_message, model, model_config, - message_window_size) + self.chat_agent = ChatAgent(system_message, model, model_config, + message_window_size) self.options_dict: Dict[str, str] = dict() self.retry_attempts = retry_attempts self.verbose = verbose @@ -96,7 +96,7 @@ def get_option(self, input_message: BaseMessage) -> str: msg_content = input_message.content i = 0 while i < self.retry_attempts: - critic_response = super().step(input_message) + critic_response = self.chat_agent.step(input_message) if critic_response.msgs is None or len(critic_response.msgs) == 0: raise RuntimeError("Got None critic messages.") @@ -104,7 +104,7 @@ def get_option(self, input_message: BaseMessage) -> str: raise RuntimeError("Critic step failed.") critic_msg = critic_response.msg - self.update_messages('assistant', critic_msg) + self.chat_agent.update_messages('assistant', critic_msg) if self.verbose: print_text_animated(self.logger_color + "\n> Critic response: " f"\x1b[3m{critic_msg.content}\x1b[0m\n") From be1d7009783d2fa3b277f1e116dc76bff59c820d Mon Sep 17 00:00:00 2001 From: Dmitrii Khizbullin Date: Mon, 3 Jul 2023 14:01:01 +0300 Subject: [PATCH 3/3] Cleanup mypy errors from tests --- camel/messages/base.py | 2 +- test/messages/test_chat_message.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/camel/messages/base.py b/camel/messages/base.py index 061c86bee6..20faf72184 100644 --- a/camel/messages/base.py +++ b/camel/messages/base.py @@ -221,7 +221,7 @@ def to_openai_assistant_message(self) -> OpenAIAssistantMessage: """ return {"role": "assistant", "content": self.content} - def to_dict(self) -> Dict: + def to_dict(self) -> Dict[str, Any]: r"""Converts the message to a dictionary. Returns: diff --git a/test/messages/test_chat_message.py b/test/messages/test_chat_message.py index a6f0416588..b3aa9bc8e6 100644 --- a/test/messages/test_chat_message.py +++ b/test/messages/test_chat_message.py @@ -11,6 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from typing import Any, Dict + import pytest from camel.messages import BaseMessage @@ -59,12 +61,12 @@ def test_chat_message(chat_message: BaseMessage) -> None: assert chat_message.content == content dictionary = chat_message.to_dict() - assert dictionary == { + reference_dict: Dict[str, Any] = { "role_name": role_name, "role_type": role_type.name, - **(meta_dict or {}), "content": content, } + assert dictionary == reference_dict def test_assistant_chat_message(assistant_chat_message: BaseMessage) -> None: @@ -79,12 +81,12 @@ def test_assistant_chat_message(assistant_chat_message: BaseMessage) -> None: assert assistant_chat_message.content == content dictionary = assistant_chat_message.to_dict() - assert dictionary == { + reference_dict: Dict[str, Any] = { "role_name": role_name, "role_type": role_type.name, - **(meta_dict or {}), "content": content, } + assert dictionary == reference_dict def test_user_chat_message(user_chat_message: BaseMessage) -> None: @@ -99,9 +101,9 @@ def test_user_chat_message(user_chat_message: BaseMessage) -> None: assert user_chat_message.content == content dictionary = user_chat_message.to_dict() - assert dictionary == { + reference_dict: Dict[str, Any] = { "role_name": role_name, "role_type": role_type.name, - **(meta_dict or {}), "content": content, } + assert dictionary == reference_dict