feat(main.py): reorganize litellm configuration to improve clarity and maintainability
feat(agent_utils.py): add model detection utilities for Claude 3.7 models fix(agent_utils.py): update get_model_token_limit to handle Claude 3.7 token limits correctly test(model_detection.py): add unit tests for model detection utilities chore(agent_utils.py): remove deprecated is_anthropic_claude function and related tests style(agent_utils.py): format code for better readability and consistency
This commit is contained in:
parent
07c6c2e5b5
commit
29c9cac4f4
|
|
@ -5,24 +5,8 @@ import sys
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
# Add litellm import
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
# Configure litellm to suppress debug logs
|
|
||||||
os.environ["LITELLM_LOG"] = "ERROR"
|
|
||||||
litellm.suppress_debug_info = True
|
|
||||||
litellm.set_verbose = False
|
|
||||||
|
|
||||||
# Explicitly configure LiteLLM's loggers
|
|
||||||
for logger_name in ["litellm", "LiteLLM"]:
|
|
||||||
litellm_logger = logging.getLogger(logger_name)
|
|
||||||
litellm_logger.setLevel(logging.WARNING)
|
|
||||||
litellm_logger.propagate = True
|
|
||||||
|
|
||||||
# Use litellm's internal method to disable debugging
|
|
||||||
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
|
|
||||||
litellm._logging._disable_debugging()
|
|
||||||
|
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
@ -99,6 +83,21 @@ from ra_aid.tools.human import ask_human
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Configure litellm to suppress debug logs
|
||||||
|
os.environ["LITELLM_LOG"] = "ERROR"
|
||||||
|
litellm.suppress_debug_info = True
|
||||||
|
litellm.set_verbose = False
|
||||||
|
|
||||||
|
# Explicitly configure LiteLLM's loggers
|
||||||
|
for logger_name in ["litellm", "LiteLLM"]:
|
||||||
|
litellm_logger = logging.getLogger(logger_name)
|
||||||
|
litellm_logger.setLevel(logging.WARNING)
|
||||||
|
litellm_logger.propagate = True
|
||||||
|
|
||||||
|
# Use litellm's internal method to disable debugging
|
||||||
|
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
|
||||||
|
litellm._logging._disable_debugging()
|
||||||
|
|
||||||
|
|
||||||
def launch_webui(host: str, port: int):
|
def launch_webui(host: str, port: int):
|
||||||
"""Launch the RA.Aid web interface."""
|
"""Launch the RA.Aid web interface."""
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,12 @@ from ra_aid.database.repositories.human_input_repository import (
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
|
from ra_aid.anthropic_token_limiter import (
|
||||||
|
sonnet_35_state_modifier,
|
||||||
|
state_modifier,
|
||||||
|
get_model_token_limit,
|
||||||
|
)
|
||||||
|
from ra_aid.model_detection import is_anthropic_claude
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
@ -67,8 +72,6 @@ def output_markdown_message(message: str) -> str:
|
||||||
return "Message output."
|
return "Message output."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_agent_kwargs(
|
def build_agent_kwargs(
|
||||||
checkpointer: Optional[Any] = None,
|
checkpointer: Optional[Any] = None,
|
||||||
model: ChatAnthropic = None,
|
model: ChatAnthropic = None,
|
||||||
|
|
@ -99,8 +102,13 @@ def build_agent_kwargs(
|
||||||
):
|
):
|
||||||
|
|
||||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||||
if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]):
|
if any(
|
||||||
return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens)
|
pattern in model.model
|
||||||
|
for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]
|
||||||
|
):
|
||||||
|
return sonnet_35_state_modifier(
|
||||||
|
state, max_input_tokens=max_input_tokens
|
||||||
|
)
|
||||||
|
|
||||||
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
||||||
|
|
||||||
|
|
@ -110,27 +118,6 @@ def build_agent_kwargs(
|
||||||
return agent_kwargs
|
return agent_kwargs
|
||||||
|
|
||||||
|
|
||||||
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
|
||||||
"""Check if the provider and model name indicate an Anthropic Claude model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration dictionary containing provider and model information
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if this is an Anthropic Claude model
|
|
||||||
"""
|
|
||||||
# For backwards compatibility, allow passing of config directly
|
|
||||||
provider = config.get("provider", "")
|
|
||||||
model_name = config.get("model", "")
|
|
||||||
result = (
|
|
||||||
provider.lower() == "anthropic"
|
|
||||||
and model_name
|
|
||||||
and "claude" in model_name.lower()
|
|
||||||
) or (
|
|
||||||
provider.lower() == "openrouter"
|
|
||||||
and model_name.lower().startswith("anthropic/claude-")
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def create_agent(
|
def create_agent(
|
||||||
|
|
@ -169,7 +156,7 @@ def create_agent(
|
||||||
# So we'll use the passed config directly
|
# So we'll use the passed config directly
|
||||||
pass
|
pass
|
||||||
max_input_tokens = (
|
max_input_tokens = (
|
||||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
get_model_token_limit(config, agent_type, model) or DEFAULT_TOKEN_LIMIT
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||||
|
|
@ -188,7 +175,7 @@ def create_agent(
|
||||||
# Default to REACT agent if provider/model detection fails
|
# Default to REACT agent if provider/model detection fails
|
||||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
max_input_tokens = get_model_token_limit(config, agent_type, model)
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||||
return create_react_agent(
|
return create_react_agent(
|
||||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||||
|
|
@ -301,7 +288,7 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||||
record_type="error",
|
record_type="error",
|
||||||
human_input_id=human_input_id,
|
human_input_id=human_input_id,
|
||||||
is_error=True,
|
is_error=True,
|
||||||
error_message=error_message
|
error_message=error_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
print_error(error_message)
|
print_error(error_message)
|
||||||
|
|
@ -464,7 +451,9 @@ def run_agent_with_retry(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_run_agent_stream(agent, msg_list)
|
_run_agent_stream(agent, msg_list)
|
||||||
if fallback_handler and hasattr(fallback_handler, 'reset_fallback_handler'):
|
if fallback_handler and hasattr(
|
||||||
|
fallback_handler, "reset_fallback_handler"
|
||||||
|
):
|
||||||
fallback_handler.reset_fallback_handler()
|
fallback_handler.reset_fallback_handler()
|
||||||
should_break, prompt, auto_test, test_attempts = (
|
should_break, prompt, auto_test, test_attempts = (
|
||||||
_execute_test_command_wrapper(
|
_execute_test_command_wrapper(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
"""Utilities for handling token limits with Anthropic models."""
|
"""Utilities for handling token limits with Anthropic models."""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Sequence
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from ra_aid.model_detection import is_claude_37
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
@ -19,7 +21,7 @@ from ra_aid.anthropic_message_utils import (
|
||||||
has_tool_use,
|
has_tool_use,
|
||||||
)
|
)
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from litellm import token_counter
|
from litellm import token_counter, get_model_info
|
||||||
|
|
||||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
|
@ -168,14 +170,39 @@ def sonnet_35_state_modifier(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
|
||||||
|
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_input_tokens: The original token limit
|
||||||
|
model: The model instance to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit
|
||||||
|
"""
|
||||||
|
if not max_input_tokens:
|
||||||
|
return max_input_tokens
|
||||||
|
|
||||||
|
if model and hasattr(model, 'model') and is_claude_37(model.model):
|
||||||
|
if hasattr(model, 'max_tokens') and model.max_tokens:
|
||||||
|
effective_max_input_tokens = max_input_tokens - model.max_tokens
|
||||||
|
logger.debug(
|
||||||
|
f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}"
|
||||||
|
)
|
||||||
|
return effective_max_input_tokens
|
||||||
|
|
||||||
|
return max_input_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(
|
def get_model_token_limit(
|
||||||
config: Dict[str, Any], agent_type: str = "default"
|
config: Dict[str, Any], agent_type: str = "default", model: Optional[BaseChatModel] = None
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
"""Get the token limit for the current model configuration based on agent type.
|
"""Get the token limit for the current model configuration based on agent type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Configuration dictionary containing provider and model information
|
config: Configuration dictionary containing provider and model information
|
||||||
agent_type: Type of agent ("default", "research", or "planner")
|
agent_type: Type of agent ("default", "research", or "planner")
|
||||||
|
model: Optional BaseChatModel instance to check for model-specific attributes
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[int]: The token limit if found, None otherwise
|
Optional[int]: The token limit if found, None otherwise
|
||||||
|
|
@ -201,7 +228,6 @@ def get_model_token_limit(
|
||||||
model_name = config.get("model", "")
|
model_name = config.get("model", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm import get_model_info
|
|
||||||
|
|
||||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||||
model_info = get_model_info(provider_model)
|
model_info = get_model_info(provider_model)
|
||||||
|
|
@ -210,7 +236,7 @@ def get_model_token_limit(
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||||
)
|
)
|
||||||
return max_input_tokens
|
return adjust_claude_37_token_limit(max_input_tokens, model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||||
|
|
@ -229,7 +255,7 @@ def get_model_token_limit(
|
||||||
max_input_tokens = None
|
max_input_tokens = None
|
||||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||||
|
|
||||||
return max_input_tokens
|
return adjust_claude_37_token_limit(max_input_tokens, model)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get model token limit: {e}")
|
logger.warning(f"Failed to get model token limit: {e}")
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from openai import OpenAI
|
||||||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
from ra_aid.model_detection import is_claude_37
|
||||||
|
|
||||||
from .models_params import models_params
|
from .models_params import models_params
|
||||||
|
|
||||||
|
|
@ -218,7 +219,6 @@ def create_llm_client(
|
||||||
is_expert,
|
is_expert,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get model configuration
|
|
||||||
model_config = models_params.get(provider, {}).get(model_name, {})
|
model_config = models_params.get(provider, {}).get(model_name, {})
|
||||||
|
|
||||||
# Default to True for known providers that support temperature if not specified
|
# Default to True for known providers that support temperature if not specified
|
||||||
|
|
@ -228,6 +228,10 @@ def create_llm_client(
|
||||||
supports_temperature = model_config["supports_temperature"]
|
supports_temperature = model_config["supports_temperature"]
|
||||||
supports_thinking = model_config.get("supports_thinking", False)
|
supports_thinking = model_config.get("supports_thinking", False)
|
||||||
|
|
||||||
|
other_kwargs = {}
|
||||||
|
if is_claude_37(model_name):
|
||||||
|
other_kwargs = {"max_tokens": 64000}
|
||||||
|
|
||||||
# Handle temperature settings
|
# Handle temperature settings
|
||||||
if is_expert:
|
if is_expert:
|
||||||
temp_kwargs = {"temperature": 0} if supports_temperature else {}
|
temp_kwargs = {"temperature": 0} if supports_temperature else {}
|
||||||
|
|
@ -235,8 +239,12 @@ def create_llm_client(
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = 0.7
|
temperature = 0.7
|
||||||
# Import repository classes directly to avoid circular imports
|
# Import repository classes directly to avoid circular imports
|
||||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
from ra_aid.database.repositories.trajectory_repository import (
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
TrajectoryRepository,
|
||||||
|
)
|
||||||
|
from ra_aid.database.repositories.human_input_repository import (
|
||||||
|
HumanInputRepository,
|
||||||
|
)
|
||||||
from ra_aid.database.connection import get_db
|
from ra_aid.database.connection import get_db
|
||||||
|
|
||||||
# Create repositories directly
|
# Create repositories directly
|
||||||
|
|
@ -250,7 +258,7 @@ def create_llm_client(
|
||||||
"display_title": "Information",
|
"display_title": "Information",
|
||||||
},
|
},
|
||||||
record_type="info",
|
record_type="info",
|
||||||
human_input_id=human_input_id
|
human_input_id=human_input_id,
|
||||||
)
|
)
|
||||||
cpm(
|
cpm(
|
||||||
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
||||||
|
|
@ -302,9 +310,9 @@ def create_llm_client(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
timeout=LLM_REQUEST_TIMEOUT,
|
timeout=LLM_REQUEST_TIMEOUT,
|
||||||
max_retries=LLM_MAX_RETRIES,
|
max_retries=LLM_MAX_RETRIES,
|
||||||
max_tokens=model_config.get("max_tokens", 64000),
|
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
**thinking_kwargs,
|
**thinking_kwargs,
|
||||||
|
**other_kwargs,
|
||||||
)
|
)
|
||||||
elif provider == "openai-compatible":
|
elif provider == "openai-compatible":
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""Utilities for detecting and working with specific model types."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
def is_claude_37(model: str) -> bool:
|
||||||
|
"""Check if the model is a Claude 3.7 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model name to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model is a Claude 3.7 model, False otherwise
|
||||||
|
"""
|
||||||
|
patterns = ["claude-3.7", "claude3.7", "claude-3-7"]
|
||||||
|
return any(pattern in model for pattern in patterns)
|
||||||
|
|
||||||
|
|
||||||
|
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
||||||
|
"""Check if the provider and model name indicate an Anthropic Claude model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary containing provider and model information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if this is an Anthropic Claude model
|
||||||
|
"""
|
||||||
|
# For backwards compatibility, allow passing of config directly
|
||||||
|
provider = config.get("provider", "")
|
||||||
|
model_name = config.get("model", "")
|
||||||
|
result = (
|
||||||
|
provider.lower() == "anthropic"
|
||||||
|
and model_name
|
||||||
|
and "claude" in model_name.lower()
|
||||||
|
) or (
|
||||||
|
provider.lower() == "openrouter"
|
||||||
|
and model_name.lower().startswith("anthropic/claude-")
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
@ -13,8 +13,7 @@ from ra_aid.agent_context import (
|
||||||
)
|
)
|
||||||
from ra_aid.agent_utils import (
|
from ra_aid.agent_utils import (
|
||||||
AgentState,
|
AgentState,
|
||||||
create_agent,
|
create_agent
|
||||||
is_anthropic_claude,
|
|
||||||
)
|
)
|
||||||
from ra_aid.anthropic_token_limiter import (
|
from ra_aid.anthropic_token_limiter import (
|
||||||
get_model_token_limit,
|
get_model_token_limit,
|
||||||
|
|
@ -453,31 +452,6 @@ def test_handle_api_error_retry(monkeypatch):
|
||||||
_handle_api_error(Exception("error code 429"), 0, 5, 1)
|
_handle_api_error(Exception("error code 429"), 0, 5, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_is_anthropic_claude():
|
|
||||||
"""Test is_anthropic_claude function with various configurations."""
|
|
||||||
# Test Anthropic provider cases
|
|
||||||
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
|
|
||||||
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
|
|
||||||
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
|
|
||||||
|
|
||||||
# Test OpenRouter provider cases
|
|
||||||
assert is_anthropic_claude(
|
|
||||||
{"provider": "openrouter", "model": "anthropic/claude-2"}
|
|
||||||
)
|
|
||||||
assert is_anthropic_claude(
|
|
||||||
{"provider": "openrouter", "model": "anthropic/claude-instant"}
|
|
||||||
)
|
|
||||||
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
|
|
||||||
|
|
||||||
# Test edge cases
|
|
||||||
assert not is_anthropic_claude({}) # Empty config
|
|
||||||
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
|
||||||
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
|
||||||
assert not is_anthropic_claude(
|
|
||||||
{"provider": "other", "model": "claude-2"}
|
|
||||||
) # Wrong provider
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repository):
|
def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repository):
|
||||||
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
|
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
|
||||||
from ra_aid.agent_context import agent_context, mark_agent_crashed
|
from ra_aid.agent_context import agent_context, mark_agent_crashed
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
"""Unit tests for model_detection.py."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from ra_aid.model_detection import is_anthropic_claude, is_claude_37
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_anthropic_claude():
|
||||||
|
"""Test is_anthropic_claude function with various configurations."""
|
||||||
|
# Test Anthropic provider cases
|
||||||
|
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
|
||||||
|
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
|
||||||
|
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
|
||||||
|
|
||||||
|
# Test OpenRouter provider cases
|
||||||
|
assert is_anthropic_claude(
|
||||||
|
{"provider": "openrouter", "model": "anthropic/claude-2"}
|
||||||
|
)
|
||||||
|
assert is_anthropic_claude(
|
||||||
|
{"provider": "openrouter", "model": "anthropic/claude-instant"}
|
||||||
|
)
|
||||||
|
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
|
||||||
|
|
||||||
|
# Test edge cases
|
||||||
|
assert not is_anthropic_claude({}) # Empty config
|
||||||
|
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
||||||
|
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
||||||
|
assert not is_anthropic_claude(
|
||||||
|
{"provider": "other", "model": "claude-2"}
|
||||||
|
) # Wrong provider
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_claude_37():
|
||||||
|
"""Test is_claude_37 function with various model names."""
|
||||||
|
# Test positive cases
|
||||||
|
assert is_claude_37("claude-3.7")
|
||||||
|
assert is_claude_37("claude3.7")
|
||||||
|
assert is_claude_37("claude-3-7")
|
||||||
|
assert is_claude_37("anthropic/claude-3.7")
|
||||||
|
assert is_claude_37("anthropic/claude3.7")
|
||||||
|
assert is_claude_37("anthropic/claude-3-7")
|
||||||
|
assert is_claude_37("claude-3.7-sonnet")
|
||||||
|
assert is_claude_37("claude3.7-haiku")
|
||||||
|
|
||||||
|
# Test negative cases
|
||||||
|
assert not is_claude_37("claude-3")
|
||||||
|
assert not is_claude_37("claude-3.5")
|
||||||
|
assert not is_claude_37("claude3.5")
|
||||||
|
assert not is_claude_37("claude-3-5")
|
||||||
|
assert not is_claude_37("gpt-4")
|
||||||
|
assert not is_claude_37("")
|
||||||
Loading…
Reference in New Issue