set agent_should_exit in gc agents

This commit is contained in:
AI Christianson 2025-03-07 10:35:02 -05:00
parent 53406f1ddf
commit 77856bfa0c
6 changed files with 233 additions and 2 deletions

View File

@ -19,6 +19,7 @@ from ra_aid.tools.expert import get_model
from ra_aid.tools.reflection import get_function_info
from ra_aid.console.output import cpm
from ra_aid.console.formatting import print_warning, print_error
from ra_aid.agent_context import should_exit
logger = get_logger(__name__)
@ -248,6 +249,11 @@ class CiaynAgent:
def _execute_tool(self, msg: BaseMessage) -> str:
"""Execute a tool call and return its result."""
# Check for should_exit before executing tool calls
if should_exit():
logger.debug("Agent should exit flag detected in _execute_tool")
return "Tool execution aborted - agent should exit flag is set"
code = msg.content
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
@ -263,10 +269,20 @@ class CiaynAgent:
# If we have multiple valid bundleable calls, execute them in sequence
if len(tool_calls) > 1:
# Check for should_exit before executing bundled tool calls
if should_exit():
logger.debug("Agent should exit flag detected before executing bundled tool calls")
return "Bundled tool execution aborted - agent should exit flag is set"
results = []
result_strings = []
for call in tool_calls:
# Check if agent should exit
if should_exit():
logger.debug("Agent should exit flag detected during bundled tool execution")
return "Tool execution interrupted: agent_should_exit flag is set."
# Validate and fix each call if needed
if validate_function_call_pattern(call):
functions_list = "\n\n".join(self.available_functions)
@ -431,6 +447,12 @@ class CiaynAgent:
logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}")
pass
# Before executing the call
if should_exit():
logger.debug("Agent should exit flag detected before tool execution")
return "Tool execution interrupted: agent_should_exit flag is set."
# Execute the tool
result = eval(code.strip(), globals_dict)
return result
except Exception as e:
@ -589,6 +611,11 @@ class CiaynAgent:
max_empty_responses = 3 # Maximum number of consecutive empty responses before giving up
while True:
# Check for should_exit
if should_exit():
logger.debug("Agent should exit flag detected in stream loop")
break
base_prompt = self._build_prompt(last_result)
self.chat_history.append(HumanMessage(content=base_prompt))
full_history = self._trim_chat_history(initial_messages, self.chat_history)

View File

@ -16,6 +16,7 @@ from rich.panel import Panel
logger = logging.getLogger(__name__)
from ra_aid.agent_context import mark_should_exit
from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
@ -101,6 +102,9 @@ def delete_key_facts(fact_ids: List[int]) -> str:
failed_msg = f"Failed to delete facts: {', '.join([f'#{fact_id}' for fact_id in failed_facts])}"
result_parts.append(failed_msg)
# Mark that the agent should exit after completing this operation
mark_should_exit()
return "\n".join(result_parts)

View File

@ -20,6 +20,7 @@ from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
from ra_aid.tools.memory import log_work_event
from ra_aid.agent_context import mark_should_exit
console = Console()
@ -98,6 +99,9 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
failed_msg = f"Failed to delete snippets: {', '.join([f'#{snippet_id}' for snippet_id in failed_snippets])}"
result_parts.append(failed_msg)
# Mark that the agent should exit
mark_should_exit()
return "Snippets deleted."

View File

@ -16,6 +16,8 @@ from rich.panel import Panel
logger = logging.getLogger(__name__)
from ra_aid.agent_context import mark_should_exit
from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
@ -103,6 +105,9 @@ def delete_research_notes(note_ids: List[int]) -> str:
failed_msg = f"Failed to delete research notes: {', '.join([f'#{note_id}' for note_id in failed_notes])}"
result_parts.append(failed_msg)
# Mark the agent to exit after performing the cleanup operation
mark_should_exit()
return "\n".join(result_parts)

View File

