Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions camel/agents/critic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from camel.utils import get_first_int, print_text_animated


class CriticAgent(ChatAgent):
class CriticAgent:
r"""A class for the critic agent that assists in selecting an option.

Args:
Expand Down Expand Up @@ -53,8 +53,8 @@ def __init__(
verbose: bool = False,
logger_color: Any = Fore.MAGENTA,
) -> None:
super().__init__(system_message, model, model_config,
message_window_size)
self.chat_agent = ChatAgent(system_message, model, model_config,
message_window_size)
self.options_dict: Dict[str, str] = dict()
self.retry_attempts = retry_attempts
self.verbose = verbose
Expand Down Expand Up @@ -96,15 +96,15 @@ def get_option(self, input_message: BaseMessage) -> str:
msg_content = input_message.content
i = 0
while i < self.retry_attempts:
critic_response = super().step(input_message)
critic_response = self.chat_agent.step(input_message)

if critic_response.msgs is None or len(critic_response.msgs) == 0:
raise RuntimeError("Got None critic messages.")
if critic_response.terminated:
raise RuntimeError("Critic step failed.")

critic_msg = critic_response.msg
self.update_messages('assistant', critic_msg)
self.chat_agent.update_messages('assistant', critic_msg)
if self.verbose:
print_text_animated(self.logger_color + "\n> Critic response: "
f"\x1b[3m{critic_msg.content}\x1b[0m\n")
Expand Down
28 changes: 15 additions & 13 deletions camel/agents/task_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from camel.typing import ModelType, RoleType, TaskType


class TaskSpecifyAgent(ChatAgent):
class TaskSpecifyAgent:
r"""An agent that specifies a given task prompt by prompting the user to
provide more details.

Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(
output_language: Optional[str] = None,
) -> None:

self.task_specify_prompt: Union[str, TextPrompt]
if task_specify_prompt is None:
task_specify_prompt_template = PromptTemplateGenerator(
).get_task_specify_prompt(task_type)
Expand All @@ -72,10 +73,10 @@ def __init__(
content="You can make a task more specific.",
)

super().__init__(system_message, model, model_config,
output_language=output_language)
self.chat_agent = ChatAgent(system_message, model, model_config,
output_language=output_language)

