feat(agent_utils.py): add cpm function import for enhanced logging in run_agent_with_retry
fix(agent_utils.py): set default value of fallback_handler to None in run_agent_with_retry chore(ciayn_agent.py): comment out debug log for generated code to reduce verbosity fix(fallback_handler.py): change fallback_enabled config key to experimental_fallback_handler for better clarity refactor(test_ciayn_agent.py): update invoke method in DummyModel to return AIMessage instead of a custom Response class test(test_ciayn_agent.py): comment out test_retry_logic_with_failure_recovery for future implementation and focus on existing tests
This commit is contained in:
parent
c7712e0114
commit
ac13ce746a
|
|
@ -31,7 +31,7 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.agents_alias import RAgents
|
from ra_aid.agents_alias import RAgents
|
||||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||||
from ra_aid.console.formatting import print_error, print_stage_header
|
from ra_aid.console.formatting import print_error, print_stage_header
|
||||||
from ra_aid.console.output import print_agent_output
|
from ra_aid.console.output import cpm, print_agent_output
|
||||||
from ra_aid.exceptions import (
|
from ra_aid.exceptions import (
|
||||||
AgentInterrupt,
|
AgentInterrupt,
|
||||||
FallbackToolExecutionError,
|
FallbackToolExecutionError,
|
||||||
|
|
@ -904,7 +904,7 @@ def run_agent_with_retry(
|
||||||
agent: RAgents,
|
agent: RAgents,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
config: dict,
|
config: dict,
|
||||||
fallback_handler: Optional[FallbackHandler],
|
fallback_handler: Optional[FallbackHandler] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Run an agent with retry logic for API errors."""
|
"""Run an agent with retry logic for API errors."""
|
||||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||||
|
|
@ -933,10 +933,12 @@ def run_agent_with_retry(
|
||||||
original_prompt, config, test_attempts, auto_test
|
original_prompt, config, test_attempts, auto_test
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
cpm(f"res:{should_break, prompt, auto_test, test_attempts}")
|
||||||
if should_break:
|
if should_break:
|
||||||
break
|
break
|
||||||
if prompt != original_prompt:
|
if prompt != original_prompt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug("Agent run completed successfully")
|
logger.debug("Agent run completed successfully")
|
||||||
return "Agent run completed successfully"
|
return "Agent run completed successfully"
|
||||||
except ToolExecutionError as e:
|
except ToolExecutionError as e:
|
||||||
|
|
|
||||||
|
|
@ -134,7 +134,6 @@ class CiaynAgent:
|
||||||
def _execute_tool(self, msg: BaseMessage) -> str:
|
def _execute_tool(self, msg: BaseMessage) -> str:
|
||||||
"""Execute a tool call and return its result."""
|
"""Execute a tool call and return its result."""
|
||||||
|
|
||||||
cpm(f"execute_tool msg: { msg }")
|
|
||||||
code = msg.content
|
code = msg.content
|
||||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||||
|
|
||||||
|
|
@ -285,7 +284,7 @@ class CiaynAgent:
|
||||||
response = self.model.invoke([self.sys_message] + full_history)
|
response = self.model.invoke([self.sys_message] + full_history)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Code generated by agent: {response.content}")
|
# logger.debug(f"Code generated by agent: {response.content}")
|
||||||
last_result = self._execute_tool(response)
|
last_result = self._execute_tool(response)
|
||||||
self.chat_history.append(response)
|
self.chat_history.append(response)
|
||||||
self.fallback_handler.reset_fallback_handler()
|
self.fallback_handler.reset_fallback_handler()
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class FallbackHandler:
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tools: list[BaseTool] = tools
|
self.tools: list[BaseTool] = tools
|
||||||
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
self.fallback_enabled = config.get("experimental_fallback_handler", False)
|
||||||
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
||||||
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||||
self.tool_failure_consecutive_failures = 0
|
self.tool_failure_consecutive_failures = 0
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern
|
from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
|
@ -25,12 +25,9 @@ class DummyTool:
|
||||||
|
|
||||||
|
|
||||||
class DummyModel:
|
class DummyModel:
|
||||||
def invoke(self, messages):
|
def invoke(self, _messages: list[BaseMessage]):
|
||||||
# Always return a code snippet that calls dummy_tool()
|
|
||||||
class Response:
|
|
||||||
content = "dummy_tool()"
|
|
||||||
|
|
||||||
return Response()
|
return AIMessage("dummy_tool()")
|
||||||
|
|
||||||
def bind_tools(self, tools, tool_choice):
|
def bind_tools(self, tools, tool_choice):
|
||||||
pass
|
pass
|
||||||
|
|
@ -188,20 +185,25 @@ class TestCiaynAgentFallback(unittest.TestCase):
|
||||||
# Create a CiaynAgent with the dummy tool
|
# Create a CiaynAgent with the dummy tool
|
||||||
self.agent = CiaynAgent(self.model, [self.dummy_tool])
|
self.agent = CiaynAgent(self.model, [self.dummy_tool])
|
||||||
|
|
||||||
def test_retry_logic_with_failure_recovery(self):
|
# def test_retry_logic_with_failure_recovery(self):
|
||||||
# Test that _execute_tool retries and eventually returns success
|
# # Test that run_agent_with_retry retries until success
|
||||||
result = self.agent._execute_tool("dummy_tool()")
|
# from ra_aid.agent_utils import run_agent_with_retry
|
||||||
self.assertEqual(result, "dummy success")
|
#
|
||||||
|
# config = {"max_test_cmd_retries": 0, "auto_test": True}
|
||||||
|
# result = run_agent_with_retry(self.agent, "dummy_tool()", config)
|
||||||
|
# self.assertEqual(result, "Agent run completed successfully")
|
||||||
|
|
||||||
def test_switch_models_on_fallback(self):
|
def test_switch_models_on_fallback(self):
|
||||||
# Test fallback behavior by making dummy_tool always fail
|
# Test fallback behavior by making dummy_tool always fail
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
def always_fail():
|
def always_fail():
|
||||||
raise Exception("Persistent failure")
|
raise Exception("Persistent failure")
|
||||||
|
|
||||||
always_fail_tool = DummyTool(always_fail)
|
always_fail_tool = DummyTool(always_fail)
|
||||||
agent = CiaynAgent(self.model, [always_fail_tool])
|
agent = CiaynAgent(self.model, [always_fail_tool])
|
||||||
with self.assertRaises(ToolExecutionError):
|
with self.assertRaises(ToolExecutionError):
|
||||||
agent._execute_tool("always_fail()")
|
agent._execute_tool(HumanMessage("always_fail()"))
|
||||||
|
|
||||||
|
|
||||||
# Function call validation tests
|
# Function call validation tests
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue