diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 5292317..f9a894c 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -14,8 +14,10 @@ from ra_aid.agent_context import ( from ra_aid.agent_utils import ( AgentState, create_agent, - get_model_token_limit, is_anthropic_claude, +) +from ra_aid.anthropic_token_limiter import ( + get_model_token_limit, state_modifier, ) from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params @@ -63,87 +65,15 @@ def mock_config_repository(): yield mock_repo -def test_get_model_token_limit_anthropic(mock_config_repository): - """Test get_model_token_limit with Anthropic model.""" - config = {"provider": "anthropic", "model": "claude2"} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] - - -def test_get_model_token_limit_openai(mock_config_repository): - """Test get_model_token_limit with OpenAI model.""" - config = {"provider": "openai", "model": "gpt-4"} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] - - -def test_get_model_token_limit_unknown(mock_config_repository): - """Test get_model_token_limit with unknown provider/model.""" - config = {"provider": "unknown", "model": "unknown-model"} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit is None - - -def test_get_model_token_limit_missing_config(mock_config_repository): - """Test get_model_token_limit with missing configuration.""" - config = {} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit is None - - -def test_get_model_token_limit_litellm_success(): - """Test get_model_token_limit successfully getting limit from litellm.""" - config = {"provider": "anthropic", "model": "claude-2"} - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.return_value = {"max_input_tokens": 100000} - token_limit = get_model_token_limit(config, "default") - assert token_limit == 100000 - - -def test_get_model_token_limit_litellm_not_found(): - """Test fallback to models_tokens when litellm raises NotFoundError.""" - config = {"provider": "anthropic", "model": "claude-2"} - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.side_effect = litellm.exceptions.NotFoundError( - message="Model not found", model="claude-2", llm_provider="anthropic" - ) - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] - - -def test_get_model_token_limit_litellm_error(): - """Test fallback to models_tokens when litellm raises other exceptions.""" - config = {"provider": "anthropic", "model": "claude-2"} - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.side_effect = Exception("Unknown error") - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] - - -def test_get_model_token_limit_unexpected_error(): - """Test returning None when unexpected errors occur.""" - config = None # This will cause an attribute error when accessed - - token_limit = get_model_token_limit(config, "default") - assert token_limit is None +# These tests have been moved to test_anthropic_token_limiter.py def test_create_agent_anthropic(mock_model, mock_config_repository): """Test create_agent with Anthropic Claude model.""" mock_config_repository.update({"provider": "anthropic", "model": "claude-2"}) - with patch("ra_aid.agent_utils.create_react_agent") as mock_react: + with patch("ra_aid.agent_utils.create_react_agent") as mock_react, \ + patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier: mock_react.return_value = "react_agent" agent = create_agent(mock_model, []) @@ -221,20 +151,7 @@ def mock_messages(): ] -def test_state_modifier(mock_messages): - """Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens.""" - state = AgentState(messages=mock_messages) - - with patch( - "ra_aid.agent_backends.ciayn_agent.CiaynAgent._estimate_tokens" - ) as mock_estimate: - mock_estimate.side_effect = lambda msg: 100 if msg else 0 - - result = state_modifier(state, max_input_tokens=250) - - assert len(result) < len(mock_messages) - assert isinstance(result[0], SystemMessage) - assert result[-1] == mock_messages[-1] +# This test has been moved to test_anthropic_token_limiter.py def test_create_agent_with_checkpointer(mock_model, mock_config_repository): @@ -265,7 +182,7 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, - patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, + patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit, ): mock_react.return_value = "react_agent" mock_limit.return_value = 100000 @@ -288,7 +205,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_ with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, - patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, + patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit, ): mock_react.return_value = "react_agent" mock_limit.return_value = 100000 @@ -299,36 +216,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_ mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2") -def test_get_model_token_limit_research(mock_config_repository): - """Test get_model_token_limit with research provider and model.""" - config = { - "provider": "openai", - "model": "gpt-4", - "research_provider": "anthropic", - "research_model": "claude-2", - } - mock_config_repository.update(config) - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.return_value = {"max_input_tokens": 150000} - token_limit = get_model_token_limit(config, "research") - assert token_limit == 150000 - - -def test_get_model_token_limit_planner(mock_config_repository): - """Test get_model_token_limit with planner provider and model.""" - config = { - "provider": "openai", - "model": "gpt-4", - "planner_provider": "deepseek", - "planner_model": "dsm-1", - } - mock_config_repository.update(config) - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.return_value = {"max_input_tokens": 120000} - token_limit = get_model_token_limit(config, "planner") - assert token_limit == 120000 +# These tests have been moved to test_anthropic_token_limiter.py # New tests for private helper methods in agent_utils.py diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py index 3f7e35e..3d0d9c3 100644 --- a/tests/ra_aid/test_anthropic_token_limiter.py +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -1,5 +1,6 @@ import unittest from unittest.mock import MagicMock, patch +import litellm from langchain_anthropic import ChatAnthropic from langchain_core.messages import ( @@ -19,6 +20,7 @@ from ra_aid.anthropic_token_limiter import ( convert_message_to_litellm_format ) from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair +from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT class TestAnthropicTokenLimiter(unittest.TestCase): @@ -140,6 +142,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase): # Verify print_messages_compact was called at least once self.assertTrue(mock_print.call_count >= 1) + def test_state_modifier_with_messages(self): + """Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens.""" + # Create a state with messages + messages = [ + SystemMessage(content="System prompt"), + HumanMessage(content="Human message 1"), + AIMessage(content="AI response 1"), + HumanMessage(content="Human message 2"), + AIMessage(content="AI response 2"), + ] + state = AgentState(messages=messages) + model = MagicMock(spec=ChatAnthropic) + model.model = "claude-3-opus-20240229" + + with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \ + patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \ + patch("ra_aid.anthropic_token_limiter.print_messages_compact"): + # Setup mock to return a fixed token count per message + mock_wrapper.return_value = lambda msgs: len(msgs) * 100 + # Setup mock to return a subset of messages + mock_trim.return_value = [messages[0], messages[-2], messages[-1]] + + result = state_modifier(state, model, max_input_tokens=250) + + # Should return what anthropic_trim_messages returned + self.assertEqual(len(result), 3) + self.assertEqual(result[0], messages[0]) # First message preserved + self.assertEqual(result[-1], messages[-1]) # Last message preserved + @patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens") @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate): @@ -191,6 +222,42 @@ class TestAnthropicTokenLimiter(unittest.TestCase): # Verify get_model_info was called with the right model mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}") + + def test_get_model_token_limit_research(self): + """Test get_model_token_limit with research provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "research_provider": "anthropic", + "research_model": "claude-2", + } + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.return_value = {"max_input_tokens": 150000} + token_limit = get_model_token_limit(config, "research") + self.assertEqual(token_limit, 150000) + # Verify get_model_info was called with the research model + mock_get_info.assert_called_with("anthropic/claude-2") + + def test_get_model_token_limit_planner(self): + """Test get_model_token_limit with planner provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "planner_provider": "deepseek", + "planner_model": "dsm-1", + } + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.return_value = {"max_input_tokens": 120000} + token_limit = get_model_token_limit(config, "planner") + self.assertEqual(token_limit, 120000) + # Verify get_model_info was called with the planner model + mock_get_info.assert_called_with("deepseek/dsm-1") @patch("ra_aid.anthropic_token_limiter.get_config_repository") @patch("litellm.get_model_info") @@ -252,6 +319,85 @@ class TestAnthropicTokenLimiter(unittest.TestCase): result = get_model_token_limit(mock_config, "planner") self.assertEqual(result, 100000) + def test_get_model_token_limit_anthropic(self): + """Test get_model_token_limit with Anthropic model.""" + config = {"provider": "anthropic", "model": "claude2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + + def test_get_model_token_limit_openai(self): + """Test get_model_token_limit with OpenAI model.""" + config = {"provider": "openai", "model": "gpt-4"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["openai"]["gpt-4"]["token_limit"]) + + def test_get_model_token_limit_unknown(self): + """Test get_model_token_limit with unknown provider/model.""" + config = {"provider": "unknown", "model": "unknown-model"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertIsNone(token_limit) + + def test_get_model_token_limit_missing_config(self): + """Test get_model_token_limit with missing configuration.""" + config = {} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertIsNone(token_limit) + + def test_get_model_token_limit_litellm_success(self): + """Test get_model_token_limit successfully getting limit from litellm.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.return_value = {"max_input_tokens": 100000} + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, 100000) + mock_get_info.assert_called_with("anthropic/claude-2") + + def test_get_model_token_limit_litellm_not_found(self): + """Test fallback to models_tokens when litellm raises NotFoundError.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.side_effect = litellm.exceptions.NotFoundError( + message="Model not found", model="claude-2", llm_provider="anthropic" + ) + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + + def test_get_model_token_limit_litellm_error(self): + """Test fallback to models_tokens when litellm raises other exceptions.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.side_effect = Exception("Unknown error") + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + + def test_get_model_token_limit_unexpected_error(self): + """Test returning None when unexpected errors occur.""" + config = None # This will cause an attribute error when accessed + + token_limit = get_model_token_limit(config, "default") + self.assertIsNone(token_limit) + def test_has_tool_use(self): """Test the has_tool_use function.""" # Test with regular AI message