def step(
def specify_prompt(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def specify_prompt(
def expand_prompt(

I'd propose expand as a basic self-descriptive word for what happens here, same could work for the class name probably

self,
original_task_prompt: Union[str, TextPrompt],
meta_dict: Optional[Dict[str, Any]] = None,
Expand All @@ -92,7 +93,8 @@ def step(
Returns:
TextPrompt: The specified task prompt.
"""
self.reset()
self.chat_agent.reset()

self.task_specify_prompt = self.task_specify_prompt.format(
task=original_task_prompt)

Expand All @@ -102,7 +104,7 @@ def step(

task_msg = BaseMessage.make_user_message(
role_name="Task Specifier", content=self.task_specify_prompt)
specifier_response = super().step(task_msg)
specifier_response = self.chat_agent.step(task_msg)
if (specifier_response.msgs is None
or len(specifier_response.msgs) == 0):
raise RuntimeError("Task specification failed.")
Expand All @@ -114,7 +116,7 @@ def step(
return TextPrompt(specified_task_msg.content)


class TaskPlannerAgent(ChatAgent):
class TaskPlannerAgent:
r"""An agent that helps divide a task into subtasks based on the input
task prompt.

Expand Down Expand Up @@ -148,10 +150,10 @@ def __init__(
content="You are a helpful task planner.",
)

super().__init__(system_message, model, model_config,
output_language=output_language)
self.chat_agent = ChatAgent(system_message, model, model_config,
output_language=output_language)

def step(
def create_plan(
self,
task_prompt: Union[str, TextPrompt],
) -> TextPrompt:
Expand All @@ -165,14 +167,14 @@ def step(
TextPrompt: A prompt for the subtasks generated by the agent.
"""
# TODO: Maybe include roles information.
self.reset()
self.chat_agent.reset()

self.task_planner_prompt = self.task_planner_prompt.format(
task=task_prompt)

task_msg = BaseMessage.make_user_message(
role_name="Task Planner", content=self.task_planner_prompt)
# sub_tasks_msgs, terminated, _
task_response = super().step(task_msg)
task_response = self.chat_agent.step(task_msg)

if task_response.msgs is None:
raise RuntimeError("Got None Subtasks messages.")
Expand Down
2 changes: 1 addition & 1 deletion camel/messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
"""
return {"role": "assistant", "content": self.content}

def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
r"""Converts the message to a dictionary.

Returns:
Expand Down
5 changes: 3 additions & 2 deletions camel/societies/role_playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
output_language=output_language,
**(task_specify_agent_kwargs or {}),
)
self.specified_task_prompt = task_specify_agent.step(
self.specified_task_prompt = task_specify_agent.specify_prompt(
task_prompt,
meta_dict=task_specify_meta_dict,
)
Expand All @@ -127,7 +127,8 @@ def __init__(
output_language=output_language,
**(task_planner_agent_kwargs or {}),
)
self.planned_task_prompt = task_planner_agent.step(task_prompt)
self.planned_task_prompt = task_planner_agent.create_plan(
task_prompt)
task_prompt = f"{task_prompt}\n{self.planned_task_prompt}"
else:
self.planned_task_prompt = None
Expand Down
8 changes: 4 additions & 4 deletions test/agents/test_task_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_task_specify_ai_society_agent(model: Optional[ModelType]):
print(f"Original task prompt:\n{original_task_prompt}\n")
task_specify_agent = TaskSpecifyAgent(
model_config=ChatGPTConfig(temperature=1.0), model=model)
specified_task_prompt = task_specify_agent.step(
specified_task_prompt = task_specify_agent.specify_prompt(
original_task_prompt, meta_dict=dict(assistant_role="Musician",
user_role="Student"))
assert ("{" and "}" not in task_specify_agent.task_specify_prompt)
Expand All @@ -47,7 +47,7 @@ def test_task_specify_code_agent(model: Optional[ModelType]):
model_config=ChatGPTConfig(temperature=1.0),
model=model,
)
specified_task_prompt = task_specify_agent.step(
specified_task_prompt = task_specify_agent.specify_prompt(
original_task_prompt, meta_dict=dict(domain="Chemistry",
language="Python"))
assert ("{" and "}" not in task_specify_agent.task_specify_prompt)
Expand All @@ -63,11 +63,11 @@ def test_task_planner_agent(model: Optional[ModelType]):
model_config=ChatGPTConfig(temperature=1.0),
model=model,
)
specified_task_prompt = task_specify_agent.step(
specified_task_prompt = task_specify_agent.specify_prompt(
original_task_prompt, meta_dict=dict(domain="Chemistry",
language="Python"))
print(f"Specified task prompt:\n{specified_task_prompt}\n")
task_planner_agent = TaskPlannerAgent(
model_config=ChatGPTConfig(temperature=1.0), model=model)
planned_task_prompt = task_planner_agent.step(specified_task_prompt)
planned_task_prompt = task_planner_agent.create_plan(specified_task_prompt)
print(f"Planned task prompt:\n{planned_task_prompt}\n")
14 changes: 8 additions & 6 deletions test/messages/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any, Dict

import pytest

from camel.messages import BaseMessage
Expand Down Expand Up @@ -59,12 +61,12 @@ def test_chat_message(chat_message: BaseMessage) -> None:
assert chat_message.content == content

dictionary = chat_message.to_dict()
assert dictionary == {
reference_dict: Dict[str, Any] = {
"role_name": role_name,
"role_type": role_type.name,
**(meta_dict or {}),
"content": content,
}
assert dictionary == reference_dict


def test_assistant_chat_message(assistant_chat_message: BaseMessage) -> None:
Expand All @@ -79,12 +81,12 @@ def test_assistant_chat_message(assistant_chat_message: BaseMessage) -> None:
assert assistant_chat_message.content == content

dictionary = assistant_chat_message.to_dict()
assert dictionary == {
reference_dict: Dict[str, Any] = {
"role_name": role_name,
"role_type": role_type.name,
**(meta_dict or {}),
"content": content,
}
assert dictionary == reference_dict


def test_user_chat_message(user_chat_message: BaseMessage) -> None:
Expand All @@ -99,9 +101,9 @@ def test_user_chat_message(user_chat_message: BaseMessage) -> None:
assert user_chat_message.content == content

dictionary = user_chat_message.to_dict()
assert dictionary == {
reference_dict: Dict[str, Any] = {
"role_name": role_name,
"role_type": role_type.name,
**(meta_dict or {}),
"content": content,
}
assert dictionary == reference_dict
Loading