refactor(agent_utils.py, ciayn_agent.py): remove unused import cpm to clean up code and improve readability

style(tests): format code for better readability and consistency in test files
test(tests): update assertions and test cases for better clarity and maintainability
This commit is contained in:
Ariel Frischer 2025-02-14 13:04:41 -08:00
parent 7a2c766824
commit d5e2e0a9a0
6 changed files with 119 additions and 48 deletions

View File

@ -31,7 +31,7 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.agents_alias import RAgents
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
from ra_aid.console.formatting import print_error, print_stage_header
from ra_aid.console.output import cpm, print_agent_output
from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import (
AgentInterrupt,
FallbackToolExecutionError,

View File

@ -7,7 +7,6 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, System
from langchain_core.tools import BaseTool
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
from ra_aid.console.output import cpm
from ra_aid.exceptions import ToolExecutionError
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger

View File

@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
import litellm
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from ra_aid.agent_utils import (
AgentState,
@ -128,9 +128,10 @@ def test_create_agent_openai(mock_model, mock_memory):
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model, [],
mock_model,
[],
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
config={'provider': 'openai', 'model': 'gpt-4'}
config={"provider": "openai", "model": "gpt-4"},
)
@ -144,9 +145,10 @@ def test_create_agent_no_token_limit(mock_model, mock_memory):
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model, [],
mock_model,
[],
max_tokens=DEFAULT_TOKEN_LIMIT,
config={'provider': 'unknown', 'model': 'unknown-model'}
config={"provider": "unknown", "model": "unknown-model"},
)
@ -163,7 +165,7 @@ def test_create_agent_missing_config(mock_model, mock_memory):
mock_model,
[],
max_tokens=DEFAULT_TOKEN_LIMIT,
config={'provider': 'openai'}
config={"provider": "openai"},
)
@ -207,9 +209,10 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory):
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model, [],
mock_model,
[],
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
config={'provider': 'openai', 'model': 'gpt-4'}
config={"provider": "openai", "model": "gpt-4"},
)

View File

@ -26,7 +26,6 @@ class DummyTool:
class DummyModel:
def invoke(self, _messages: list[BaseMessage]):
return AIMessage("dummy_tool()")
def bind_tools(self, tools, tool_choice):

View File

@ -86,6 +86,7 @@ class TestFallbackHandler(unittest.TestCase):
def test_load_fallback_tool_models(self):
import ra_aid.fallback_handler as fh
original_supported = fh.supported_top_tool_models
fh.supported_top_tool_models = [
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
@ -95,13 +96,16 @@ class TestFallbackHandler(unittest.TestCase):
fh.supported_top_tool_models = original_supported
def test_extract_failed_tool_name(self):
from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
# Case when tool_name is provided
error1 = ToolExecutionError("Error", base_message="dummy", tool_name="dummy_tool")
error1 = ToolExecutionError(
"Error", base_message="dummy", tool_name="dummy_tool"
)
name1 = self.fallback_handler.extract_failed_tool_name(error1)
self.assertEqual(name1, "dummy_tool")
# Case when tool_name is not provided but regex works
error2 = ToolExecutionError("error with name=\"test_tool\"")
error2 = ToolExecutionError('error with name="test_tool"')
name2 = self.fallback_handler.extract_failed_tool_name(error2)
self.assertEqual(name2, "test_tool")
# Case when regex fails and exception is raised
@ -110,16 +114,13 @@ class TestFallbackHandler(unittest.TestCase):
self.fallback_handler.extract_failed_tool_name(error3)
def test_find_tool_to_bind(self):
# Create a dummy tool to be found
class DummyTool:
def invoke(self, args):
return "result"
class DummyWrapper:
def __init__(self, func):
self.func = func
def dummy_func(args):
def dummy_func(_args):
return "result"
dummy_tool = DummyTool()
dummy_wrapper = DummyWrapper(dummy_func)
self.agent.tools.append(dummy_wrapper)
tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__)
@ -134,53 +135,82 @@ class TestFallbackHandler(unittest.TestCase):
self.tools = tools
self.tool_choice = tool_choice
return self
def with_retry(self, stop_after_attempt):
return self
def invoke(self, msg_list):
return "dummy_response"
dummy_model = DummyModel()
# Set current tool for binding
class DummyTool:
def invoke(self, args):
return "result"
self.fallback_handler.current_tool_to_bind = DummyTool()
self.fallback_handler.current_failing_tool_name = "test_tool"
# Test with force calling ("fc") type
fallback_model_fc = {"type": "fc"}
bound_model_fc = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_fc)
bound_model_fc = self.fallback_handler._bind_tool_model(
dummy_model, fallback_model_fc
)
self.assertTrue(hasattr(bound_model_fc, "tool_choice"))
self.assertEqual(bound_model_fc.tool_choice, "test_tool")
# Test with prompt type
fallback_model_prompt = {"type": "prompt"}
bound_model_prompt = self.fallback_handler._bind_tool_model(dummy_model, fallback_model_prompt)
bound_model_prompt = self.fallback_handler._bind_tool_model(
dummy_model, fallback_model_prompt
)
self.assertTrue(bound_model_prompt.tool_choice is None)
def test_invoke_fallback(self):
from unittest.mock import patch
import os
import ra_aid.llm as llm
from unittest.mock import patch
# Successful fallback scenario with proper API key set
with patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}), \
patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \
patch("ra_aid.fallback_handler.validate_provider_env", return_value=True), \
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm:
with (
patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}),
patch(
"ra_aid.fallback_handler.supported_top_tool_models",
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
),
patch("ra_aid.fallback_handler.validate_provider_env", return_value=True),
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
):
class DummyModel:
def bind_tools(self, tools, tool_choice=None):
return self
def with_retry(self, stop_after_attempt):
return self
def invoke(self, msg_list):
return DummyResponse()
class DummyResponse:
additional_kwargs = {"tool_calls": [{"id": "1", "type": "test", "function": {"name": "dummy_tool", "arguments": "{\"a\":1}"}}]}
additional_kwargs = {
"tool_calls": [
{
"id": "1",
"type": "test",
"function": {"name": "dummy_tool", "arguments": '{"a":1}'},
}
]
}
def dummy_initialize_llm(provider, model_name):
return DummyModel()
mock_init_llm.side_effect = dummy_initialize_llm
# Set current tool for fallback
class DummyTool:
def invoke(self, args):
return "tool_result"
self.fallback_handler.current_tool_to_bind = DummyTool()
self.fallback_handler.current_failing_tool_name = "dummy_tool"
# Add dummy tool for lookup in invoke_prompt_tool_call
@ -194,39 +224,57 @@ class TestFallbackHandler(unittest.TestCase):
},
)
)
result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"})
result = self.fallback_handler.invoke_fallback(
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
)
self.assertIsInstance(result, list)
self.assertEqual(result[1], "tool_result")
# Failed fallback scenario due to missing API key (simulate by empty environment)
with patch.dict(os.environ, {}, clear=True), \
patch("ra_aid.fallback_handler.supported_top_tool_models", new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}]), \
patch("ra_aid.fallback_handler.validate_provider_env", return_value=False), \
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm:
with (
patch.dict(os.environ, {}, clear=True),
patch(
"ra_aid.fallback_handler.supported_top_tool_models",
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
),
patch("ra_aid.fallback_handler.validate_provider_env", return_value=False),
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
):
class FailingDummyModel:
def bind_tools(self, tools, tool_choice=None):
return self
def with_retry(self, stop_after_attempt):
return self
def invoke(self, msg_list):
raise Exception("API key missing")
def failing_initialize_llm(provider, model_name):
return FailingDummyModel()
mock_init_llm.side_effect = failing_initialize_llm
fallback_result = self.fallback_handler.invoke_fallback({"provider": "dummy", "model": "dummy_model", "type": "prompt"})
fallback_result = self.fallback_handler.invoke_fallback(
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
)
self.assertIsNone(fallback_result)
# Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail
# Set failure count to trigger the fallback attempt in attempt_fallback
from ra_aid.exceptions import FallbackToolExecutionError
self.fallback_handler.tool_failure_consecutive_failures = self.fallback_handler.max_failures
self.fallback_handler.tool_failure_consecutive_failures = (
self.fallback_handler.max_failures
)
with self.assertRaises(FallbackToolExecutionError) as cm:
self.fallback_handler.attempt_fallback()
self.assertIn("All fallback models have failed", str(cm.exception))
def test_construct_prompt_msg_list(self):
msgs = self.fallback_handler.construct_prompt_msg_list()
from ra_aid.fallback_handler import SystemMessage, HumanMessage
from ra_aid.fallback_handler import HumanMessage, SystemMessage
self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs))
self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs))
# Test with failed_messages added
@ -238,13 +286,17 @@ class TestFallbackHandler(unittest.TestCase):
# Create dummy tool function
def dummy_tool_func(args):
return "invoked_result"
dummy_tool_func.__name__ = "dummy_tool"
# Create wrapper class
class DummyToolWrapper:
def __init__(self, func):
self.func = func
def invoke(self, args):
return self.func(args)
dummy_wrapper = DummyToolWrapper(dummy_tool_func)
self.fallback_handler.tools = [dummy_wrapper]
tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}}
@ -255,9 +307,13 @@ class TestFallbackHandler(unittest.TestCase):
dummy_tool_call = {
"id": "123",
"type": "test",
"function": {"name": "dummy_tool", "arguments": "{\"x\":42}"}
"function": {"name": "dummy_tool", "arguments": '{"x":42}'},
}
DummyResponse = type("DummyResponse", (), {"additional_kwargs": {"tool_calls": [dummy_tool_call]}})
DummyResponse = type(
"DummyResponse",
(),
{"additional_kwargs": {"tool_calls": [dummy_tool_call]}},
)
result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse)
self.assertEqual(result["id"], "123")
self.assertEqual(result["name"], "dummy_tool")
@ -285,22 +341,30 @@ class TestFallbackHandler(unittest.TestCase):
def test_handle_failure_response(self):
from ra_aid.exceptions import ToolExecutionError
def dummy_handle_failure(error, agent):
return ["fallback_response"]
self.fallback_handler.handle_failure = dummy_handle_failure
response = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React")
from ra_aid.fallback_handler import SystemMessage
self.assertTrue(all(isinstance(m, SystemMessage) for m in response))
response_non = self.fallback_handler.handle_failure_response(ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other")
self.assertIsNone(response_non)
self.fallback_handler.handle_failure = dummy_handle_failure
response = self.fallback_handler.handle_failure_response(
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React"
)
from ra_aid.fallback_handler import SystemMessage
self.assertTrue(all(isinstance(m, SystemMessage) for m in response))
response_non = self.fallback_handler.handle_failure_response(
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other"
)
self.assertIsNone(response_non)
def test_init_msg_list_non_overlapping(self):
# Test when the first two and last two messages do not overlap.
full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"]
self.fallback_handler.init_msg_list(full_list)
# Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5")
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"])
self.assertEqual(
self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"]
)
def test_init_msg_list_with_overlap(self):
# Test when the last two messages overlap with the first two.
@ -309,5 +373,6 @@ class TestFallbackHandler(unittest.TestCase):
# Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present.
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"])
if __name__ == "__main__":
unittest.main()

View File

@ -121,7 +121,9 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
def test_initialize_expert_unsupported_provider(clean_env):
"""Test error handling for unsupported provider in expert mode."""
with pytest.raises(ValueError, match=r"Missing required environment variable for provider: unknown"):
with pytest.raises(
ValueError, match=r"Missing required environment variable for provider: unknown"
):
initialize_expert_llm("unknown", "model")
@ -197,7 +199,10 @@ def test_initialize_unsupported_provider(clean_env):
"""Test initialization with unsupported provider raises ValueError"""
with pytest.raises(ValueError) as exc_info:
initialize_llm("unsupported", "model")
assert str(exc_info.value) == "Missing required environment variable for provider: unsupported"
assert (
str(exc_info.value)
== "Missing required environment variable for provider: unsupported"
)
def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini):