From ba82f41fa0815dd467eb989fce2045f14a478975 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Thu, 9 Jan 2025 12:57:58 -0500 Subject: [PATCH] Only include work_log key if needed, to save tokens. --- ra_aid/agent_utils.py | 2 +- ra_aid/tools/agent.py | 30 ++++++++++++++++++++---------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 5a692fa..f402f3e 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -544,7 +544,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: _global_memory['task_completed'] = False _global_memory['completion_message'] = '' break - if _global_memory['task_completed'] or _global_memory['plan_completed']: + if _global_memory['task_completed']: _global_memory['task_completed'] = False _global_memory['completion_message'] = '' break diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 51618cf..91d0462 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -87,8 +87,7 @@ def request_research(query: str) -> ResearchResult: _global_memory['completion_message'] = '' _global_memory['task_completed'] = False - return { - "work_log": work_log, + response_data = { "completion_message": completion_message, "key_facts": get_memory_value("key_facts"), "related_files": get_related_files(), @@ -97,6 +96,9 @@ def request_research(query: str) -> ResearchResult: "success": success, "reason": reason } + if work_log is not None: + response_data["work_log"] = work_log + return response_data @tool("request_web_research") def request_web_research(query: str) -> ResearchResult: @@ -148,14 +150,16 @@ def request_web_research(query: str) -> ResearchResult: _global_memory['completion_message'] = '' _global_memory['task_completed'] = False - return { - "work_log": work_log, + response_data = { "completion_message": completion_message, "key_snippets": get_memory_value("key_snippets"), "research_notes": get_memory_value("research_notes"), "success": success, "reason": reason } + if work_log is not None: + response_data["work_log"] = work_log + return response_data @tool("request_research_and_implementation") def request_research_and_implementation(query: str) -> Dict[str, Any]: @@ -212,8 +216,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: _global_memory['task_completed'] = False _global_memory['plan_completed'] = False - return { - "work_log": work_log, + response_data = { "completion_message": completion_message, "key_facts": get_memory_value("key_facts"), "related_files": get_related_files(), @@ -222,6 +225,9 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: "success": success, "reason": reason } + if work_log is not None: + response_data["work_log"] = work_log + return response_data @tool("request_task_implementation") def request_task_implementation(task_spec: str) -> Dict[str, Any]: @@ -281,8 +287,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: _global_memory['completion_message'] = '' _global_memory['task_completed'] = False - return { - "work_log": work_log, + response_data = { "key_facts": get_memory_value("key_facts"), "related_files": get_related_files(), "key_snippets": get_memory_value("key_snippets"), @@ -290,6 +295,9 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: "success": success, "reason": reason } + if work_log is not None: + response_data["work_log"] = work_log + return response_data @tool("request_implementation") def request_implementation(task_spec: str) -> Dict[str, Any]: @@ -341,8 +349,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]: _global_memory['task_completed'] = False _global_memory['plan_completed'] = False - return { - "work_log": work_log, + response_data = { "completion_message": completion_message, "key_facts": get_memory_value("key_facts"), "related_files": get_related_files(), @@ -350,3 +357,6 @@ def request_implementation(task_spec: str) -> Dict[str, Any]: "success": success, "reason": reason } + if work_log is not None: + response_data["work_log"] = work_log + return response_data