"""Unit tests for agent_utils.py.""" from unittest.mock import Mock, patch import litellm import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from ra_aid.agent_utils import ( AgentState, create_agent, get_model_token_limit, state_modifier, ) from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params @pytest.fixture def mock_model(): """Fixture providing a mock LLM model.""" model = Mock(spec=BaseChatModel) return model @pytest.fixture def mock_memory(): """Fixture providing a mock global memory store.""" with patch("ra_aid.agent_utils._global_memory") as mock_mem: mock_mem.get.return_value = {} yield mock_mem def test_get_model_token_limit_anthropic(mock_memory): """Test get_model_token_limit with Anthropic model.""" config = {"provider": "anthropic", "model": "claude2"} token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] def test_get_model_token_limit_openai(mock_memory): """Test get_model_token_limit with OpenAI model.""" config = {"provider": "openai", "model": "gpt-4"} token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] def test_get_model_token_limit_unknown(mock_memory): """Test get_model_token_limit with unknown provider/model.""" config = {"provider": "unknown", "model": "unknown-model"} token_limit = get_model_token_limit(config, "default") assert token_limit is None def test_get_model_token_limit_missing_config(mock_memory): """Test get_model_token_limit with missing configuration.""" config = {} token_limit = get_model_token_limit(config, "default") assert token_limit is None def test_get_model_token_limit_litellm_success(): """Test get_model_token_limit successfully getting limit from litellm.""" config = {"provider": "anthropic", "model": "claude-2"} with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.return_value = {"max_input_tokens": 100000} token_limit = get_model_token_limit(config, "default") assert token_limit == 100000 def test_get_model_token_limit_litellm_not_found(): """Test fallback to models_tokens when litellm raises NotFoundError.""" config = {"provider": "anthropic", "model": "claude-2"} with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.side_effect = litellm.exceptions.NotFoundError( message="Model not found", model="claude-2", llm_provider="anthropic" ) token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] def test_get_model_token_limit_litellm_error(): """Test fallback to models_tokens when litellm raises other exceptions.""" config = {"provider": "anthropic", "model": "claude-2"} with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.side_effect = Exception("Unknown error") token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] def test_get_model_token_limit_unexpected_error(): """Test returning None when unexpected errors occur.""" config = None # This will cause an attribute error when accessed token_limit = get_model_token_limit(config, "default") assert token_limit is None def test_create_agent_anthropic(mock_model, mock_memory): """Test create_agent with Anthropic Claude model.""" mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"} with patch("ra_aid.agent_utils.create_react_agent") as mock_react: mock_react.return_value = "react_agent" agent = create_agent(mock_model, []) assert agent == "react_agent" mock_react.assert_called_once_with( mock_model, [], state_modifier=mock_react.call_args[1]["state_modifier"] ) def test_create_agent_openai(mock_model, mock_memory): """Test create_agent with OpenAI model.""" mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"} with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: mock_ciayn.return_value = "ciayn_agent" agent = create_agent(mock_model, []) assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( mock_model, [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"] ) def test_create_agent_no_token_limit(mock_model, mock_memory): """Test create_agent when no token limit is found.""" mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"} with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: mock_ciayn.return_value = "ciayn_agent" agent = create_agent(mock_model, []) assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT ) def test_create_agent_missing_config(mock_model, mock_memory): """Test create_agent with missing configuration.""" mock_memory.get.return_value = {"provider": "openai"} with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: mock_ciayn.return_value = "ciayn_agent" agent = create_agent(mock_model, []) assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT, ) @pytest.fixture def mock_messages(): """Fixture providing mock message objects.""" return [ SystemMessage(content="System prompt"), HumanMessage(content="Human message 1"), AIMessage(content="AI response 1"), HumanMessage(content="Human message 2"), AIMessage(content="AI response 2"), ] def test_state_modifier(mock_messages): """Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens.""" state = AgentState(messages=mock_messages) with patch( "ra_aid.agents.ciayn_agent.CiaynAgent._estimate_tokens" ) as mock_estimate: mock_estimate.side_effect = lambda msg: 100 if msg else 0 result = state_modifier(state, max_input_tokens=250) assert len(result) < len(mock_messages) assert isinstance(result[0], SystemMessage) assert result[-1] == mock_messages[-1] def test_create_agent_with_checkpointer(mock_model, mock_memory): """Test create_agent with checkpointer argument.""" mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"} mock_checkpointer = Mock() with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: mock_ciayn.return_value = "ciayn_agent" agent = create_agent(mock_model, [], checkpointer=mock_checkpointer) assert agent == "ciayn_agent" mock_ciayn.assert_called_once_with( mock_model, [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"] ) def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory): """Test create_agent sets up token limiting for Claude models when enabled.""" mock_memory.get.return_value = { "provider": "anthropic", "model": "claude-2", "limit_tokens": True, } with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, ): mock_react.return_value = "react_agent" mock_limit.return_value = 100000 agent = create_agent(mock_model, []) assert agent == "react_agent" args = mock_react.call_args assert "state_modifier" in args[1] assert callable(args[1]["state_modifier"]) def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory): """Test create_agent doesn't set up token limiting for Claude models when disabled.""" mock_memory.get.return_value = { "provider": "anthropic", "model": "claude-2", "limit_tokens": False, } with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, ): mock_react.return_value = "react_agent" mock_limit.return_value = 100000 agent = create_agent(mock_model, []) assert agent == "react_agent" mock_react.assert_called_once_with(mock_model, []) def test_get_model_token_limit_research(mock_memory): """Test get_model_token_limit with research provider and model.""" config = { "provider": "openai", "model": "gpt-4", "research_provider": "anthropic", "research_model": "claude-2", } with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.return_value = {"max_input_tokens": 150000} token_limit = get_model_token_limit(config, "research") assert token_limit == 150000 def test_get_model_token_limit_planner(mock_memory): """Test get_model_token_limit with planner provider and model.""" config = { "provider": "openai", "model": "gpt-4", "planner_provider": "deepseek", "planner_model": "dsm-1", } with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.return_value = {"max_input_tokens": 120000} token_limit = get_model_token_limit(config, "planner") assert token_limit == 120000