set agent_should_exit in gc agents
This commit is contained in:
parent
53406f1ddf
commit
77856bfa0c
|
|
@ -19,6 +19,7 @@ from ra_aid.tools.expert import get_model
|
|||
from ra_aid.tools.reflection import get_function_info
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.console.formatting import print_warning, print_error
|
||||
from ra_aid.agent_context import should_exit
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -247,6 +248,11 @@ class CiaynAgent:
|
|||
|
||||
def _execute_tool(self, msg: BaseMessage) -> str:
|
||||
"""Execute a tool call and return its result."""
|
||||
|
||||
# Check for should_exit before executing tool calls
|
||||
if should_exit():
|
||||
logger.debug("Agent should exit flag detected in _execute_tool")
|
||||
return "Tool execution aborted - agent should exit flag is set"
|
||||
|
||||
code = msg.content
|
||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||
|
|
@ -263,10 +269,20 @@ class CiaynAgent:
|
|||
|
||||
# If we have multiple valid bundleable calls, execute them in sequence
|
||||
if len(tool_calls) > 1:
|
||||
# Check for should_exit before executing bundled tool calls
|
||||
if should_exit():
|
||||
logger.debug("Agent should exit flag detected before executing bundled tool calls")
|
||||
return "Bundled tool execution aborted - agent should exit flag is set"
|
||||
|
||||
results = []
|
||||
result_strings = []
|
||||
|
||||
for call in tool_calls:
|
||||
# Check if agent should exit
|
||||
if should_exit():
|
||||
logger.debug("Agent should exit flag detected during bundled tool execution")
|
||||
return "Tool execution interrupted: agent_should_exit flag is set."
|
||||
|
||||
# Validate and fix each call if needed
|
||||
if validate_function_call_pattern(call):
|
||||
functions_list = "\n\n".join(self.available_functions)
|
||||
|
|
@ -431,6 +447,12 @@ class CiaynAgent:
|
|||
logger.debug(f"Failed to parse parameters for duplicate detection: {str(e)}")
|
||||
pass
|
||||
|
||||
# Before executing the call
|
||||
if should_exit():
|
||||
logger.debug("Agent should exit flag detected before tool execution")
|
||||
return "Tool execution interrupted: agent_should_exit flag is set."
|
||||
|
||||
# Execute the tool
|
||||
result = eval(code.strip(), globals_dict)
|
||||
return result
|
||||
except Exception as e:
|
||||
|
|
@ -589,6 +611,11 @@ class CiaynAgent:
|
|||
max_empty_responses = 3 # Maximum number of consecutive empty responses before giving up
|
||||
|
||||
while True:
|
||||
# Check for should_exit
|
||||
if should_exit():
|
||||
logger.debug("Agent should exit flag detected in stream loop")
|
||||
break
|
||||
|
||||
base_prompt = self._build_prompt(last_result)
|
||||
self.chat_history.append(HumanMessage(content=base_prompt))
|
||||
full_history = self._trim_chat_history(initial_messages, self.chat_history)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from rich.panel import Panel
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ra_aid.agent_context import mark_should_exit
|
||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
|
@ -101,6 +102,9 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
|||
failed_msg = f"Failed to delete facts: {', '.join([f'#{fact_id}' for fact_id in failed_facts])}"
|
||||
result_parts.append(failed_msg)
|
||||
|
||||
# Mark that the agent should exit after completing this operation
|
||||
mark_should_exit()
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from ra_aid.database.repositories.config_repository import get_config_repository
|
|||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
from ra_aid.agent_context import mark_should_exit
|
||||
|
||||
|
||||
console = Console()
|
||||
|
|
@ -98,6 +99,9 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
failed_msg = f"Failed to delete snippets: {', '.join([f'#{snippet_id}' for snippet_id in failed_snippets])}"
|
||||
result_parts.append(failed_msg)
|
||||
|
||||
# Mark that the agent should exit
|
||||
mark_should_exit()
|
||||
|
||||
return "Snippets deleted."
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from rich.panel import Panel
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ra_aid.agent_context import mark_should_exit
|
||||
|
||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
|
@ -103,6 +105,9 @@ def delete_research_notes(note_ids: List[int]) -> str:
|
|||
failed_msg = f"Failed to delete research notes: {', '.join([f'#{note_id}' for note_id in failed_notes])}"
|
||||
result_parts.append(failed_msg)
|
||||
|
||||
# Mark the agent to exit after performing the cleanup operation
|
||||
mark_should_exit()
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,8 +20,7 @@ RESEARCH_COMMON_PROMPT_HEADER = """Current Date: {current_date}
|
|||
|
||||
KEEP IT SIMPLE
|
||||
|
||||
Context from Previous Research (if available):
|
||||
|
||||
<previous research>
|
||||
<key facts>
|
||||
{key_facts}
|
||||
</key facts>
|
||||
|
|
@ -43,6 +42,8 @@ Work already done:
|
|||
<project info>
|
||||
{project_info}
|
||||
</project info>
|
||||
<caveat>You should make the most efficient use of this previous research possible, with the caveat that not all of it will be relevant to the current task you are assigned with. Use this previous research to save redudant research, and to inform what you are currently tasked with. Be as efficient as possible.</caveat>
|
||||
</previous research>
|
||||
|
||||
Role
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,190 @@
|
|||
"""Tests for CIAYN agent respecting should_exit flag."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||
from ra_aid.agent_context import agent_context, mark_should_exit, should_exit
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
||||
|
||||
def test_ciayn_agent_stream_respects_should_exit():
|
||||
"""Test that the CIAYN agent's stream method respects should_exit."""
|
||||
# Create mock model and tool
|
||||
mock_model = Mock()
|
||||
mock_tool = Mock()
|
||||
mock_tool.func.__name__ = "mock_tool"
|
||||
|
||||
# Configure mock model to return a response
|
||||
mock_model.invoke.return_value = AIMessage(content="mock_tool()")
|
||||
|
||||
# Create agent
|
||||
agent = CiaynAgent(mock_model, [mock_tool])
|
||||
|
||||
# Set up test messages
|
||||
messages = {"messages": [HumanMessage(content="test")]}
|
||||
|
||||
# Test stream exits when should_exit is set
|
||||
with agent_context() as ctx:
|
||||
# Set should_exit
|
||||
mark_should_exit()
|
||||
|
||||
# Verify should_exit is set
|
||||
assert should_exit()
|
||||
|
||||
# Call stream - should exit immediately without calling model.invoke
|
||||
results = list(agent.stream(messages))
|
||||
|
||||
# Verify stream exited without processing (empty results)
|
||||
assert len(results) == 0
|
||||
|
||||
# Verify model was not called
|
||||
mock_model.invoke.assert_not_called()
|
||||
|
||||
# Test negative case - stream should continue when should_exit is not set
|
||||
with agent_context() as ctx:
|
||||
# Verify should_exit is not set
|
||||
assert not should_exit()
|
||||
|
||||
# Execute stream
|
||||
next(agent.stream(messages))
|
||||
|
||||
# Verify model was called
|
||||
mock_model.invoke.assert_called_once()
|
||||
|
||||
|
||||
def test_ciayn_agent_bundled_tools_respects_should_exit():
|
||||
"""Test that the CIAYN agent respects should_exit when executing bundled tools."""
|
||||
# Create mock model and tools
|
||||
mock_model = Mock()
|
||||
mock_tool1 = Mock()
|
||||
mock_tool1.func.__name__ = "emit_key_facts"
|
||||
mock_tool1.func.return_value = "Tool 1 executed"
|
||||
mock_tool2 = Mock()
|
||||
mock_tool2.func.__name__ = "emit_key_snippet"
|
||||
mock_tool2.func.return_value = "Tool 2 executed"
|
||||
|
||||
# Configure model to return multiple tool calls
|
||||
# Create a response with two bundled tool calls
|
||||
mock_model.invoke.return_value = AIMessage(content="""
|
||||
emit_key_facts(facts=["fact1", "fact2"])
|
||||
emit_key_snippet(snippet_info={"filepath": "test.py", "line_number": 1, "snippet": "code", "description": "test"})
|
||||
""")
|
||||
|
||||
# Create agent with bundleable tools
|
||||
agent = CiaynAgent(mock_model, [mock_tool1, mock_tool2])
|
||||
|
||||
# Test executing bundled tools when should_exit is set
|
||||
with agent_context() as ctx:
|
||||
# Set up test messages
|
||||
messages = {"messages": [HumanMessage(content="test")]}
|
||||
|
||||
# Set should_exit before calling stream
|
||||
mark_should_exit()
|
||||
|
||||
# Call stream
|
||||
generator = agent.stream(messages)
|
||||
# Stream should exit without processing due to should_exit in the stream method
|
||||
result = list(generator)
|
||||
|
||||
# Model should not have been called
|
||||
mock_model.invoke.assert_not_called()
|
||||
|
||||
# Tools should not have been executed
|
||||
mock_tool1.func.assert_not_called()
|
||||
mock_tool2.func.assert_not_called()
|
||||
|
||||
|
||||
def test_ciayn_agent_single_tool_respects_should_exit():
|
||||
"""Test that the CIAYN agent respects should_exit when executing a single tool."""
|
||||
# Create mock model and tool
|
||||
mock_model = Mock()
|
||||
mock_tool = Mock()
|
||||
mock_tool.func.__name__ = "test_tool"
|
||||
mock_tool.func.return_value = "Tool executed"
|
||||
|
||||
# Configure model to return a single tool call
|
||||
mock_model.invoke.return_value = AIMessage(content="test_tool()")
|
||||
|
||||
# Create agent
|
||||
agent = CiaynAgent(mock_model, [mock_tool])
|
||||
|
||||
# Set up a context that manipulates should_exit during execution
|
||||
class TestContext:
|
||||
def __init__(self):
|
||||
self.should_exit_flag = False
|
||||
|
||||
def set_exit_flag(self):
|
||||
self.should_exit_flag = True
|
||||
return "Exit flag set"
|
||||
|
||||
def get_exit_flag(self):
|
||||
return self.should_exit_flag
|
||||
|
||||
test_context = TestContext()
|
||||
|
||||
# Replace the actual tool execution with one that sets should_exit
|
||||
def execute_and_exit(*args, **kwargs):
|
||||
# Set should_exit flag
|
||||
with agent_context() as ctx:
|
||||
mark_should_exit()
|
||||
return "Tool executed"
|
||||
|
||||
# Override the mock tool to set should_exit
|
||||
mock_tool.func.side_effect = execute_and_exit
|
||||
|
||||
# Execute tool call that sets should_exit
|
||||
with agent_context() as ctx:
|
||||
# Verify should_exit is initially False
|
||||
assert not should_exit()
|
||||
|
||||
# Set up test messages
|
||||
messages = {"messages": [HumanMessage(content="test")]}
|
||||
|
||||
# Call stream
|
||||
generator = agent.stream(messages)
|
||||
# Get first response (which will execute the tool and set should_exit)
|
||||
next(generator, None)
|
||||
|
||||
# Verify model was called
|
||||
mock_model.invoke.assert_called_once()
|
||||
|
||||
# Verify tool was executed
|
||||
mock_tool.func.assert_called_once()
|
||||
|
||||
# Verify should_exit was set
|
||||
assert should_exit()
|
||||
|
||||
# Get next response (should be empty because stream should exit)
|
||||
results = list(generator)
|
||||
assert len(results) == 0
|
||||
|
||||
# Verify model wasn't called again
|
||||
assert mock_model.invoke.call_count == 1
|
||||
|
||||
|
||||
def test_ciayn_agent_execute_tool_respects_should_exit():
|
||||
"""Test that _execute_tool respects should_exit."""
|
||||
# Create mock model and tool
|
||||
mock_model = Mock()
|
||||
mock_tool = Mock()
|
||||
mock_tool.func.__name__ = "test_tool"
|
||||
|
||||
# Create agent
|
||||
agent = CiaynAgent(mock_model, [mock_tool])
|
||||
|
||||
# Test _execute_tool exits when should_exit is set
|
||||
with agent_context() as ctx:
|
||||
# Set should_exit
|
||||
mark_should_exit()
|
||||
|
||||
# Call _execute_tool
|
||||
message = HumanMessage(content="test_tool()")
|
||||
result = agent._execute_tool(message)
|
||||
|
||||
# Verify early exit message
|
||||
assert "agent should exit flag is set" in result
|
||||
|
||||
# Verify tool was not executed
|
||||
mock_tool.func.assert_not_called()
|
||||
Loading…
Reference in New Issue