326 lines
11 KiB
Python
326 lines
11 KiB
Python
import unittest
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
|
from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern
|
|
from ra_aid.exceptions import ToolExecutionError
|
|
|
|
|
|
# Dummy tool function for testing retry and fallback behavior
|
|
def dummy_tool():
|
|
dummy_tool.attempt += 1
|
|
if dummy_tool.attempt < 3:
|
|
raise Exception("Simulated failure")
|
|
return "dummy success"
|
|
|
|
|
|
dummy_tool.attempt = 0
|
|
|
|
|
|
class DummyTool:
|
|
def __init__(self, func):
|
|
self.func = func
|
|
|
|
|
|
class DummyModel:
|
|
def invoke(self, messages):
|
|
# Always return a code snippet that calls dummy_tool()
|
|
class Response:
|
|
content = "dummy_tool()"
|
|
|
|
return Response()
|
|
|
|
def bind_tools(self, tools, tool_choice):
|
|
pass
|
|
|
|
|
|
# Fixtures from the source file
|
|
@pytest.fixture
|
|
def mock_model():
|
|
"""Create a mock language model."""
|
|
model = Mock()
|
|
model.invoke = Mock()
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
def agent(mock_model):
|
|
"""Create a CiaynAgent instance with mock model."""
|
|
tools = [] # Empty tools list for testing trimming functionality
|
|
return CiaynAgent(mock_model, tools, max_history_messages=3)
|
|
|
|
|
|
# Trimming test functions
|
|
def test_trim_chat_history_preserves_initial_messages(agent):
|
|
"""Test that initial messages are preserved during trimming."""
|
|
initial_messages = [
|
|
HumanMessage(content="Initial 1"),
|
|
AIMessage(content="Initial 2"),
|
|
]
|
|
chat_history = [
|
|
HumanMessage(content="Chat 1"),
|
|
AIMessage(content="Chat 2"),
|
|
HumanMessage(content="Chat 3"),
|
|
AIMessage(content="Chat 4"),
|
|
]
|
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
|
# Verify initial messages are preserved
|
|
assert result[:2] == initial_messages
|
|
# Verify only last 3 chat messages are kept (due to max_history_messages=3)
|
|
assert len(result[2:]) == 3
|
|
assert result[2:] == chat_history[-3:]
|
|
|
|
|
|
def test_trim_chat_history_under_limit(agent):
|
|
"""Test trimming when chat history is under the maximum limit."""
|
|
initial_messages = [HumanMessage(content="Initial")]
|
|
chat_history = [HumanMessage(content="Chat 1"), AIMessage(content="Chat 2")]
|
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
|
# Verify no trimming occurred
|
|
assert len(result) == 3
|
|
assert result == initial_messages + chat_history
|
|
|
|
|
|
def test_trim_chat_history_over_limit(agent):
|
|
"""Test trimming when chat history exceeds the maximum limit."""
|
|
initial_messages = [HumanMessage(content="Initial")]
|
|
chat_history = [
|
|
HumanMessage(content="Chat 1"),
|
|
AIMessage(content="Chat 2"),
|
|
HumanMessage(content="Chat 3"),
|
|
AIMessage(content="Chat 4"),
|
|
HumanMessage(content="Chat 5"),
|
|
]
|
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
|
# Verify correct trimming
|
|
assert len(result) == 4 # initial + max_history_messages
|
|
assert result[0] == initial_messages[0] # Initial message preserved
|
|
assert result[1:] == chat_history[-3:] # Last 3 chat messages kept
|
|
|
|
|
|
def test_trim_chat_history_empty_initial(agent):
|
|
"""Test trimming with empty initial messages."""
|
|
initial_messages = []
|
|
chat_history = [
|
|
HumanMessage(content="Chat 1"),
|
|
AIMessage(content="Chat 2"),
|
|
HumanMessage(content="Chat 3"),
|
|
AIMessage(content="Chat 4"),
|
|
]
|
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
|
# Verify only last 3 messages are kept
|
|
assert len(result) == 3
|
|
assert result == chat_history[-3:]
|
|
|
|
|
|
def test_trim_chat_history_empty_chat(agent):
|
|
"""Test trimming with empty chat history."""
|
|
initial_messages = [
|
|
HumanMessage(content="Initial 1"),
|
|
AIMessage(content="Initial 2"),
|
|
]
|
|
chat_history = []
|
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
|
# Verify initial messages are preserved and no trimming occurred
|
|
assert result == initial_messages
|
|
assert len(result) == 2
|
|
|
|
|
|
def test_trim_chat_history_token_limit():
|
|
"""Test trimming based on token limit."""
|
|
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=25)
|
|
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
|
|
chat_history = [
|
|
HumanMessage(content="A" * 40), # ~10 tokens
|
|
AIMessage(content="B" * 40), # ~10 tokens
|
|
HumanMessage(content="C" * 40), # ~10 tokens
|
|
]
|
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
|
# Should keep initial message (~2 tokens) and last message (~10 tokens)
|
|
assert len(result) == 2
|
|
assert result[0] == initial_messages[0]
|
|
assert result[1] == chat_history[-1]
|
|
|
|
|
|
def test_trim_chat_history_no_token_limit():
|
|
"""Test trimming with no token limit set."""
|
|
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
|
|
initial_messages = [HumanMessage(content="Initial")]
|
|
chat_history = [
|
|
HumanMessage(content="A" * 1000),
|
|
AIMessage(content="B" * 1000),
|
|
HumanMessage(content="C" * 1000),
|
|
]
|
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
|
# Should keep initial message and last 2 messages (max_history_messages=2)
|
|
assert len(result) == 3
|
|
assert result[0] == initial_messages[0]
|
|
assert result[1:] == chat_history[-2:]
|
|
|
|
|
|
def test_trim_chat_history_both_limits():
|
|
"""Test trimming with both message count and token limits."""
|
|
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=35)
|
|
initial_messages = [HumanMessage(content="Init")] # ~1 token
|
|
chat_history = [
|
|
HumanMessage(content="A" * 40), # ~10 tokens
|
|
AIMessage(content="B" * 40), # ~10 tokens
|
|
HumanMessage(content="C" * 40), # ~10 tokens
|
|
AIMessage(content="D" * 40), # ~10 tokens
|
|
]
|
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
|
# Should first apply message limit (keeping last 3)
|
|
# Then token limit should further reduce to fit under 15 tokens
|
|
assert len(result) == 2 # Initial message + 1 message under token limit
|
|
assert result[0] == initial_messages[0]
|
|
assert result[1] == chat_history[-1]
|
|
|
|
|
|
# Fallback tests
|
|
class TestCiaynAgentFallback(unittest.TestCase):
|
|
def setUp(self):
|
|
# Reset dummy_tool attempt counter before each test
|
|
dummy_tool.attempt = 0
|
|
self.dummy_tool = DummyTool(dummy_tool)
|
|
self.model = DummyModel()
|
|
# Create a CiaynAgent with the dummy tool
|
|
self.agent = CiaynAgent(self.model, [self.dummy_tool])
|
|
|
|
def test_retry_logic_with_failure_recovery(self):
|
|
# Test that _execute_tool retries and eventually returns success
|
|
result = self.agent._execute_tool("dummy_tool()")
|
|
self.assertEqual(result, "dummy success")
|
|
|
|
def test_switch_models_on_fallback(self):
|
|
# Test fallback behavior by making dummy_tool always fail
|
|
def always_fail():
|
|
raise Exception("Persistent failure")
|
|
|
|
always_fail_tool = DummyTool(always_fail)
|
|
agent = CiaynAgent(self.model, [always_fail_tool])
|
|
with self.assertRaises(ToolExecutionError):
|
|
agent._execute_tool("always_fail()")
|
|
|
|
|
|
# Function call validation tests
|
|
class TestFunctionCallValidation:
|
|
@pytest.mark.parametrize(
|
|
"test_input",
|
|
[
|
|
"basic_func()",
|
|
'func_with_arg("test")',
|
|
'complex_func(1, "two", three)',
|
|
'nested_parens(func("test"))',
|
|
"under_score()",
|
|
"with-dash()",
|
|
],
|
|
)
|
|
def test_valid_function_calls(self, test_input):
|
|
"""Test function call patterns that should pass validation."""
|
|
assert not validate_function_call_pattern(test_input)
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_input",
|
|
[
|
|
"",
|
|
"Invalid!function()",
|
|
"missing_parens",
|
|
"unmatched(parens))",
|
|
"multiple()calls()",
|
|
"no spaces()()",
|
|
],
|
|
)
|
|
def test_invalid_function_calls(self, test_input):
|
|
"""Test function call patterns that should fail validation."""
|
|
assert validate_function_call_pattern(test_input)
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_input",
|
|
[
|
|
" leading_space()",
|
|
"trailing_space() ",
|
|
"func (arg)",
|
|
"func( spaced args )",
|
|
],
|
|
)
|
|
def test_whitespace_handling(self, test_input):
|
|
"""Test whitespace variations in function calls."""
|
|
assert not validate_function_call_pattern(test_input)
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_input",
|
|
[
|
|
"""multiline(
|
|
arg
|
|
)""",
|
|
"func(\n arg1,\n arg2\n)",
|
|
],
|
|
)
|
|
def test_multiline_responses(self, test_input):
|
|
"""Test function calls spanning multiple lines."""
|
|
assert not validate_function_call_pattern(test_input)
|
|
|
|
|
|
class TestCiaynAgentNewMethods(unittest.TestCase):
|
|
def setUp(self):
|
|
# Create a dummy tool that always fails for testing fallback
|
|
def always_fail():
|
|
raise Exception("Failure for fallback test")
|
|
|
|
self.always_fail_tool = DummyTool(always_fail)
|
|
# Create a dummy model that does minimal work for fallback tests
|
|
self.dummy_model = DummyModel()
|
|
# Initialize CiaynAgent with configuration to trigger fallback quickly
|
|
self.agent = CiaynAgent(
|
|
self.dummy_model,
|
|
[self.always_fail_tool],
|
|
config={
|
|
"max_tool_failures": 2,
|
|
"fallback_tool_models": "dummy-fallback-model",
|
|
},
|
|
)
|
|
|
|
def test_handle_tool_failure_increments_counter(self):
|
|
initial_failures = self.agent.tool_failure_consecutive_failures
|
|
self.agent._handle_tool_failure("dummy_call()", Exception("Test error"))
|
|
self.assertEqual(
|
|
self.agent.tool_failure_consecutive_failures, initial_failures + 1
|
|
)
|
|
|
|
def test_attempt_fallback_invokes_fallback_logic(self):
|
|
# Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env
|
|
# to simulate fallback switching without external dependencies.
|
|
def dummy_initialize_llm(provider, model_name, temperature=None):
|
|
return self.dummy_model
|
|
|
|
def dummy_merge_chat_history():
|
|
return ["merged"]
|
|
|
|
def dummy_validate_provider_env(provider):
|
|
return True
|
|
|
|
import ra_aid.llm as llm
|
|
|
|
original_initialize = llm.initialize_llm
|
|
original_merge = llm.merge_chat_history
|
|
original_validate = llm.validate_provider_env
|
|
llm.initialize_llm = dummy_initialize_llm
|
|
llm.merge_chat_history = dummy_merge_chat_history
|
|
llm.validate_provider_env = dummy_validate_provider_env
|
|
|
|
# Set failure counter high enough to trigger fallback in _handle_tool_failure
|
|
self.agent.tool_failure_consecutive_failures = 2
|
|
# Call _attempt_fallback; it should reset the failure counter to 0 on success.
|
|
self.agent._attempt_fallback("always_fail_tool()")
|
|
self.assertEqual(self.agent.tool_failure_consecutive_failures, 0)
|
|
# Restore original functions
|
|
llm.initialize_llm = original_initialize
|
|
llm.merge_chat_history = original_merge
|
|
llm.validate_provider_env = original_validate
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|