From 77856bfa0c26991bccb587420ac59a7b12bcf12f Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Fri, 7 Mar 2025 10:35:02 -0500 Subject: [PATCH] set agent_should_exit in gc agents --- ra_aid/agent_backends/ciayn_agent.py | 27 +++ ra_aid/agents/key_facts_gc_agent.py | 4 + ra_aid/agents/key_snippets_gc_agent.py | 4 + ra_aid/agents/research_notes_gc_agent.py | 5 + ra_aid/prompts/research_prompts.py | 5 +- tests/ra_aid/test_agent_should_exit_ciayn.py | 190 +++++++++++++++++++ 6 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 tests/ra_aid/test_agent_should_exit_ciayn.py diff --git a/ra_aid/agent_backends/ciayn_agent.py b/ra_aid/agent_backends/ciayn_agent.py index 0ed8c9a..7a5f70e 100644 --- a/ra_aid/agent_backends/ciayn_agent.py +++ b/ra_aid/agent_backends/ciayn_agent.py @@ -19,6 +19,7 @@ from ra_aid.tools.expert import get_model from ra_aid.tools.reflection import get_function_info from ra_aid.console.output import cpm from ra_aid.console.formatting import print_warning, print_error +from ra_aid.agent_context import should_exit logger = get_logger(__name__) @@ -247,6 +248,11 @@ class CiaynAgent: def _execute_tool(self, msg: BaseMessage) -> str: """Execute a tool call and return its result.""" + + # Check for should_exit before executing tool calls + if should_exit(): + logger.debug("Agent should exit flag detected in _execute_tool") + return "Tool execution aborted - agent should exit flag is set" code = msg.content globals_dict = {tool.func.__name__: tool.func for tool in self.tools} @@ -263,10 +269,20 @@ class CiaynAgent: # If we have multiple valid bundleable calls, execute them in sequence if len(tool_calls) > 1: + # Check for should_exit before executing bundled tool calls + if should_exit(): + logger.debug("Agent should exit flag detected before executing bundled tool calls") + return "Bundled tool execution aborted - agent should exit flag is set" + results = [] result_strings = [] for call in tool_calls: + # Check if agent should exit + if should_exit(): + logger.debug("Agent should exit flag detected during bundled tool execution") + return "Tool execution interrupted: agent_should_exit flag is set." + # Validate and fix each call if needed if validate_function_call_pattern(call): functions_list = "\n\n".join(self.available_functions) @@ -431,6 +447,12 @@ class CiaynAgent: logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}") pass + # Before executing the call + if should_exit(): + logger.debug("Agent should exit flag detected before tool execution") + return "Tool execution interrupted: agent_should_exit flag is set." + + # Execute the tool result = eval(code.strip(), globals_dict) return result except Exception as e: @@ -589,6 +611,11 @@ class CiaynAgent: max_empty_responses = 3 # Maximum number of consecutive empty responses before giving up while True: + # Check for should_exit + if should_exit(): + logger.debug("Agent should exit flag detected in stream loop") + break + base_prompt = self._build_prompt(last_result) self.chat_history.append(HumanMessage(content=base_prompt)) full_history = self._trim_chat_history(initial_messages, self.chat_history) diff --git a/ra_aid/agents/key_facts_gc_agent.py b/ra_aid/agents/key_facts_gc_agent.py index ece3c3d..a2b1115 100644 --- a/ra_aid/agents/key_facts_gc_agent.py +++ b/ra_aid/agents/key_facts_gc_agent.py @@ -16,6 +16,7 @@ from rich.panel import Panel logger = logging.getLogger(__name__) +from ra_aid.agent_context import mark_should_exit from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository @@ -101,6 +102,9 @@ def delete_key_facts(fact_ids: List[int]) -> str: failed_msg = f"Failed to delete facts: {', '.join([f'#{fact_id}' for fact_id in failed_facts])}" result_parts.append(failed_msg) + # Mark that the agent should exit after completing this operation + mark_should_exit() + return "\n".join(result_parts) diff --git a/ra_aid/agents/key_snippets_gc_agent.py b/ra_aid/agents/key_snippets_gc_agent.py index 72e4c38..43e0473 100644 --- a/ra_aid/agents/key_snippets_gc_agent.py +++ b/ra_aid/agents/key_snippets_gc_agent.py @@ -20,6 +20,7 @@ from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.llm import initialize_llm from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT from ra_aid.tools.memory import log_work_event +from ra_aid.agent_context import mark_should_exit console = Console() @@ -98,6 +99,9 @@ def delete_key_snippets(snippet_ids: List[int]) -> str: failed_msg = f"Failed to delete snippets: {', '.join([f'#{snippet_id}' for snippet_id in failed_snippets])}" result_parts.append(failed_msg) + # Mark that the agent should exit + mark_should_exit() + return "Snippets deleted." diff --git a/ra_aid/agents/research_notes_gc_agent.py b/ra_aid/agents/research_notes_gc_agent.py index 9fe168e..70bb91f 100644 --- a/ra_aid/agents/research_notes_gc_agent.py +++ b/ra_aid/agents/research_notes_gc_agent.py @@ -16,6 +16,8 @@ from rich.panel import Panel logger = logging.getLogger(__name__) +from ra_aid.agent_context import mark_should_exit + from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.database.repositories.research_note_repository import get_research_note_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository @@ -103,6 +105,9 @@ def delete_research_notes(note_ids: List[int]) -> str: failed_msg = f"Failed to delete research notes: {', '.join([f'#{note_id}' for note_id in failed_notes])}" result_parts.append(failed_msg) + # Mark the agent to exit after performing the cleanup operation + mark_should_exit() + return "\n".join(result_parts) diff --git a/ra_aid/prompts/research_prompts.py b/ra_aid/prompts/research_prompts.py index 8bd38f2..3540505 100644 --- a/ra_aid/prompts/research_prompts.py +++ b/ra_aid/prompts/research_prompts.py @@ -20,8 +20,7 @@ RESEARCH_COMMON_PROMPT_HEADER = """Current Date: {current_date} KEEP IT SIMPLE -Context from Previous Research (if available): - + {key_facts} @@ -43,6 +42,8 @@ Work already done: {project_info} +You should make the most efficient use of this previous research possible, with the caveat that not all of it will be relevant to the current task you are assigned with. Use this previous research to save redudant research, and to inform what you are currently tasked with. Be as efficient as possible. + Role diff --git a/tests/ra_aid/test_agent_should_exit_ciayn.py b/tests/ra_aid/test_agent_should_exit_ciayn.py new file mode 100644 index 0000000..150b184 --- /dev/null +++ b/tests/ra_aid/test_agent_should_exit_ciayn.py @@ -0,0 +1,190 @@ +"""Tests for CIAYN agent respecting should_exit flag.""" + +import pytest +from unittest.mock import Mock +from langchain_core.messages import HumanMessage, AIMessage + +from ra_aid.agent_backends.ciayn_agent import CiaynAgent +from ra_aid.agent_context import agent_context, mark_should_exit, should_exit +from ra_aid.exceptions import ToolExecutionError + + +def test_ciayn_agent_stream_respects_should_exit(): + """Test that the CIAYN agent's stream method respects should_exit.""" + # Create mock model and tool + mock_model = Mock() + mock_tool = Mock() + mock_tool.func.__name__ = "mock_tool" + + # Configure mock model to return a response + mock_model.invoke.return_value = AIMessage(content="mock_tool()") + + # Create agent + agent = CiaynAgent(mock_model, [mock_tool]) + + # Set up test messages + messages = {"messages": [HumanMessage(content="test")]} + + # Test stream exits when should_exit is set + with agent_context() as ctx: + # Set should_exit + mark_should_exit() + + # Verify should_exit is set + assert should_exit() + + # Call stream - should exit immediately without calling model.invoke + results = list(agent.stream(messages)) + + # Verify stream exited without processing (empty results) + assert len(results) == 0 + + # Verify model was not called + mock_model.invoke.assert_not_called() + + # Test negative case - stream should continue when should_exit is not set + with agent_context() as ctx: + # Verify should_exit is not set + assert not should_exit() + + # Execute stream + next(agent.stream(messages)) + + # Verify model was called + mock_model.invoke.assert_called_once() + + +def test_ciayn_agent_bundled_tools_respects_should_exit(): + """Test that the CIAYN agent respects should_exit when executing bundled tools.""" + # Create mock model and tools + mock_model = Mock() + mock_tool1 = Mock() + mock_tool1.func.__name__ = "emit_key_facts" + mock_tool1.func.return_value = "Tool 1 executed" + mock_tool2 = Mock() + mock_tool2.func.__name__ = "emit_key_snippet" + mock_tool2.func.return_value = "Tool 2 executed" + + # Configure model to return multiple tool calls + # Create a response with two bundled tool calls + mock_model.invoke.return_value = AIMessage(content=""" + emit_key_facts(facts=["fact1", "fact2"]) + emit_key_snippet(snippet_info={"filepath": "test.py", "line_number": 1, "snippet": "code", "description": "test"}) + """) + + # Create agent with bundleable tools + agent = CiaynAgent(mock_model, [mock_tool1, mock_tool2]) + + # Test executing bundled tools when should_exit is set + with agent_context() as ctx: + # Set up test messages + messages = {"messages": [HumanMessage(content="test")]} + + # Set should_exit before calling stream + mark_should_exit() + + # Call stream + generator = agent.stream(messages) + # Stream should exit without processing due to should_exit in the stream method + result = list(generator) + + # Model should not have been called + mock_model.invoke.assert_not_called() + + # Tools should not have been executed + mock_tool1.func.assert_not_called() + mock_tool2.func.assert_not_called() + + +def test_ciayn_agent_single_tool_respects_should_exit(): + """Test that the CIAYN agent respects should_exit when executing a single tool.""" + # Create mock model and tool + mock_model = Mock() + mock_tool = Mock() + mock_tool.func.__name__ = "test_tool" + mock_tool.func.return_value = "Tool executed" + + # Configure model to return a single tool call + mock_model.invoke.return_value = AIMessage(content="test_tool()") + + # Create agent + agent = CiaynAgent(mock_model, [mock_tool]) + + # Set up a context that manipulates should_exit during execution + class TestContext: + def __init__(self): + self.should_exit_flag = False + + def set_exit_flag(self): + self.should_exit_flag = True + return "Exit flag set" + + def get_exit_flag(self): + return self.should_exit_flag + + test_context = TestContext() + + # Replace the actual tool execution with one that sets should_exit + def execute_and_exit(*args, **kwargs): + # Set should_exit flag + with agent_context() as ctx: + mark_should_exit() + return "Tool executed" + + # Override the mock tool to set should_exit + mock_tool.func.side_effect = execute_and_exit + + # Execute tool call that sets should_exit + with agent_context() as ctx: + # Verify should_exit is initially False + assert not should_exit() + + # Set up test messages + messages = {"messages": [HumanMessage(content="test")]} + + # Call stream + generator = agent.stream(messages) + # Get first response (which will execute the tool and set should_exit) + next(generator, None) + + # Verify model was called + mock_model.invoke.assert_called_once() + + # Verify tool was executed + mock_tool.func.assert_called_once() + + # Verify should_exit was set + assert should_exit() + + # Get next response (should be empty because stream should exit) + results = list(generator) + assert len(results) == 0 + + # Verify model wasn't called again + assert mock_model.invoke.call_count == 1 + + +def test_ciayn_agent_execute_tool_respects_should_exit(): + """Test that _execute_tool respects should_exit.""" + # Create mock model and tool + mock_model = Mock() + mock_tool = Mock() + mock_tool.func.__name__ = "test_tool" + + # Create agent + agent = CiaynAgent(mock_model, [mock_tool]) + + # Test _execute_tool exits when should_exit is set + with agent_context() as ctx: + # Set should_exit + mark_should_exit() + + # Call _execute_tool + message = HumanMessage(content="test_tool()") + result = agent._execute_tool(message) + + # Verify early exit message + assert "agent should exit flag is set" in result + + # Verify tool was not executed + mock_tool.func.assert_not_called() \ No newline at end of file