Merge pull request #132 from ariel-frischer/fix-token-limiter-2

Fix Sonnet 3.7 Token Limiter - Adjust Effective Max Input Tokens
This commit is contained in:
Andrew I. Christianson 2025-03-14 16:42:39 -04:00 committed by GitHub
commit aaf09c5df6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 342 additions and 141 deletions

View File

@ -5,24 +5,8 @@ import sys
import uuid import uuid
from datetime import datetime from datetime import datetime
# Add litellm import
import litellm import litellm
# Configure litellm to suppress debug logs
os.environ["LITELLM_LOG"] = "ERROR"
litellm.suppress_debug_info = True
litellm.set_verbose = False
# Explicitly configure LiteLLM's loggers
for logger_name in ["litellm", "LiteLLM"]:
litellm_logger = logging.getLogger(logger_name)
litellm_logger.setLevel(logging.WARNING)
litellm_logger.propagate = True
# Use litellm's internal method to disable debugging
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
litellm._logging._disable_debugging()
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
@ -99,6 +83,21 @@ from ra_aid.tools.human import ask_human
logger = get_logger(__name__) logger = get_logger(__name__)
# Configure litellm to suppress debug logs
os.environ["LITELLM_LOG"] = "ERROR"
litellm.suppress_debug_info = True
litellm.set_verbose = False
# Explicitly configure LiteLLM's loggers
for logger_name in ["litellm", "LiteLLM"]:
litellm_logger = logging.getLogger(logger_name)
litellm_logger.setLevel(logging.WARNING)
litellm_logger.propagate = True
# Use litellm's internal method to disable debugging
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
litellm._logging._disable_debugging()
def launch_webui(host: str, port: int): def launch_webui(host: str, port: int):
"""Launch the RA.Aid web interface.""" """Launch the RA.Aid web interface."""

View File