@ -20,8 +20,7 @@ RESEARCH_COMMON_PROMPT_HEADER = """Current Date: {current_date}
KEEP IT SIMPLE
Context from Previous Research (if available):
<previous research>
<key facts>
{key_facts}
</key facts>
@ -43,6 +42,8 @@ Work already done:
<project info>
{project_info}
</project info>
<caveat>You should make the most efficient use of this previous research possible, with the caveat that not all of it will be relevant to the current task you are assigned with. Use this previous research to save redudant research, and to inform what you are currently tasked with. Be as efficient as possible.</caveat>
</previous research>
Role

View File

@ -0,0 +1,190 @@
"""Tests for CIAYN agent respecting should_exit flag."""
import pytest
from unittest.mock import Mock
from langchain_core.messages import HumanMessage, AIMessage
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.agent_context import agent_context, mark_should_exit, should_exit
from ra_aid.exceptions import ToolExecutionError
def test_ciayn_agent_stream_respects_should_exit():
"""Test that the CIAYN agent's stream method respects should_exit."""
# Create mock model and tool
mock_model = Mock()
mock_tool = Mock()
mock_tool.func.__name__ = "mock_tool"
# Configure mock model to return a response
mock_model.invoke.return_value = AIMessage(content="mock_tool()")
# Create agent
agent = CiaynAgent(mock_model, [mock_tool])
# Set up test messages
messages = {"messages": [HumanMessage(content="test")]}
# Test stream exits when should_exit is set
with agent_context() as ctx:
# Set should_exit
mark_should_exit()
# Verify should_exit is set
assert should_exit()
# Call stream - should exit immediately without calling model.invoke
results = list(agent.stream(messages))
# Verify stream exited without processing (empty results)
assert len(results) == 0
# Verify model was not called
mock_model.invoke.assert_not_called()
# Test negative case - stream should continue when should_exit is not set
with agent_context() as ctx:
# Verify should_exit is not set
assert not should_exit()
# Execute stream
next(agent.stream(messages))
# Verify model was called
mock_model.invoke.assert_called_once()
def test_ciayn_agent_bundled_tools_respects_should_exit():
"""Test that the CIAYN agent respects should_exit when executing bundled tools."""
# Create mock model and tools
mock_model = Mock()
mock_tool1 = Mock()
mock_tool1.func.__name__ = "emit_key_facts"
mock_tool1.func.return_value = "Tool 1 executed"
mock_tool2 = Mock()
mock_tool2.func.__name__ = "emit_key_snippet"
mock_tool2.func.return_value = "Tool 2 executed"
# Configure model to return multiple tool calls
# Create a response with two bundled tool calls
mock_model.invoke.return_value = AIMessage(content="""
emit_key_facts(facts=["fact1", "fact2"])
emit_key_snippet(snippet_info={"filepath": "test.py", "line_number": 1, "snippet": "code", "description": "test"})
""")
# Create agent with bundleable tools
agent = CiaynAgent(mock_model, [mock_tool1, mock_tool2])
# Test executing bundled tools when should_exit is set
with agent_context() as ctx:
# Set up test messages
messages = {"messages": [HumanMessage(content="test")]}
# Set should_exit before calling stream
mark_should_exit()
# Call stream
generator = agent.stream(messages)
# Stream should exit without processing due to should_exit in the stream method
result = list(generator)
# Model should not have been called
mock_model.invoke.assert_not_called()
# Tools should not have been executed
mock_tool1.func.assert_not_called()
mock_tool2.func.assert_not_called()
def test_ciayn_agent_single_tool_respects_should_exit():
"""Test that the CIAYN agent respects should_exit when executing a single tool."""
# Create mock model and tool
mock_model = Mock()
mock_tool = Mock()
mock_tool.func.__name__ = "test_tool"
mock_tool.func.return_value = "Tool executed"
# Configure model to return a single tool call
mock_model.invoke.return_value = AIMessage(content="test_tool()")
# Create agent
agent = CiaynAgent(mock_model, [mock_tool])
# Set up a context that manipulates should_exit during execution
class TestContext:
def __init__(self):
self.should_exit_flag = False
def set_exit_flag(self):
self.should_exit_flag = True
return "Exit flag set"
def get_exit_flag(self):
return self.should_exit_flag
test_context = TestContext()
# Replace the actual tool execution with one that sets should_exit
def execute_and_exit(*args, **kwargs):
# Set should_exit flag
with agent_context() as ctx:
mark_should_exit()
return "Tool executed"
# Override the mock tool to set should_exit
mock_tool.func.side_effect = execute_and_exit
# Execute tool call that sets should_exit
with agent_context() as ctx:
# Verify should_exit is initially False
assert not should_exit()
# Set up test messages
messages = {"messages": [HumanMessage(content="test")]}
# Call stream
generator = agent.stream(messages)
# Get first response (which will execute the tool and set should_exit)
next(generator, None)
# Verify model was called
mock_model.invoke.assert_called_once()
# Verify tool was executed
mock_tool.func.assert_called_once()
# Verify should_exit was set
assert should_exit()
# Get next response (should be empty because stream should exit)
results = list(generator)
assert len(results) == 0
# Verify model wasn't called again
assert mock_model.invoke.call_count == 1
def test_ciayn_agent_execute_tool_respects_should_exit():
"""Test that _execute_tool respects should_exit."""
# Create mock model and tool
mock_model = Mock()
mock_tool = Mock()
mock_tool.func.__name__ = "test_tool"
# Create agent
agent = CiaynAgent(mock_model, [mock_tool])
# Test _execute_tool exits when should_exit is set
with agent_context() as ctx:
# Set should_exit
mark_should_exit()
# Call _execute_tool
message = HumanMessage(content="test_tool()")
result = agent._execute_tool(message)
# Verify early exit message
assert "agent should exit flag is set" in result
# Verify tool was not executed
mock_tool.func.assert_not_called()