205 lines
6.8 KiB
Python
205 lines
6.8 KiB
Python
"""Unit tests for agent_utils.py."""
|
|
|
|
import pytest
|
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
|
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT
|
|
from ra_aid.agent_utils import state_modifier, AgentState
|
|
from unittest.mock import Mock, patch
|
|
from langchain_core.language_models import BaseChatModel
|
|
|
|
from ra_aid.agent_utils import create_agent, get_model_token_limit
|
|
from ra_aid.models_tokens import models_tokens
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model():
|
|
"""Fixture providing a mock LLM model."""
|
|
model = Mock(spec=BaseChatModel)
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_memory():
|
|
"""Fixture providing a mock global memory store."""
|
|
with patch("ra_aid.agent_utils._global_memory") as mock_mem:
|
|
mock_mem.get.return_value = {}
|
|
yield mock_mem
|
|
|
|
|
|
def test_get_model_token_limit_anthropic(mock_memory):
|
|
"""Test get_model_token_limit with Anthropic model."""
|
|
config = {"provider": "anthropic", "model": "claude2"}
|
|
|
|
token_limit = get_model_token_limit(config)
|
|
assert token_limit == models_tokens["anthropic"]["claude2"]
|
|
|
|
|
|
def test_get_model_token_limit_openai(mock_memory):
|
|
"""Test get_model_token_limit with OpenAI model."""
|
|
config = {"provider": "openai", "model": "gpt-4"}
|
|
|
|
token_limit = get_model_token_limit(config)
|
|
assert token_limit == models_tokens["openai"]["gpt-4"]
|
|
|
|
|
|
def test_get_model_token_limit_unknown(mock_memory):
|
|
"""Test get_model_token_limit with unknown provider/model."""
|
|
config = {"provider": "unknown", "model": "unknown-model"}
|
|
|
|
token_limit = get_model_token_limit(config)
|
|
assert token_limit is None
|
|
|
|
|
|
def test_get_model_token_limit_missing_config(mock_memory):
|
|
"""Test get_model_token_limit with missing configuration."""
|
|
config = {}
|
|
|
|
token_limit = get_model_token_limit(config)
|
|
assert token_limit is None
|
|
|
|
|
|
def test_create_agent_anthropic(mock_model, mock_memory):
|
|
"""Test create_agent with Anthropic Claude model."""
|
|
mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"}
|
|
|
|
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
|
|
mock_react.return_value = "react_agent"
|
|
agent = create_agent(mock_model, [])
|
|
|
|
assert agent == "react_agent"
|
|
mock_react.assert_called_once_with(
|
|
mock_model, [], state_modifier=mock_react.call_args[1]["state_modifier"]
|
|
)
|
|
|
|
|
|
def test_create_agent_openai(mock_model, mock_memory):
|
|
"""Test create_agent with OpenAI model."""
|
|
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
|
|
|
|
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
|
mock_ciayn.return_value = "ciayn_agent"
|
|
agent = create_agent(mock_model, [])
|
|
|
|
assert agent == "ciayn_agent"
|
|
mock_ciayn.assert_called_once_with(
|
|
mock_model, [], max_tokens=models_tokens["openai"]["gpt-4"]
|
|
)
|
|
|
|
|
|
def test_create_agent_no_token_limit(mock_model, mock_memory):
|
|
"""Test create_agent when no token limit is found."""
|
|
mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"}
|
|
|
|
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
|
mock_ciayn.return_value = "ciayn_agent"
|
|
agent = create_agent(mock_model, [])
|
|
|
|
assert agent == "ciayn_agent"
|
|
mock_ciayn.assert_called_once_with(
|
|
mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT
|
|
)
|
|
|
|
|
|
def test_create_agent_missing_config(mock_model, mock_memory):
|
|
"""Test create_agent with missing configuration."""
|
|
mock_memory.get.return_value = {"provider": "openai"}
|
|
|
|
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
|
mock_ciayn.return_value = "ciayn_agent"
|
|
agent = create_agent(mock_model, [])
|
|
|
|
assert agent == "ciayn_agent"
|
|
mock_ciayn.assert_called_once_with(
|
|
mock_model,
|
|
[],
|
|
max_tokens=DEFAULT_TOKEN_LIMIT,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_messages():
|
|
"""Fixture providing mock message objects."""
|
|
|
|
return [
|
|
SystemMessage(content="System prompt"),
|
|
HumanMessage(content="Human message 1"),
|
|
AIMessage(content="AI response 1"),
|
|
HumanMessage(content="Human message 2"),
|
|
AIMessage(content="AI response 2"),
|
|
]
|
|
|
|
|
|
def test_state_modifier(mock_messages):
|
|
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
|
|
state = AgentState(messages=mock_messages)
|
|
|
|
with patch(
|
|
"ra_aid.agents.ciayn_agent.CiaynAgent._estimate_tokens"
|
|
) as mock_estimate:
|
|
mock_estimate.side_effect = lambda msg: 100 if msg else 0
|
|
|
|
result = state_modifier(state, max_tokens=250)
|
|
|
|
assert len(result) < len(mock_messages)
|
|
assert isinstance(result[0], SystemMessage)
|
|
assert result[-1] == mock_messages[-1]
|
|
|
|
|
|
def test_create_agent_with_checkpointer(mock_model, mock_memory):
|
|
"""Test create_agent with checkpointer argument."""
|
|
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
|
|
mock_checkpointer = Mock()
|
|
|
|
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
|
mock_ciayn.return_value = "ciayn_agent"
|
|
agent = create_agent(mock_model, [], checkpointer=mock_checkpointer)
|
|
|
|
assert agent == "ciayn_agent"
|
|
mock_ciayn.assert_called_once_with(
|
|
mock_model, [], max_tokens=models_tokens["openai"]["gpt-4"]
|
|
)
|
|
|
|
|
|
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
|
|
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
|
mock_memory.get.return_value = {
|
|
"provider": "anthropic",
|
|
"model": "claude-2",
|
|
"limit_tokens": True,
|
|
}
|
|
|
|
with (
|
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
|
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
|
):
|
|
mock_react.return_value = "react_agent"
|
|
mock_limit.return_value = 100000
|
|
|
|
agent = create_agent(mock_model, [])
|
|
|
|
assert agent == "react_agent"
|
|
args = mock_react.call_args
|
|
assert "state_modifier" in args[1]
|
|
assert callable(args[1]["state_modifier"])
|
|
|
|
|
|
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory):
|
|
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
|
mock_memory.get.return_value = {
|
|
"provider": "anthropic",
|
|
"model": "claude-2",
|
|
"limit_tokens": False,
|
|
}
|
|
|
|
with (
|
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
|
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
|
):
|
|
mock_react.return_value = "react_agent"
|
|
mock_limit.return_value = 100000
|
|
|
|
agent = create_agent(mock_model, [])
|
|
|
|
assert agent == "react_agent"
|
|
mock_react.assert_called_once_with(mock_model, [])
|