From 5d899d3d1316b68e7559f4888b37e3808f34de02 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Tue, 11 Mar 2025 08:56:12 -0400 Subject: [PATCH] trajectory --- ra_aid/__main__.py | 24 ++++++++++++++++++++++++ ra_aid/agents/planning_agent.py | 14 ++++++++++++++ ra_aid/console/formatting.py | 28 ---------------------------- ra_aid/tools/agent.py | 21 +++++++++++++++++---- 4 files changed, 55 insertions(+), 32 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index bf42ec8..472fceb 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -620,6 +620,18 @@ def main(): sys.exit(1) print_stage_header("Chat Mode") + + # Record stage transition in trajectory + trajectory_repo = get_trajectory_repository() + human_input_id = get_human_input_repository().get_most_recent_id() + trajectory_repo.create( + step_data={ + "stage": "chat_mode", + "display_title": "Chat Mode", + }, + record_type="stage_transition", + human_input_id=human_input_id + ) # Get project info try: @@ -769,6 +781,18 @@ def main(): # Run research stage print_stage_header("Research Stage") + + # Record stage transition in trajectory + trajectory_repo = get_trajectory_repository() + human_input_id = get_human_input_repository().get_most_recent_id() + trajectory_repo.create( + step_data={ + "stage": "research_stage", + "display_title": "Research Stage", + }, + record_type="stage_transition", + human_input_id=human_input_id + ) # Initialize research model with potential overrides research_provider = args.research_provider or args.provider diff --git a/ra_aid/agents/planning_agent.py b/ra_aid/agents/planning_agent.py index 42355b6..8b02a38 100644 --- a/ra_aid/agents/planning_agent.py +++ b/ra_aid/agents/planning_agent.py @@ -24,6 +24,8 @@ from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_ from ra_aid.database.repositories.research_note_repository import get_research_note_repository from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.database.repositories.work_log_repository import get_work_log_repository +from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository +from ra_aid.database.repositories.human_input_repository import get_human_input_repository from ra_aid.env_inv_context import get_env_inv from ra_aid.exceptions import AgentInterrupt from ra_aid.llm import initialize_expert_llm @@ -155,6 +157,18 @@ def run_planning_agent( # Display the planning stage header before any reasoning assistance print_stage_header("Planning Stage") + + # Record stage transition in trajectory + trajectory_repo = get_trajectory_repository() + human_input_id = get_human_input_repository().get_most_recent_id() + trajectory_repo.create( + step_data={ + "stage": "planning_stage", + "display_title": "Planning Stage", + }, + record_type="stage_transition", + human_input_id=human_input_id + ) # Initialize expert guidance section expert_guidance = "" diff --git a/ra_aid/console/formatting.py b/ra_aid/console/formatting.py index a10fffe..013a3ee 100644 --- a/ra_aid/console/formatting.py +++ b/ra_aid/console/formatting.py @@ -37,20 +37,6 @@ def print_stage_header(stage: str) -> None: # Create styled panel with icon panel_content = f" {icon} {stage_title}" console.print(Panel(panel_content, style="green bold", padding=0)) - - # Record trajectory event - focus on semantic meaning - trajectory_repo = get_trajectory_repository() - human_input_id = get_human_input_repository().get_most_recent_id() - - trajectory_repo.create( - step_data={ - "stage": stage_key, - "display_icon": icon, - "display_title": stage_title, - }, - record_type="stage_transition", - human_input_id=human_input_id - ) def print_task_header(task: str) -> None: @@ -60,20 +46,6 @@ def print_task_header(task: str) -> None: task: The task text to print (supports Markdown formatting) """ console.print(Panel(Markdown(task), title="🔧 Task", border_style="yellow bold")) - - # Record trajectory event - trajectory_repo = get_trajectory_repository() - human_input_id = get_human_input_repository().get_most_recent_id() - - trajectory_repo.create( - step_data={ - "task": task, - "display_title": "Task", - "display_icon": "🔧", - }, - record_type="task_display", - human_input_id=human_input_id - ) def print_error(message: str) -> None: diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 26190e3..6928713 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -14,11 +14,12 @@ from ra_aid.agent_context import ( is_crashed, reset_completion_flags, ) -from ra_aid.console.formatting import print_error -from ra_aid.database.repositories.human_input_repository import HumanInputRepository +from ra_aid.console.formatting import print_error, print_task_header +from ra_aid.database.repositories.human_input_repository import HumanInputRepository, get_human_input_repository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.config_repository import get_config_repository +from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository from ra_aid.database.repositories.related_files_repository import get_related_files_repository from ra_aid.database.repositories.research_note_repository import get_research_note_repository from ra_aid.exceptions import AgentInterrupt @@ -26,8 +27,7 @@ from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict -from ..console import print_task_header -from ..llm import initialize_llm +from ra_aid.llm import initialize_llm from .human import ask_human from .memory import get_related_files, get_work_log @@ -346,6 +346,19 @@ def request_task_implementation(task_spec: str) -> str: try: print_task_header(task_spec) + + # Record task display in trajectory + trajectory_repo = get_trajectory_repository() + human_input_id = get_human_input_repository().get_most_recent_id() + trajectory_repo.create( + step_data={ + "task": task_spec, + "display_title": "Task", + }, + record_type="task_display", + human_input_id=human_input_id + ) + # Run implementation agent from ..agents.implementation_agent import run_task_implementation_agent