diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index d613041..53f4c43 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -5,24 +5,8 @@ import sys import uuid from datetime import datetime -# Add litellm import import litellm -# Configure litellm to suppress debug logs -os.environ["LITELLM_LOG"] = "ERROR" -litellm.suppress_debug_info = True -litellm.set_verbose = False - -# Explicitly configure LiteLLM's loggers -for logger_name in ["litellm", "LiteLLM"]: - litellm_logger = logging.getLogger(logger_name) - litellm_logger.setLevel(logging.WARNING) - litellm_logger.propagate = True - -# Use litellm's internal method to disable debugging -if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"): - litellm._logging._disable_debugging() - from langgraph.checkpoint.memory import MemorySaver from rich.console import Console from rich.panel import Panel @@ -99,6 +83,21 @@ from ra_aid.tools.human import ask_human logger = get_logger(__name__) +# Configure litellm to suppress debug logs +os.environ["LITELLM_LOG"] = "ERROR" +litellm.suppress_debug_info = True +litellm.set_verbose = False + +# Explicitly configure LiteLLM's loggers +for logger_name in ["litellm", "LiteLLM"]: + litellm_logger = logging.getLogger(logger_name) + litellm_logger.setLevel(logging.WARNING) + litellm_logger.propagate = True + +# Use litellm's internal method to disable debugging +if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"): + litellm._logging._disable_debugging() + def launch_webui(host: str, port: int): """Launch the RA.Aid web interface.""" diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 0abb164..1bb2ccd 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -51,7 +51,12 @@ from ra_aid.database.repositories.human_input_repository import ( ) from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository from ra_aid.database.repositories.config_repository import get_config_repository -from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit +from ra_aid.anthropic_token_limiter import ( + sonnet_35_state_modifier, + state_modifier, + get_model_token_limit, +) +from ra_aid.model_detection import is_anthropic_claude console = Console() @@ -67,8 +72,6 @@ def output_markdown_message(message: str) -> str: return "Message output." - - def build_agent_kwargs( checkpointer: Optional[Any] = None, model: ChatAnthropic = None, @@ -99,8 +102,13 @@ def build_agent_kwargs( ): def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: - if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]): - return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens) + if any( + pattern in model.model + for pattern in ["claude-3.5", "claude3.5", "claude-3-5"] + ): + return sonnet_35_state_modifier( + state, max_input_tokens=max_input_tokens + ) return state_modifier(state, model, max_input_tokens=max_input_tokens) @@ -110,27 +118,6 @@ def build_agent_kwargs( return agent_kwargs -def is_anthropic_claude(config: Dict[str, Any]) -> bool: - """Check if the provider and model name indicate an Anthropic Claude model. - - Args: - config: Configuration dictionary containing provider and model information - - Returns: - bool: True if this is an Anthropic Claude model - """ - # For backwards compatibility, allow passing of config directly - provider = config.get("provider", "") - model_name = config.get("model", "") - result = ( - provider.lower() == "anthropic" - and model_name - and "claude" in model_name.lower() - ) or ( - provider.lower() == "openrouter" - and model_name.lower().startswith("anthropic/claude-") - ) - return result def create_agent( @@ -169,7 +156,7 @@ def create_agent( # So we'll use the passed config directly pass max_input_tokens = ( - get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT + get_model_token_limit(config, agent_type, model) or DEFAULT_TOKEN_LIMIT ) # Use REACT agent for Anthropic Claude models, otherwise use CIAYN @@ -188,7 +175,7 @@ def create_agent( # Default to REACT agent if provider/model detection fails logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") config = get_config_repository().get_all() - max_input_tokens = get_model_token_limit(config, agent_type) + max_input_tokens = get_model_token_limit(config, agent_type, model) agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens) return create_react_agent( model, tools, interrupt_after=["tools"], **agent_kwargs @@ -289,7 +276,7 @@ def _handle_api_error(e, attempt, max_retries, base_delay): logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e)) delay = base_delay * (2**attempt) error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})" - + # Record error in trajectory trajectory_repo = get_trajectory_repository() human_input_id = get_human_input_repository().get_most_recent_id() @@ -301,9 +288,9 @@ def _handle_api_error(e, attempt, max_retries, base_delay): record_type="error", human_input_id=human_input_id, is_error=True, - error_message=error_message + error_message=error_message, ) - + print_error(error_message) start = time.monotonic() while time.monotonic() - start < delay: @@ -464,7 +451,9 @@ def run_agent_with_retry( try: _run_agent_stream(agent, msg_list) - if fallback_handler and hasattr(fallback_handler, 'reset_fallback_handler'): + if fallback_handler and hasattr( + fallback_handler, "reset_fallback_handler" + ): fallback_handler.reset_fallback_handler() should_break, prompt, auto_test, test_attempts = ( _execute_test_command_wrapper( diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index 45a79d4..794459a 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -1,31 +1,27 @@ """Utilities for handling token limits with Anthropic models.""" from functools import partial -from typing import Any, Dict, List, Optional, Sequence -from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Tuple +from langchain_core.language_models import BaseChatModel +from ra_aid.model_detection import is_claude_37 from langchain_anthropic import ChatAnthropic from langchain_core.messages import ( - AIMessage, BaseMessage, - RemoveMessage, - ToolMessage, trim_messages, ) from langchain_core.messages.base import message_to_dict from ra_aid.anthropic_message_utils import ( anthropic_trim_messages, - has_tool_use, ) from langgraph.prebuilt.chat_agent_executor import AgentState -from litellm import token_counter +from litellm import token_counter, get_model_info from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params -from ra_aid.console.output import cpm, print_messages_compact logger = get_logger(__name__) @@ -168,14 +164,62 @@ def sonnet_35_state_modifier( return result +def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]: + """Get the provider and model name for the specified agent type. + + Args: + config: Configuration dictionary containing provider and model information + agent_type: Type of agent ("default", "research", or "planner") + + Returns: + Tuple[str, str]: A tuple containing (provider, model_name) + """ + if agent_type == "research": + provider = config.get("research_provider", "") or config.get("provider", "") + model_name = config.get("research_model", "") or config.get("model", "") + elif agent_type == "planner": + provider = config.get("planner_provider", "") or config.get("provider", "") + model_name = config.get("planner_model", "") or config.get("model", "") + else: + provider = config.get("provider", "") + model_name = config.get("model", "") + + return provider, model_name + + +def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]: + """Adjust token limit for Claude 3.7 models by subtracting max_tokens. + + Args: + max_input_tokens: The original token limit + model: The model instance to check + + Returns: + Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit + """ + if not max_input_tokens: + return max_input_tokens + + if model and hasattr(model, 'model') and is_claude_37(model.model): + if hasattr(model, 'max_tokens') and model.max_tokens: + effective_max_input_tokens = max_input_tokens - model.max_tokens + logger.debug( + f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}" + ) + return effective_max_input_tokens + + return max_input_tokens + + def get_model_token_limit( - config: Dict[str, Any], agent_type: str = "default" + config: Dict[str, Any], agent_type: str = "default", model: Optional[BaseChatModel] = None ) -> Optional[int]: """Get the token limit for the current model configuration based on agent type. Args: config: Configuration dictionary containing provider and model information agent_type: Type of agent ("default", "research", or "planner") + model: Optional BaseChatModel instance to check for model-specific attributes Returns: Optional[int]: The token limit if found, None otherwise @@ -190,27 +234,20 @@ def get_model_token_limit( # In tests, this may fail because the repository isn't set up # So we'll use the passed config directly pass - if agent_type == "research": - provider = config.get("research_provider", "") or config.get("provider", "") - model_name = config.get("research_model", "") or config.get("model", "") - elif agent_type == "planner": - provider = config.get("planner_provider", "") or config.get("provider", "") - model_name = config.get("planner_model", "") or config.get("model", "") - else: - provider = config.get("provider", "") - model_name = config.get("model", "") + + provider, model_name = get_provider_and_model_for_agent_type(config, agent_type) + + # Always attempt to get model info from litellm first + provider_model = model_name if not provider else f"{provider}/{model_name}" try: - from litellm import get_model_info - - provider_model = model_name if not provider else f"{provider}/{model_name}" model_info = get_model_info(provider_model) max_input_tokens = model_info.get("max_input_tokens") if max_input_tokens: logger.debug( f"Using litellm token limit for {model_name}: {max_input_tokens}" ) - return max_input_tokens + return adjust_claude_37_token_limit(max_input_tokens, model) except Exception as e: logger.debug( f"Error getting model info from litellm: {e}, falling back to models_params" @@ -229,7 +266,7 @@ def get_model_token_limit( max_input_tokens = None logger.debug(f"Could not find token limit for {provider}/{model_name}") - return max_input_tokens + return adjust_claude_37_token_limit(max_input_tokens, model) except Exception as e: logger.warning(f"Failed to get model token limit: {e}") diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 56ec811..3f60d6f 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -10,6 +10,7 @@ from openai import OpenAI from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.console.output import cpm from ra_aid.logging_config import get_logger +from ra_aid.model_detection import is_claude_37 from .models_params import models_params @@ -218,7 +219,6 @@ def create_llm_client( is_expert, ) - # Get model configuration model_config = models_params.get(provider, {}).get(model_name, {}) # Default to True for known providers that support temperature if not specified @@ -228,6 +228,10 @@ def create_llm_client( supports_temperature = model_config["supports_temperature"] supports_thinking = model_config.get("supports_thinking", False) + other_kwargs = {} + if is_claude_37(model_name): + other_kwargs = {"max_tokens": 64000} + # Handle temperature settings if is_expert: temp_kwargs = {"temperature": 0} if supports_temperature else {} @@ -235,22 +239,26 @@ def create_llm_client( if temperature is None: temperature = 0.7 # Import repository classes directly to avoid circular imports - from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository - from ra_aid.database.repositories.human_input_repository import HumanInputRepository + from ra_aid.database.repositories.trajectory_repository import ( + TrajectoryRepository, + ) + from ra_aid.database.repositories.human_input_repository import ( + HumanInputRepository, + ) from ra_aid.database.connection import get_db - + # Create repositories directly trajectory_repo = TrajectoryRepository(get_db()) human_input_repo = HumanInputRepository(get_db()) human_input_id = human_input_repo.get_most_recent_id() - + trajectory_repo.create( step_data={ "message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.", "display_title": "Information", }, record_type="info", - human_input_id=human_input_id + human_input_id=human_input_id, ) cpm( "This model supports temperature argument but none was given. Setting default temperature to 0.7." @@ -302,9 +310,9 @@ def create_llm_client( model_name=model_name, timeout=LLM_REQUEST_TIMEOUT, max_retries=LLM_MAX_RETRIES, - max_tokens=model_config.get("max_tokens", 64000), **temp_kwargs, **thinking_kwargs, + **other_kwargs, ) elif provider == "openai-compatible": return ChatOpenAI( diff --git a/ra_aid/model_detection.py b/ra_aid/model_detection.py new file mode 100644 index 0000000..ca45c51 --- /dev/null +++ b/ra_aid/model_detection.py @@ -0,0 +1,39 @@ +"""Utilities for detecting and working with specific model types.""" + +from typing import Optional, Dict, Any + + +def is_claude_37(model: str) -> bool: + """Check if the model is a Claude 3.7 model. + + Args: + model: The model name to check + + Returns: + bool: True if the model is a Claude 3.7 model, False otherwise + """ + patterns = ["claude-3.7", "claude3.7", "claude-3-7"] + return any(pattern in model for pattern in patterns) + + +def is_anthropic_claude(config: Dict[str, Any]) -> bool: + """Check if the provider and model name indicate an Anthropic Claude model. + + Args: + config: Configuration dictionary containing provider and model information + + Returns: + bool: True if this is an Anthropic Claude model + """ + # For backwards compatibility, allow passing of config directly + provider = config.get("provider", "") + model_name = config.get("model", "") + result = ( + provider.lower() == "anthropic" + and model_name + and "claude" in model_name.lower() + ) or ( + provider.lower() == "openrouter" + and model_name.lower().startswith("anthropic/claude-") + ) + return result diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index ff978a8..1d7ed50 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -13,8 +13,7 @@ from ra_aid.agent_context import ( ) from ra_aid.agent_utils import ( AgentState, - create_agent, - is_anthropic_claude, + create_agent ) from ra_aid.anthropic_token_limiter import ( get_model_token_limit, @@ -453,31 +452,6 @@ def test_handle_api_error_retry(monkeypatch): _handle_api_error(Exception("error code 429"), 0, 5, 1) -def test_is_anthropic_claude(): - """Test is_anthropic_claude function with various configurations.""" - # Test Anthropic provider cases - assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"}) - assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"}) - assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"}) - - # Test OpenRouter provider cases - assert is_anthropic_claude( - {"provider": "openrouter", "model": "anthropic/claude-2"} - ) - assert is_anthropic_claude( - {"provider": "openrouter", "model": "anthropic/claude-instant"} - ) - assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"}) - - # Test edge cases - assert not is_anthropic_claude({}) # Empty config - assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model - assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider - assert not is_anthropic_claude( - {"provider": "other", "model": "claude-2"} - ) # Wrong provider - - def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repository): """Test that run_agent_with_retry checks for crash status at the beginning of each iteration.""" from ra_aid.agent_context import agent_context, mark_agent_crashed diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py index 36f9528..a3ac861 100644 --- a/tests/ra_aid/test_anthropic_token_limiter.py +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -17,10 +17,11 @@ from ra_aid.anthropic_token_limiter import ( get_model_token_limit, state_modifier, sonnet_35_state_modifier, - convert_message_to_litellm_format + convert_message_to_litellm_format, + adjust_claude_37_token_limit ) from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair -from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT +from ra_aid.models_params import models_params class TestAnthropicTokenLimiter(unittest.TestCase): @@ -113,9 +114,8 @@ class TestAnthropicTokenLimiter(unittest.TestCase): self.assertEqual(result, 0) @patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") - @patch("ra_aid.anthropic_token_limiter.print_messages_compact") @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") - def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper): + def test_state_modifier(self, mock_trim_messages, mock_create_wrapper): # Setup a proper token counter function that returns integers def token_counter(msgs): # Return token count based on number of messages @@ -155,8 +155,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase): model.model = "claude-3-opus-20240229" with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \ - patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \ - patch("ra_aid.anthropic_token_limiter.print_messages_compact"): + patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim: # Setup mock to return a fixed token count per message mock_wrapper.return_value = lambda msgs: len(msgs) * 100 # Setup mock to return a subset of messages @@ -206,23 +205,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase): self.assertEqual(call_args["include_system"], True) @patch("ra_aid.anthropic_token_limiter.get_config_repository") - @patch("litellm.get_model_info") - def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo): - from ra_aid.config import DEFAULT_MODEL + @patch("ra_aid.anthropic_token_limiter.get_model_info") + @patch("ra_aid.anthropic_token_limiter.is_claude_37") + @patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") + def test_get_model_token_limit_from_litellm(self, mock_adjust, mock_is_claude_37, mock_get_model_info, mock_get_config_repo): + # Use a specific model name instead of DEFAULT_MODEL to avoid test dependency + model_name = "claude-3-7-sonnet-20250219" # Setup mocks - mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL} + mock_config = {"provider": "anthropic", "model": model_name} mock_get_config_repo.return_value.get_all.return_value = mock_config # Mock litellm's get_model_info to return a token limit mock_get_model_info.return_value = {"max_input_tokens": 100000} + # Mock is_claude_37 to return True + mock_is_claude_37.return_value = True + + # Mock adjust_claude_37_token_limit to return the original value + mock_adjust.return_value = 100000 + # Test getting token limit result = get_model_token_limit(mock_config) self.assertEqual(result, 100000) # Verify get_model_info was called with the right model - mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}") + mock_get_model_info.assert_called_once_with(f"anthropic/{model_name}") + + # Verify adjust_claude_37_token_limit was called + mock_adjust.assert_called_once_with(100000, None) def test_get_model_token_limit_research(self): """Test get_model_token_limit with research provider and model.""" @@ -230,17 +241,24 @@ class TestAnthropicTokenLimiter(unittest.TestCase): "provider": "openai", "model": "gpt-4", "research_provider": "anthropic", - "research_model": "claude-2", + "research_model": "claude-3-7-sonnet-20250219", } with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ - patch("litellm.get_model_info") as mock_get_info: + patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: mock_get_config_repo.return_value.get_all.return_value = config mock_get_info.return_value = {"max_input_tokens": 150000} + mock_adjust.return_value = 150000 + + # Call the function to check the return value token_limit = get_model_token_limit(config, "research") self.assertEqual(token_limit, 150000) + # Verify get_model_info was called with the research model - mock_get_info.assert_called_with("anthropic/claude-2") + mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219") + # Verify adjust_claude_37_token_limit was called + mock_adjust.assert_called_once_with(150000, None) def test_get_model_token_limit_planner(self): """Test get_model_token_limit with planner provider and model.""" @@ -252,16 +270,23 @@ class TestAnthropicTokenLimiter(unittest.TestCase): } with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ - patch("litellm.get_model_info") as mock_get_info: + patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: mock_get_config_repo.return_value.get_all.return_value = config mock_get_info.return_value = {"max_input_tokens": 120000} + mock_adjust.return_value = 120000 + + # Call the function to check the return value token_limit = get_model_token_limit(config, "planner") self.assertEqual(token_limit, 120000) + # Verify get_model_info was called with the planner model - mock_get_info.assert_called_with("deepseek/dsm-1") + mock_get_info.assert_called_once_with("deepseek/dsm-1") + # Verify adjust_claude_37_token_limit was called + mock_adjust.assert_called_once_with(120000, None) @patch("ra_aid.anthropic_token_limiter.get_config_repository") - @patch("litellm.get_model_info") + @patch("ra_aid.anthropic_token_limiter.get_model_info") def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo): # Setup mocks mock_config = {"provider": "anthropic", "model": "claude-2"} @@ -280,54 +305,87 @@ class TestAnthropicTokenLimiter(unittest.TestCase): self.assertEqual(result, 100000) @patch("ra_aid.anthropic_token_limiter.get_config_repository") - @patch("litellm.get_model_info") - def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo): - from ra_aid.config import DEFAULT_MODEL + @patch("ra_aid.anthropic_token_limiter.get_model_info") + @patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") + def test_get_model_token_limit_for_different_agent_types(self, mock_adjust, mock_get_model_info, mock_get_config_repo): + # Use specific model names instead of DEFAULT_MODEL to avoid test dependency + claude_model = "claude-3-7-sonnet-20250219" # Setup mocks for different agent types mock_config = { "provider": "anthropic", - "model": DEFAULT_MODEL, + "model": claude_model, "research_provider": "openai", "research_model": "gpt-4", "planner_provider": "anthropic", - "planner_model": "claude-3-sonnet-20240229" + "planner_model": "claude-3-7-opus-20250301" } mock_get_config_repo.return_value.get_all.return_value = mock_config # Mock different returns for different models def model_info_side_effect(model_name): - if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name: + if "claude-3-7-sonnet" in model_name: return {"max_input_tokens": 200000} elif "gpt-4" in model_name: return {"max_input_tokens": 8192} - elif "claude-3-sonnet" in model_name: - return {"max_input_tokens": 100000} + elif "claude-3-7-opus" in model_name: + return {"max_input_tokens": 250000} else: raise Exception(f"Unknown model: {model_name}") mock_get_model_info.side_effect = model_info_side_effect + # Mock adjust_claude_37_token_limit to return the same values + mock_adjust.side_effect = lambda tokens, model: tokens + # Test default agent type result = get_model_token_limit(mock_config, "default") self.assertEqual(result, 200000) + mock_get_model_info.assert_called_with(f"anthropic/{claude_model}") + + # Reset mock + mock_get_model_info.reset_mock() # Test research agent type result = get_model_token_limit(mock_config, "research") self.assertEqual(result, 8192) + mock_get_model_info.assert_called_with("openai/gpt-4") + + # Reset mock + mock_get_model_info.reset_mock() # Test planner agent type result = get_model_token_limit(mock_config, "planner") - self.assertEqual(result, 100000) + self.assertEqual(result, 250000) + mock_get_model_info.assert_called_with("anthropic/claude-3-7-opus-20250301") def test_get_model_token_limit_anthropic(self): """Test get_model_token_limit with Anthropic model.""" - config = {"provider": "anthropic", "model": "claude2"} + config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"} - with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \ + patch("litellm.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: + + # Setup mocks mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.side_effect = Exception("Model not found") + + # Create a mock models_params with claude-3-7 + mock_models_params_dict = { + "anthropic": { + "claude-3-7-sonnet-20250219": {"token_limit": 200000} + } + } + mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__ + mock_models_params.get.side_effect = mock_models_params_dict.get + + # Mock adjust to return the same value + mock_adjust.return_value = 200000 + token_limit = get_model_token_limit(config, "default") - self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + self.assertEqual(token_limit, 200000) def test_get_model_token_limit_openai(self): """Test get_model_token_limit with OpenAI model.""" @@ -358,28 +416,51 @@ class TestAnthropicTokenLimiter(unittest.TestCase): def test_get_model_token_limit_litellm_success(self): """Test get_model_token_limit successfully getting limit from litellm.""" - config = {"provider": "anthropic", "model": "claude-2"} + config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"} with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ - patch("litellm.get_model_info") as mock_get_info: + patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: mock_get_config_repo.return_value.get_all.return_value = config mock_get_info.return_value = {"max_input_tokens": 100000} + mock_adjust.return_value = 100000 + + # Call the function to check the return value token_limit = get_model_token_limit(config, "default") self.assertEqual(token_limit, 100000) - mock_get_info.assert_called_with("anthropic/claude-2") + + # Verify get_model_info was called with the right model + mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219") + mock_adjust.assert_called_once_with(100000, None) def test_get_model_token_limit_litellm_not_found(self): """Test fallback to models_tokens when litellm raises NotFoundError.""" - config = {"provider": "anthropic", "model": "claude-2"} + config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"} with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ - patch("litellm.get_model_info") as mock_get_info: + patch("litellm.get_model_info") as mock_get_info, \ + patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \ + patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust: + mock_get_config_repo.return_value.get_all.return_value = config mock_get_info.side_effect = litellm.exceptions.NotFoundError( - message="Model not found", model="claude-2", llm_provider="anthropic" + message="Model not found", model="claude-3-7-sonnet-20250219", llm_provider="anthropic" ) + + # Create a mock models_params with claude-3-7 + mock_models_params_dict = { + "anthropic": { + "claude-3-7-sonnet-20250219": {"token_limit": 200000} + } + } + mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__ + mock_models_params.get.side_effect = mock_models_params_dict.get + + # Mock adjust to return the same value + mock_adjust.return_value = 200000 + token_limit = get_model_token_limit(config, "default") - self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + self.assertEqual(token_limit, 200000) def test_get_model_token_limit_litellm_error(self): """Test fallback to models_tokens when litellm raises other exceptions.""" @@ -399,6 +480,30 @@ class TestAnthropicTokenLimiter(unittest.TestCase): token_limit = get_model_token_limit(config, "default") self.assertIsNone(token_limit) + def test_adjust_claude_37_token_limit(self): + """Test adjust_claude_37_token_limit function.""" + # Create a mock model + mock_model = MagicMock() + mock_model.model = "claude-3.7-sonnet" + mock_model.max_tokens = 4096 + + # Test with Claude 3.7 model + result = adjust_claude_37_token_limit(100000, mock_model) + self.assertEqual(result, 95904) # 100000 - 4096 + + # Test with non-Claude 3.7 model + mock_model.model = "claude-3-opus" + result = adjust_claude_37_token_limit(100000, mock_model) + self.assertEqual(result, 100000) # No adjustment + + # Test with None max_input_tokens + result = adjust_claude_37_token_limit(None, mock_model) + self.assertIsNone(result) + + # Test with None model + result = adjust_claude_37_token_limit(100000, None) + self.assertEqual(result, 100000) + def test_has_tool_use(self): """Test the has_tool_use function.""" # Test with regular AI message diff --git a/tests/ra_aid/test_model_detection.py b/tests/ra_aid/test_model_detection.py new file mode 100644 index 0000000..4127c50 --- /dev/null +++ b/tests/ra_aid/test_model_detection.py @@ -0,0 +1,50 @@ +"""Unit tests for model_detection.py.""" + +import pytest +from ra_aid.model_detection import is_anthropic_claude, is_claude_37 + + +def test_is_anthropic_claude(): + """Test is_anthropic_claude function with various configurations.""" + # Test Anthropic provider cases + assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"}) + assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"}) + assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"}) + + # Test OpenRouter provider cases + assert is_anthropic_claude( + {"provider": "openrouter", "model": "anthropic/claude-2"} + ) + assert is_anthropic_claude( + {"provider": "openrouter", "model": "anthropic/claude-instant"} + ) + assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"}) + + # Test edge cases + assert not is_anthropic_claude({}) # Empty config + assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model + assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider + assert not is_anthropic_claude( + {"provider": "other", "model": "claude-2"} + ) # Wrong provider + + +def test_is_claude_37(): + """Test is_claude_37 function with various model names.""" + # Test positive cases + assert is_claude_37("claude-3.7") + assert is_claude_37("claude3.7") + assert is_claude_37("claude-3-7") + assert is_claude_37("anthropic/claude-3.7") + assert is_claude_37("anthropic/claude3.7") + assert is_claude_37("anthropic/claude-3-7") + assert is_claude_37("claude-3.7-sonnet") + assert is_claude_37("claude3.7-haiku") + + # Test negative cases + assert not is_claude_37("claude-3") + assert not is_claude_37("claude-3.5") + assert not is_claude_37("claude3.5") + assert not is_claude_37("claude-3-5") + assert not is_claude_37("gpt-4") + assert not is_claude_37("")