@ -51,7 +51,12 @@ from ra_aid.database.repositories.human_input_repository import (
) )
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit from ra_aid.anthropic_token_limiter import (
sonnet_35_state_modifier,
state_modifier,
get_model_token_limit,
)
from ra_aid.model_detection import is_anthropic_claude
console = Console() console = Console()
@ -67,8 +72,6 @@ def output_markdown_message(message: str) -> str:
return "Message output." return "Message output."
def build_agent_kwargs( def build_agent_kwargs(
checkpointer: Optional[Any] = None, checkpointer: Optional[Any] = None,
model: ChatAnthropic = None, model: ChatAnthropic = None,
@ -99,8 +102,13 @@ def build_agent_kwargs(
): ):
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]): if any(
return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens) pattern in model.model
for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]
):
return sonnet_35_state_modifier(
state, max_input_tokens=max_input_tokens
)
return state_modifier(state, model, max_input_tokens=max_input_tokens) return state_modifier(state, model, max_input_tokens=max_input_tokens)
@ -110,27 +118,6 @@ def build_agent_kwargs(
return agent_kwargs return agent_kwargs
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
"""Check if the provider and model name indicate an Anthropic Claude model.
Args:
config: Configuration dictionary containing provider and model information
Returns:
bool: True if this is an Anthropic Claude model
"""
# For backwards compatibility, allow passing of config directly
provider = config.get("provider", "")
model_name = config.get("model", "")
result = (
provider.lower() == "anthropic"
and model_name
and "claude" in model_name.lower()
) or (
provider.lower() == "openrouter"
and model_name.lower().startswith("anthropic/claude-")
)
return result
def create_agent( def create_agent(
@ -169,7 +156,7 @@ def create_agent(
# So we'll use the passed config directly # So we'll use the passed config directly
pass pass
max_input_tokens = ( max_input_tokens = (
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT get_model_token_limit(config, agent_type, model) or DEFAULT_TOKEN_LIMIT
) )
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN # Use REACT agent for Anthropic Claude models, otherwise use CIAYN
@ -188,7 +175,7 @@ def create_agent(
# Default to REACT agent if provider/model detection fails # Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = get_config_repository().get_all() config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type) max_input_tokens = get_model_token_limit(config, agent_type, model)
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens) agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
return create_react_agent( return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs model, tools, interrupt_after=["tools"], **agent_kwargs
@ -289,7 +276,7 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e)) logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
delay = base_delay * (2**attempt) delay = base_delay * (2**attempt)
error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})" error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
# Record error in trajectory # Record error in trajectory
trajectory_repo = get_trajectory_repository() trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id() human_input_id = get_human_input_repository().get_most_recent_id()
@ -301,9 +288,9 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
record_type="error", record_type="error",
human_input_id=human_input_id, human_input_id=human_input_id,
is_error=True, is_error=True,
error_message=error_message error_message=error_message,
) )
print_error(error_message) print_error(error_message)
start = time.monotonic() start = time.monotonic()
while time.monotonic() - start < delay: while time.monotonic() - start < delay:
@ -464,7 +451,9 @@ def run_agent_with_retry(
try: try:
_run_agent_stream(agent, msg_list) _run_agent_stream(agent, msg_list)
if fallback_handler and hasattr(fallback_handler, 'reset_fallback_handler'): if fallback_handler and hasattr(
fallback_handler, "reset_fallback_handler"
):
fallback_handler.reset_fallback_handler() fallback_handler.reset_fallback_handler()
should_break, prompt, auto_test, test_attempts = ( should_break, prompt, auto_test, test_attempts = (
_execute_test_command_wrapper( _execute_test_command_wrapper(

View File

@ -1,31 +1,27 @@
"""Utilities for handling token limits with Anthropic models.""" """Utilities for handling token limits with Anthropic models."""
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Sequence from typing import Any, Dict, List, Optional, Sequence, Tuple
from dataclasses import dataclass from langchain_core.language_models import BaseChatModel
from ra_aid.model_detection import is_claude_37
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage,
BaseMessage, BaseMessage,
RemoveMessage,
ToolMessage,
trim_messages, trim_messages,
) )
from langchain_core.messages.base import message_to_dict from langchain_core.messages.base import message_to_dict
from ra_aid.anthropic_message_utils import ( from ra_aid.anthropic_message_utils import (
anthropic_trim_messages, anthropic_trim_messages,
has_tool_use,
) )
from langgraph.prebuilt.chat_agent_executor import AgentState from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import token_counter from litellm import token_counter, get_model_info
from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.console.output import cpm, print_messages_compact
logger = get_logger(__name__) logger = get_logger(__name__)
@ -168,14 +164,62 @@ def sonnet_35_state_modifier(
return result return result
def get_provider_and_model_for_agent_type(config: Dict[str, Any], agent_type: str) -> Tuple[str, str]:
"""Get the provider and model name for the specified agent type.
Args:
config: Configuration dictionary containing provider and model information
agent_type: Type of agent ("default", "research", or "planner")
Returns:
Tuple[str, str]: A tuple containing (provider, model_name)
"""
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "")
model_name = config.get("model", "")
return provider, model_name
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
Args:
max_input_tokens: The original token limit
model: The model instance to check
Returns:
Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit
"""
if not max_input_tokens:
return max_input_tokens
if model and hasattr(model, 'model') and is_claude_37(model.model):
if hasattr(model, 'max_tokens') and model.max_tokens:
effective_max_input_tokens = max_input_tokens - model.max_tokens
logger.debug(
f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}"
)
return effective_max_input_tokens
return max_input_tokens
def get_model_token_limit( def get_model_token_limit(
config: Dict[str, Any], agent_type: str = "default" config: Dict[str, Any], agent_type: str = "default", model: Optional[BaseChatModel] = None
) -> Optional[int]: ) -> Optional[int]:
"""Get the token limit for the current model configuration based on agent type. """Get the token limit for the current model configuration based on agent type.
Args: Args:
config: Configuration dictionary containing provider and model information config: Configuration dictionary containing provider and model information
agent_type: Type of agent ("default", "research", or "planner") agent_type: Type of agent ("default", "research", or "planner")
model: Optional BaseChatModel instance to check for model-specific attributes
Returns: Returns:
Optional[int]: The token limit if found, None otherwise Optional[int]: The token limit if found, None otherwise
@ -190,27 +234,20 @@ def get_model_token_limit(
# In tests, this may fail because the repository isn't set up # In tests, this may fail because the repository isn't set up
# So we'll use the passed config directly # So we'll use the passed config directly
pass pass
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "") provider, model_name = get_provider_and_model_for_agent_type(config, agent_type)
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner": # Always attempt to get model info from litellm first
provider = config.get("planner_provider", "") or config.get("provider", "") provider_model = model_name if not provider else f"{provider}/{model_name}"
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "")
model_name = config.get("model", "")
try: try:
from litellm import get_model_info
provider_model = model_name if not provider else f"{provider}/{model_name}"
model_info = get_model_info(provider_model) model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens") max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens: if max_input_tokens:
logger.debug( logger.debug(
f"Using litellm token limit for {model_name}: {max_input_tokens}" f"Using litellm token limit for {model_name}: {max_input_tokens}"
) )
return max_input_tokens return adjust_claude_37_token_limit(max_input_tokens, model)
except Exception as e: except Exception as e:
logger.debug( logger.debug(
f"Error getting model info from litellm: {e}, falling back to models_params" f"Error getting model info from litellm: {e}, falling back to models_params"
@ -229,7 +266,7 @@ def get_model_token_limit(
max_input_tokens = None max_input_tokens = None
logger.debug(f"Could not find token limit for {provider}/{model_name}") logger.debug(f"Could not find token limit for {provider}/{model_name}")
return max_input_tokens return adjust_claude_37_token_limit(max_input_tokens, model)
except Exception as e: except Exception as e:
logger.warning(f"Failed to get model token limit: {e}") logger.warning(f"Failed to get model token limit: {e}")

View File

@ -10,6 +10,7 @@ from openai import OpenAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
from ra_aid.console.output import cpm from ra_aid.console.output import cpm
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from ra_aid.model_detection import is_claude_37
from .models_params import models_params from .models_params import models_params
@ -218,7 +219,6 @@ def create_llm_client(
is_expert, is_expert,
) )
# Get model configuration
model_config = models_params.get(provider, {}).get(model_name, {}) model_config = models_params.get(provider, {}).get(model_name, {})
# Default to True for known providers that support temperature if not specified # Default to True for known providers that support temperature if not specified
@ -228,6 +228,10 @@ def create_llm_client(
supports_temperature = model_config["supports_temperature"] supports_temperature = model_config["supports_temperature"]
supports_thinking = model_config.get("supports_thinking", False) supports_thinking = model_config.get("supports_thinking", False)
other_kwargs = {}
if is_claude_37(model_name):
other_kwargs = {"max_tokens": 64000}
# Handle temperature settings # Handle temperature settings
if is_expert: if is_expert:
temp_kwargs = {"temperature": 0} if supports_temperature else {} temp_kwargs = {"temperature": 0} if supports_temperature else {}
@ -235,22 +239,26 @@ def create_llm_client(
if temperature is None: if temperature is None:
temperature = 0.7 temperature = 0.7
# Import repository classes directly to avoid circular imports # Import repository classes directly to avoid circular imports
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository from ra_aid.database.repositories.trajectory_repository import (
from ra_aid.database.repositories.human_input_repository import HumanInputRepository TrajectoryRepository,
)
from ra_aid.database.repositories.human_input_repository import (
HumanInputRepository,
)
from ra_aid.database.connection import get_db from ra_aid.database.connection import get_db
# Create repositories directly # Create repositories directly
trajectory_repo = TrajectoryRepository(get_db()) trajectory_repo = TrajectoryRepository(get_db())
human_input_repo = HumanInputRepository(get_db()) human_input_repo = HumanInputRepository(get_db())
human_input_id = human_input_repo.get_most_recent_id() human_input_id = human_input_repo.get_most_recent_id()
trajectory_repo.create( trajectory_repo.create(
step_data={ step_data={
"message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.", "message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.",
"display_title": "Information", "display_title": "Information",
}, },
record_type="info", record_type="info",
human_input_id=human_input_id human_input_id=human_input_id,
) )
cpm( cpm(
"This model supports temperature argument but none was given. Setting default temperature to 0.7." "This model supports temperature argument but none was given. Setting default temperature to 0.7."
@ -302,9 +310,9 @@ def create_llm_client(
model_name=model_name, model_name=model_name,
timeout=LLM_REQUEST_TIMEOUT, timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES, max_retries=LLM_MAX_RETRIES,
max_tokens=model_config.get("max_tokens", 64000),
**temp_kwargs, **temp_kwargs,
**thinking_kwargs, **thinking_kwargs,
**other_kwargs,
) )
elif provider == "openai-compatible": elif provider == "openai-compatible":
return ChatOpenAI( return ChatOpenAI(

39
ra_aid/model_detection.py Normal file
View File

@ -0,0 +1,39 @@
"""Utilities for detecting and working with specific model types."""
from typing import Optional, Dict, Any
def is_claude_37(model: str) -> bool:
"""Check if the model is a Claude 3.7 model.
Args:
model: The model name to check
Returns:
bool: True if the model is a Claude 3.7 model, False otherwise
"""
patterns = ["claude-3.7", "claude3.7", "claude-3-7"]
return any(pattern in model for pattern in patterns)
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
"""Check if the provider and model name indicate an Anthropic Claude model.
Args:
config: Configuration dictionary containing provider and model information
Returns:
bool: True if this is an Anthropic Claude model
"""
# For backwards compatibility, allow passing of config directly
provider = config.get("provider", "")
model_name = config.get("model", "")
result = (
provider.lower() == "anthropic"
and model_name
and "claude" in model_name.lower()
) or (
provider.lower() == "openrouter"
and model_name.lower().startswith("anthropic/claude-")
)
return result

View File

@ -13,8 +13,7 @@ from ra_aid.agent_context import (
) )
from ra_aid.agent_utils import ( from ra_aid.agent_utils import (
AgentState, AgentState,
create_agent, create_agent
is_anthropic_claude,
) )
from ra_aid.anthropic_token_limiter import ( from ra_aid.anthropic_token_limiter import (
get_model_token_limit, get_model_token_limit,
@ -453,31 +452,6 @@ def test_handle_api_error_retry(monkeypatch):
_handle_api_error(Exception("error code 429"), 0, 5, 1) _handle_api_error(Exception("error code 429"), 0, 5, 1)
def test_is_anthropic_claude():
"""Test is_anthropic_claude function with various configurations."""
# Test Anthropic provider cases
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
# Test OpenRouter provider cases
assert is_anthropic_claude(
{"provider": "openrouter", "model": "anthropic/claude-2"}
)
assert is_anthropic_claude(
{"provider": "openrouter", "model": "anthropic/claude-instant"}
)
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
# Test edge cases
assert not is_anthropic_claude({}) # Empty config
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
assert not is_anthropic_claude(
{"provider": "other", "model": "claude-2"}
) # Wrong provider
def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repository): def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repository):
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration.""" """Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
from ra_aid.agent_context import agent_context, mark_agent_crashed from ra_aid.agent_context import agent_context, mark_agent_crashed

View File

@ -17,10 +17,11 @@ from ra_aid.anthropic_token_limiter import (
get_model_token_limit, get_model_token_limit,
state_modifier, state_modifier,
sonnet_35_state_modifier, sonnet_35_state_modifier,
convert_message_to_litellm_format convert_message_to_litellm_format,
adjust_claude_37_token_limit
) )
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT from ra_aid.models_params import models_params
class TestAnthropicTokenLimiter(unittest.TestCase): class TestAnthropicTokenLimiter(unittest.TestCase):
@ -113,9 +114,8 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
self.assertEqual(result, 0) self.assertEqual(result, 0)
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") @patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper): def test_state_modifier(self, mock_trim_messages, mock_create_wrapper):
# Setup a proper token counter function that returns integers # Setup a proper token counter function that returns integers
def token_counter(msgs): def token_counter(msgs):
# Return token count based on number of messages # Return token count based on number of messages
@ -155,8 +155,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
model.model = "claude-3-opus-20240229" model.model = "claude-3-opus-20240229"
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \ with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \ patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim:
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
# Setup mock to return a fixed token count per message # Setup mock to return a fixed token count per message
mock_wrapper.return_value = lambda msgs: len(msgs) * 100 mock_wrapper.return_value = lambda msgs: len(msgs) * 100
# Setup mock to return a subset of messages # Setup mock to return a subset of messages
@ -206,23 +205,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
self.assertEqual(call_args["include_system"], True) self.assertEqual(call_args["include_system"], True)
@patch("ra_aid.anthropic_token_limiter.get_config_repository") @patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info") @patch("ra_aid.anthropic_token_limiter.get_model_info")
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo): @patch("ra_aid.anthropic_token_limiter.is_claude_37")
from ra_aid.config import DEFAULT_MODEL @patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
def test_get_model_token_limit_from_litellm(self, mock_adjust, mock_is_claude_37, mock_get_model_info, mock_get_config_repo):
# Use a specific model name instead of DEFAULT_MODEL to avoid test dependency
model_name = "claude-3-7-sonnet-20250219"
# Setup mocks # Setup mocks
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL} mock_config = {"provider": "anthropic", "model": model_name}
mock_get_config_repo.return_value.get_all.return_value = mock_config mock_get_config_repo.return_value.get_all.return_value = mock_config
# Mock litellm's get_model_info to return a token limit # Mock litellm's get_model_info to return a token limit
mock_get_model_info.return_value = {"max_input_tokens": 100000} mock_get_model_info.return_value = {"max_input_tokens": 100000}
# Mock is_claude_37 to return True
mock_is_claude_37.return_value = True
# Mock adjust_claude_37_token_limit to return the original value
mock_adjust.return_value = 100000
# Test getting token limit # Test getting token limit
result = get_model_token_limit(mock_config) result = get_model_token_limit(mock_config)
self.assertEqual(result, 100000) self.assertEqual(result, 100000)
# Verify get_model_info was called with the right model # Verify get_model_info was called with the right model
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}") mock_get_model_info.assert_called_once_with(f"anthropic/{model_name}")
# Verify adjust_claude_37_token_limit was called
mock_adjust.assert_called_once_with(100000, None)
def test_get_model_token_limit_research(self): def test_get_model_token_limit_research(self):
"""Test get_model_token_limit with research provider and model.""" """Test get_model_token_limit with research provider and model."""
@ -230,17 +241,24 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
"provider": "openai", "provider": "openai",
"model": "gpt-4", "model": "gpt-4",
"research_provider": "anthropic", "research_provider": "anthropic",
"research_model": "claude-2", "research_model": "claude-3-7-sonnet-20250219",
} }
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info: patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
mock_get_config_repo.return_value.get_all.return_value = config mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 150000} mock_get_info.return_value = {"max_input_tokens": 150000}
mock_adjust.return_value = 150000
# Call the function to check the return value
token_limit = get_model_token_limit(config, "research") token_limit = get_model_token_limit(config, "research")
self.assertEqual(token_limit, 150000) self.assertEqual(token_limit, 150000)
# Verify get_model_info was called with the research model # Verify get_model_info was called with the research model
mock_get_info.assert_called_with("anthropic/claude-2") mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
# Verify adjust_claude_37_token_limit was called
mock_adjust.assert_called_once_with(150000, None)
def test_get_model_token_limit_planner(self): def test_get_model_token_limit_planner(self):
"""Test get_model_token_limit with planner provider and model.""" """Test get_model_token_limit with planner provider and model."""
@ -252,16 +270,23 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
} }
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info: patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
mock_get_config_repo.return_value.get_all.return_value = config mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 120000} mock_get_info.return_value = {"max_input_tokens": 120000}
mock_adjust.return_value = 120000
# Call the function to check the return value
token_limit = get_model_token_limit(config, "planner") token_limit = get_model_token_limit(config, "planner")
self.assertEqual(token_limit, 120000) self.assertEqual(token_limit, 120000)
# Verify get_model_info was called with the planner model # Verify get_model_info was called with the planner model
mock_get_info.assert_called_with("deepseek/dsm-1") mock_get_info.assert_called_once_with("deepseek/dsm-1")
# Verify adjust_claude_37_token_limit was called
mock_adjust.assert_called_once_with(120000, None)
@patch("ra_aid.anthropic_token_limiter.get_config_repository") @patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info") @patch("ra_aid.anthropic_token_limiter.get_model_info")
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo): def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
# Setup mocks # Setup mocks
mock_config = {"provider": "anthropic", "model": "claude-2"} mock_config = {"provider": "anthropic", "model": "claude-2"}
@ -280,54 +305,87 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
self.assertEqual(result, 100000) self.assertEqual(result, 100000)
@patch("ra_aid.anthropic_token_limiter.get_config_repository") @patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info") @patch("ra_aid.anthropic_token_limiter.get_model_info")
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo): @patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit")
from ra_aid.config import DEFAULT_MODEL def test_get_model_token_limit_for_different_agent_types(self, mock_adjust, mock_get_model_info, mock_get_config_repo):
# Use specific model names instead of DEFAULT_MODEL to avoid test dependency
claude_model = "claude-3-7-sonnet-20250219"
# Setup mocks for different agent types # Setup mocks for different agent types
mock_config = { mock_config = {
"provider": "anthropic", "provider": "anthropic",
"model": DEFAULT_MODEL, "model": claude_model,
"research_provider": "openai", "research_provider": "openai",
"research_model": "gpt-4", "research_model": "gpt-4",
"planner_provider": "anthropic", "planner_provider": "anthropic",
"planner_model": "claude-3-sonnet-20240229" "planner_model": "claude-3-7-opus-20250301"
} }
mock_get_config_repo.return_value.get_all.return_value = mock_config mock_get_config_repo.return_value.get_all.return_value = mock_config
# Mock different returns for different models # Mock different returns for different models
def model_info_side_effect(model_name): def model_info_side_effect(model_name):
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name: if "claude-3-7-sonnet" in model_name:
return {"max_input_tokens": 200000} return {"max_input_tokens": 200000}
elif "gpt-4" in model_name: elif "gpt-4" in model_name:
return {"max_input_tokens": 8192} return {"max_input_tokens": 8192}
elif "claude-3-sonnet" in model_name: elif "claude-3-7-opus" in model_name:
return {"max_input_tokens": 100000} return {"max_input_tokens": 250000}
else: else:
raise Exception(f"Unknown model: {model_name}") raise Exception(f"Unknown model: {model_name}")
mock_get_model_info.side_effect = model_info_side_effect mock_get_model_info.side_effect = model_info_side_effect
# Mock adjust_claude_37_token_limit to return the same values
mock_adjust.side_effect = lambda tokens, model: tokens
# Test default agent type # Test default agent type
result = get_model_token_limit(mock_config, "default") result = get_model_token_limit(mock_config, "default")
self.assertEqual(result, 200000) self.assertEqual(result, 200000)
mock_get_model_info.assert_called_with(f"anthropic/{claude_model}")
# Reset mock
mock_get_model_info.reset_mock()
# Test research agent type # Test research agent type
result = get_model_token_limit(mock_config, "research") result = get_model_token_limit(mock_config, "research")
self.assertEqual(result, 8192) self.assertEqual(result, 8192)
mock_get_model_info.assert_called_with("openai/gpt-4")
# Reset mock
mock_get_model_info.reset_mock()
# Test planner agent type # Test planner agent type
result = get_model_token_limit(mock_config, "planner") result = get_model_token_limit(mock_config, "planner")
self.assertEqual(result, 100000) self.assertEqual(result, 250000)
mock_get_model_info.assert_called_with("anthropic/claude-3-7-opus-20250301")
def test_get_model_token_limit_anthropic(self): def test_get_model_token_limit_anthropic(self):
"""Test get_model_token_limit with Anthropic model.""" """Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"} config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
patch("litellm.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
# Setup mocks
mock_get_config_repo.return_value.get_all.return_value = config mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.side_effect = Exception("Model not found")
# Create a mock models_params with claude-3-7
mock_models_params_dict = {
"anthropic": {
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
}
}
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
mock_models_params.get.side_effect = mock_models_params_dict.get
# Mock adjust to return the same value
mock_adjust.return_value = 200000
token_limit = get_model_token_limit(config, "default") token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) self.assertEqual(token_limit, 200000)
def test_get_model_token_limit_openai(self): def test_get_model_token_limit_openai(self):
"""Test get_model_token_limit with OpenAI model.""" """Test get_model_token_limit with OpenAI model."""
@ -358,28 +416,51 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
def test_get_model_token_limit_litellm_success(self): def test_get_model_token_limit_litellm_success(self):
"""Test get_model_token_limit successfully getting limit from litellm.""" """Test get_model_token_limit successfully getting limit from litellm."""
config = {"provider": "anthropic", "model": "claude-2"} config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info: patch("ra_aid.anthropic_token_limiter.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
mock_get_config_repo.return_value.get_all.return_value = config mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 100000} mock_get_info.return_value = {"max_input_tokens": 100000}
mock_adjust.return_value = 100000
# Call the function to check the return value
token_limit = get_model_token_limit(config, "default") token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, 100000) self.assertEqual(token_limit, 100000)
mock_get_info.assert_called_with("anthropic/claude-2")
# Verify get_model_info was called with the right model
mock_get_info.assert_called_once_with("anthropic/claude-3-7-sonnet-20250219")
mock_adjust.assert_called_once_with(100000, None)
def test_get_model_token_limit_litellm_not_found(self): def test_get_model_token_limit_litellm_not_found(self):
"""Test fallback to models_tokens when litellm raises NotFoundError.""" """Test fallback to models_tokens when litellm raises NotFoundError."""
config = {"provider": "anthropic", "model": "claude-2"} config = {"provider": "anthropic", "model": "claude-3-7-sonnet-20250219"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info: patch("litellm.get_model_info") as mock_get_info, \
patch("ra_aid.anthropic_token_limiter.models_params") as mock_models_params, \
patch("ra_aid.anthropic_token_limiter.adjust_claude_37_token_limit") as mock_adjust:
mock_get_config_repo.return_value.get_all.return_value = config mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.side_effect = litellm.exceptions.NotFoundError( mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found", model="claude-2", llm_provider="anthropic" message="Model not found", model="claude-3-7-sonnet-20250219", llm_provider="anthropic"
) )
# Create a mock models_params with claude-3-7
mock_models_params_dict = {
"anthropic": {
"claude-3-7-sonnet-20250219": {"token_limit": 200000}
}
}
mock_models_params.__getitem__.side_effect = mock_models_params_dict.__getitem__
mock_models_params.get.side_effect = mock_models_params_dict.get
# Mock adjust to return the same value
mock_adjust.return_value = 200000
token_limit = get_model_token_limit(config, "default") token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) self.assertEqual(token_limit, 200000)
def test_get_model_token_limit_litellm_error(self): def test_get_model_token_limit_litellm_error(self):
"""Test fallback to models_tokens when litellm raises other exceptions.""" """Test fallback to models_tokens when litellm raises other exceptions."""
@ -399,6 +480,30 @@ class TestAnthropicTokenLimiter(unittest.TestCase):
token_limit = get_model_token_limit(config, "default") token_limit = get_model_token_limit(config, "default")
self.assertIsNone(token_limit) self.assertIsNone(token_limit)
def test_adjust_claude_37_token_limit(self):
"""Test adjust_claude_37_token_limit function."""
# Create a mock model
mock_model = MagicMock()
mock_model.model = "claude-3.7-sonnet"
mock_model.max_tokens = 4096
# Test with Claude 3.7 model
result = adjust_claude_37_token_limit(100000, mock_model)
self.assertEqual(result, 95904) # 100000 - 4096
# Test with non-Claude 3.7 model
mock_model.model = "claude-3-opus"
result = adjust_claude_37_token_limit(100000, mock_model)
self.assertEqual(result, 100000) # No adjustment
# Test with None max_input_tokens
result = adjust_claude_37_token_limit(None, mock_model)
self.assertIsNone(result)
# Test with None model
result = adjust_claude_37_token_limit(100000, None)
self.assertEqual(result, 100000)
def test_has_tool_use(self): def test_has_tool_use(self):
"""Test the has_tool_use function.""" """Test the has_tool_use function."""
# Test with regular AI message # Test with regular AI message

View File

@ -0,0 +1,50 @@
"""Unit tests for model_detection.py."""
import pytest
from ra_aid.model_detection import is_anthropic_claude, is_claude_37
def test_is_anthropic_claude():
"""Test is_anthropic_claude function with various configurations."""
# Test Anthropic provider cases
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
# Test OpenRouter provider cases
assert is_anthropic_claude(
{"provider": "openrouter", "model": "anthropic/claude-2"}
)
assert is_anthropic_claude(
{"provider": "openrouter", "model": "anthropic/claude-instant"}
)
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
# Test edge cases
assert not is_anthropic_claude({}) # Empty config
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
assert not is_anthropic_claude(
{"provider": "other", "model": "claude-2"}
) # Wrong provider
def test_is_claude_37():
"""Test is_claude_37 function with various model names."""
# Test positive cases
assert is_claude_37("claude-3.7")
assert is_claude_37("claude3.7")
assert is_claude_37("claude-3-7")
assert is_claude_37("anthropic/claude-3.7")
assert is_claude_37("anthropic/claude3.7")
assert is_claude_37("anthropic/claude-3-7")
assert is_claude_37("claude-3.7-sonnet")
assert is_claude_37("claude3.7-haiku")
# Test negative cases
assert not is_claude_37("claude-3")
assert not is_claude_37("claude-3.5")
assert not is_claude_37("claude3.5")
assert not is_claude_37("claude-3-5")
assert not is_claude_37("gpt-4")
assert not is_claude_37("")