Include work log in agent context.

This commit is contained in:
user 2024-12-23 12:18:48 -05:00
parent 9a53940dcd
commit a0250855da
2 changed files with 40 additions and 11 deletions

View File

@ -8,7 +8,7 @@ ResearchResult = Dict[str, Union[str, bool, Dict[int, Any], List[Any], None]]
from rich.console import Console from rich.console import Console
from ra_aid.tools.memory import _global_memory from ra_aid.tools.memory import _global_memory
from ra_aid.console.formatting import print_error, print_interrupt from ra_aid.console.formatting import print_error, print_interrupt
from .memory import get_memory_value, get_related_files from .memory import get_memory_value, get_related_files, get_work_log, reset_work_log
from ..llm import initialize_llm from ..llm import initialize_llm
from ..console import print_task_header from ..console import print_task_header
@ -72,11 +72,17 @@ def request_research(query: str) -> ResearchResult:
# Get completion message if available # Get completion message if available
completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None)
# Get and reset work log if at root depth
work_log = get_work_log() if current_depth == 1 else None
if current_depth == 1:
reset_work_log()
# Clear completion state from global memory # Clear completion state from global memory
_global_memory['completion_message'] = '' _global_memory['completion_message'] = ''
_global_memory['task_completed'] = False _global_memory['task_completed'] = False
return { return {
"work_log": work_log,
"completion_message": completion_message, "completion_message": completion_message,
"key_facts": get_memory_value("key_facts"), "key_facts": get_memory_value("key_facts"),
"related_files": get_related_files(), "related_files": get_related_files(),
@ -123,12 +129,19 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
# Get completion message if available # Get completion message if available
completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None)
# Get and reset work log if at root depth
current_depth = _global_memory.get('agent_depth', 0)
work_log = get_work_log() if current_depth == 1 else None
if current_depth == 1:
reset_work_log()
# Clear completion state from global memory # Clear completion state from global memory
_global_memory['completion_message'] = '' _global_memory['completion_message'] = ''
_global_memory['task_completed'] = False _global_memory['task_completed'] = False
_global_memory['plan_completed'] = False _global_memory['plan_completed'] = False
return { return {
"work_log": work_log,
"completion_message": completion_message, "completion_message": completion_message,
"key_facts": get_memory_value("key_facts"), "key_facts": get_memory_value("key_facts"),
"related_files": get_related_files(), "related_files": get_related_files(),
@ -182,11 +195,18 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
# Get completion message if available # Get completion message if available
completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None)
# Get and reset work log if at root depth
current_depth = _global_memory.get('agent_depth', 0)
work_log = get_work_log() if current_depth == 1 else None
if current_depth == 1:
reset_work_log()
# Clear completion state from global memory # Clear completion state from global memory
_global_memory['completion_message'] = '' _global_memory['completion_message'] = ''
_global_memory['task_completed'] = False _global_memory['task_completed'] = False
return { return {
"work_log": work_log,
"key_facts": get_memory_value("key_facts"), "key_facts": get_memory_value("key_facts"),
"related_files": get_related_files(), "related_files": get_related_files(),
"key_snippets": get_memory_value("key_snippets"), "key_snippets": get_memory_value("key_snippets"),
@ -231,12 +251,19 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
# Get completion message if available # Get completion message if available
completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None)
# Get and reset work log if at root depth
current_depth = _global_memory.get('agent_depth', 0)
work_log = get_work_log() if current_depth == 1 else None
if current_depth == 1:
reset_work_log()
# Clear completion state from global memory # Clear completion state from global memory
_global_memory['completion_message'] = '' _global_memory['completion_message'] = ''
_global_memory['task_completed'] = False _global_memory['task_completed'] = False
_global_memory['plan_completed'] = False _global_memory['plan_completed'] = False
return { return {
"work_log": work_log,
"completion_message": completion_message, "completion_message": completion_message,
"key_facts": get_memory_value("key_facts"), "key_facts": get_memory_value("key_facts"),
"related_files": get_related_files(), "related_files": get_related_files(),

View File

@ -63,7 +63,7 @@ def emit_plan(plan: str) -> str:
""" """
_global_memory['plans'].append(plan) _global_memory['plans'].append(plan)
console.print(Panel(Markdown(plan), title="📋 Plan")) console.print(Panel(Markdown(plan), title="📋 Plan"))
log_work_event(f"Added plan step: {plan}") log_work_event(f"Added plan step:\n\n{plan}")
return plan return plan
@tool("emit_task") @tool("emit_task")
@ -84,7 +84,7 @@ def emit_task(task: str) -> str:
_global_memory['tasks'][task_id] = task _global_memory['tasks'][task_id] = task
console.print(Panel(Markdown(task), title=f"✅ Task #{task_id}")) console.print(Panel(Markdown(task), title=f"✅ Task #{task_id}"))
log_work_event(f"Task #{task_id} added: {task}") log_work_event(f"Task #{task_id} added:\n\n{task}")
return f"Task #{task_id} stored." return f"Task #{task_id} stored."
@ -114,7 +114,7 @@ def emit_key_facts(facts: List[str]) -> str:
# Add result message # Add result message
results.append(f"Stored fact #{fact_id}: {fact}") results.append(f"Stored fact #{fact_id}: {fact}")
log_work_event(f"Stored {len(facts)} key facts") log_work_event(f"Stored {len(facts)} key facts.")
return "Facts stored." return "Facts stored."
@ -138,7 +138,7 @@ def delete_key_facts(fact_ids: List[int]) -> str:
console.print(Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")) console.print(Panel(Markdown(success_msg), title="Fact Deleted", border_style="green"))
results.append(success_msg) results.append(success_msg)
log_work_event(f"Deleted facts {fact_ids}") log_work_event(f"Deleted facts {fact_ids}.")
return "Facts deleted." return "Facts deleted."
@tool("delete_tasks") @tool("delete_tasks")
@ -163,7 +163,7 @@ def delete_tasks(task_ids: List[int]) -> str:
border_style="green")) border_style="green"))
results.append(success_msg) results.append(success_msg)
log_work_event(f"Deleted tasks {task_ids}") log_work_event(f"Deleted tasks {task_ids}.")
return "Tasks deleted." return "Tasks deleted."
@tool("request_implementation") @tool("request_implementation")
@ -181,7 +181,7 @@ def request_implementation() -> str:
""" """
_global_memory['implementation_requested'] = True _global_memory['implementation_requested'] = True
console.print(Panel("🚀 Implementation Requested", style="yellow", padding=0)) console.print(Panel("🚀 Implementation Requested", style="yellow", padding=0))
log_work_event("Implementation requested") log_work_event("Implementation requested.")
return "" return ""
@ -234,7 +234,7 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
results.append(f"Stored snippet #{snippet_id}") results.append(f"Stored snippet #{snippet_id}")
log_work_event(f"Stored {len(snippets)} code snippets") log_work_event(f"Stored {len(snippets)} code snippets.")
return "Snippets stored." return "Snippets stored."
@tool("delete_key_snippets") @tool("delete_key_snippets")
@ -259,7 +259,7 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
border_style="green")) border_style="green"))
results.append(success_msg) results.append(success_msg)
log_work_event(f"Deleted snippets {snippet_ids}") log_work_event(f"Deleted snippets {snippet_ids}.")
return "Snippets deleted." return "Snippets deleted."
@tool("swap_task_order") @tool("swap_task_order")
@ -309,7 +309,7 @@ def one_shot_completed(message: str) -> str:
_global_memory['task_completed'] = True _global_memory['task_completed'] = True
_global_memory['completion_message'] = message _global_memory['completion_message'] = message
console.print(Panel(Markdown(message), title="✅ Task Completed")) console.print(Panel(Markdown(message), title="✅ Task Completed"))
log_work_event(f"Task completed: {message}") log_work_event(f"Task completed\n\n{message}")
return "Completion noted." return "Completion noted."
@tool("task_completed") @tool("task_completed")
@ -342,7 +342,7 @@ def plan_implementation_completed(message: str) -> str:
_global_memory['tasks'].clear() # Clear task list when plan is completed _global_memory['tasks'].clear() # Clear task list when plan is completed
_global_memory['task_id_counter'] = 1 _global_memory['task_id_counter'] = 1
console.print(Panel(Markdown(message), title="✅ Plan Executed")) console.print(Panel(Markdown(message), title="✅ Plan Executed"))
log_work_event(f"Plan execution completed: {message}") log_work_event(f"Plan execution completed:\n\n{message}")
return "Plan completion noted and task list cleared." return "Plan completion noted and task list cleared."
def get_related_files() -> List[str]: def get_related_files() -> List[str]:
@ -433,6 +433,7 @@ def get_work_log() -> str:
Example: Example:
## 2024-12-23T11:39:10 ## 2024-12-23T11:39:10
Task #1 added: Create login form Task #1 added: Create login form
""" """
if not _global_memory['work_log']: if not _global_memory['work_log']:
@ -442,6 +443,7 @@ def get_work_log() -> str:
for entry in _global_memory['work_log']: for entry in _global_memory['work_log']:
entries.extend([ entries.extend([
f"## {entry['timestamp']}", f"## {entry['timestamp']}",
"",
entry['event'], entry['event'],
"" # Blank line between entries "" # Blank line between entries
]) ])