Skip to content

Commit aa2b501

Browse files
committed
single shot solution
1 parent 2150be5 commit aa2b501

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

examples/ai_society/role_playing_multiprocess.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020

2121
from camel.configs import ChatGPTConfig
2222
from camel.societies import RolePlaying
23-
from camel.typing import TaskType
23+
from camel.typing import TaskType, ModelType
2424
from camel.utils import download_tasks
2525

2626

2727
def generate_data(assistant_idx: int, assistant_role_name: str, user_idx: int,
2828
user_role_name: str, task_idx: int, task_prompt: str,
2929
verbose: bool = False) -> None:
3030

31-
max_num_messages = 40
31+
max_num_messages = 100
3232

3333
original_task_prompt = task_prompt.replace(f"{task_idx+1}. ", "")
3434

@@ -38,6 +38,7 @@ def generate_data(assistant_idx: int, assistant_role_name: str, user_idx: int,
3838
task_prompt=original_task_prompt,
3939
with_task_specify=True,
4040
with_task_planner=False,
41+
model_type=ModelType.GPT_3_5_TURBO_16K,
4142
task_specify_agent_kwargs=dict(model_config=ChatGPTConfig(
4243
temperature=1.4)),
4344
)
@@ -204,7 +205,7 @@ def main() -> None:
204205
try:
205206
slurm_array_task_id = os.environ.get('SLURM_ARRAY_TASK_ID')
206207
if slurm_array_task_id is None:
207-
raise
208+
raise ValueError("SLURM_ARRAY_TASK_ID is not set")
208209
array_idx = int(slurm_array_task_id)
209210
except (TypeError, ValueError) as e:
210211
print(f"Error: {e}")
@@ -227,7 +228,7 @@ def main() -> None:
227228
roles_per_chunk:(array_idx + 1) *
228229
roles_per_chunk]
229230

230-
pool = multiprocessing.Pool()
231+
pool = multiprocessing.Pool(processes=10)
231232

232233
for assistant_idx, assistant_role_name in enumerate(assistant_roles):
233234
assistant_idx += array_idx * roles_per_chunk

examples/single_shot/pair_generator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from camel.typing import TaskType, RoleType
1818

1919

20-
def main(key: str = "generate_users", num_roles: int = 50):
20+
def main():
2121

2222
single_shot_template = SingleShotPromptTemplateDict()
2323
assistant_sys_msg_prompt = single_shot_template[RoleType.ASSISTANT]
@@ -39,5 +39,4 @@ def main(key: str = "generate_users", num_roles: int = 50):
3939

4040

4141
if __name__ == "__main__":
42-
main("generate_users", 50)
43-
main("generate_assistants", 50)
42+
main()

0 commit comments

Comments
 (0)