feat(main.py): reorganize litellm configuration to improve clarity and maintainability
feat(agent_utils.py): add model detection utilities for Claude 3.7 models fix(agent_utils.py): update get_model_token_limit to handle Claude 3.7 token limits correctly test(model_detection.py): add unit tests for model detection utilities chore(agent_utils.py): remove deprecated is_anthropic_claude function and related tests style(agent_utils.py): format code for better readability and consistency
This commit is contained in:
parent
07c6c2e5b5
commit
29c9cac4f4
|
|
@ -5,24 +5,8 @@ import sys
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Add litellm import
|
||||
import litellm
|
||||
|
||||
# Configure litellm to suppress debug logs
|
||||
os.environ["LITELLM_LOG"] = "ERROR"
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.set_verbose = False
|
||||
|
||||
# Explicitly configure LiteLLM's loggers
|
||||
for logger_name in ["litellm", "LiteLLM"]:
|
||||
litellm_logger = logging.getLogger(logger_name)
|
||||
litellm_logger.setLevel(logging.WARNING)
|
||||
litellm_logger.propagate = True
|
||||
|
||||
# Use litellm's internal method to disable debugging
|
||||
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
|
||||
litellm._logging._disable_debugging()
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
|
@ -99,6 +83,21 @@ from ra_aid.tools.human import ask_human
|
|||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Configure litellm to suppress debug logs
|
||||
os.environ["LITELLM_LOG"] = "ERROR"
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.set_verbose = False
|
||||
|
||||
# Explicitly configure LiteLLM's loggers
|
||||
for logger_name in ["litellm", "LiteLLM"]:
|
||||
litellm_logger = logging.getLogger(logger_name)
|
||||
litellm_logger.setLevel(logging.WARNING)
|
||||
litellm_logger.propagate = True
|
||||
|
||||
# Use litellm's internal method to disable debugging
|
||||
if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"):
|
||||
litellm._logging._disable_debugging()
|
||||
|
||||
|
||||
def launch_webui(host: str, port: int):
|
||||
"""Launch the RA.Aid web interface."""
|
||||
|
|
|
|||
|
|
@ -51,7 +51,12 @@ from ra_aid.database.repositories.human_input_repository import (
|
|||
)
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
sonnet_35_state_modifier,
|
||||
state_modifier,
|
||||
get_model_token_limit,
|
||||
)
|
||||
from ra_aid.model_detection import is_anthropic_claude
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -67,8 +72,6 @@ def output_markdown_message(message: str) -> str:
|
|||
return "Message output."
|
||||
|
||||
|
||||
|
||||
|
||||
def build_agent_kwargs(
|
||||
checkpointer: Optional[Any] = None,
|
||||
model: ChatAnthropic = None,
|
||||
|
|
@ -99,8 +102,13 @@ def build_agent_kwargs(
|
|||
):
|
||||
|
||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||
if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]):
|
||||
return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens)
|
||||
if any(
|
||||
pattern in model.model
|
||||
for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]
|
||||
):
|
||||
return sonnet_35_state_modifier(
|
||||
state, max_input_tokens=max_input_tokens
|
||||
)
|
||||
|
||||
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
||||
|
||||
|
|
@ -110,27 +118,6 @@ def build_agent_kwargs(
|
|||
return agent_kwargs
|
||||
|
||||
|
||||
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
||||
"""Check if the provider and model name indicate an Anthropic Claude model.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
|
||||
Returns:
|
||||
bool: True if this is an Anthropic Claude model
|
||||
"""
|
||||
# For backwards compatibility, allow passing of config directly
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
result = (
|
||||
provider.lower() == "anthropic"
|
||||
and model_name
|
||||
and "claude" in model_name.lower()
|
||||
) or (
|
||||
provider.lower() == "openrouter"
|
||||
and model_name.lower().startswith("anthropic/claude-")
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def create_agent(
|
||||
|
|
@ -169,7 +156,7 @@ def create_agent(
|
|||
# So we'll use the passed config directly
|
||||
pass
|
||||
max_input_tokens = (
|
||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||
get_model_token_limit(config, agent_type, model) or DEFAULT_TOKEN_LIMIT
|
||||
)
|
||||
|
||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||
|
|
@ -188,7 +175,7 @@ def create_agent(
|
|||
# Default to REACT agent if provider/model detection fails
|
||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
config = get_config_repository().get_all()
|
||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||
max_input_tokens = get_model_token_limit(config, agent_type, model)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
|
|
@ -289,7 +276,7 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||
delay = base_delay * (2**attempt)
|
||||
error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
|
|
@ -301,9 +288,9 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
|
||||
print_error(error_message)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
|
|
@ -464,7 +451,9 @@ def run_agent_with_retry(
|
|||
|
||||
try:
|
||||
_run_agent_stream(agent, msg_list)
|
||||
if fallback_handler and hasattr(fallback_handler, 'reset_fallback_handler'):
|
||||
if fallback_handler and hasattr(
|
||||
fallback_handler, "reset_fallback_handler"
|
||||
):
|
||||
fallback_handler.reset_fallback_handler()
|
||||
should_break, prompt, auto_test, test_attempts = (
|
||||
_execute_test_command_wrapper(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
"""Utilities for handling token limits with Anthropic models."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from ra_aid.model_detection import is_claude_37
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
|
@ -19,7 +21,7 @@ from ra_aid.anthropic_message_utils import (
|
|||
has_tool_use,
|
||||
)
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import token_counter
|
||||
from litellm import token_counter, get_model_info
|
||||
|
||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
|
|
@ -168,14 +170,39 @@ def sonnet_35_state_modifier(
|
|||
return result
|
||||
|
||||
|
||||
def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]:
|
||||
"""Adjust token limit for Claude 3.7 models by subtracting max_tokens.
|
||||
|
||||
Args:
|
||||
max_input_tokens: The original token limit
|
||||
model: The model instance to check
|
||||
|
||||
Returns:
|
||||
Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit
|
||||
"""
|
||||
if not max_input_tokens:
|
||||
return max_input_tokens
|
||||
|
||||
if model and hasattr(model, 'model') and is_claude_37(model.model):
|
||||
if hasattr(model, 'max_tokens') and model.max_tokens:
|
||||
effective_max_input_tokens = max_input_tokens - model.max_tokens
|
||||
logger.debug(
|
||||
f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}"
|
||||
)
|
||||
return effective_max_input_tokens
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
config: Dict[str, Any], agent_type: str = "default"
|
||||
config: Dict[str, Any], agent_type: str = "default", model: Optional[BaseChatModel] = None
|
||||
) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration based on agent type.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
agent_type: Type of agent ("default", "research", or "planner")
|
||||
model: Optional BaseChatModel instance to check for model-specific attributes
|
||||
|
||||
Returns:
|
||||
Optional[int]: The token limit if found, None otherwise
|
||||
|
|
@ -201,7 +228,6 @@ def get_model_token_limit(
|
|||
model_name = config.get("model", "")
|
||||
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
model_info = get_model_info(provider_model)
|
||||
|
|
@ -210,7 +236,7 @@ def get_model_token_limit(
|
|||
logger.debug(
|
||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||
)
|
||||
return max_input_tokens
|
||||
return adjust_claude_37_token_limit(max_input_tokens, model)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||
|
|
@ -229,7 +255,7 @@ def get_model_token_limit(
|
|||
max_input_tokens = None
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
return max_input_tokens
|
||||
return adjust_claude_37_token_limit(max_input_tokens, model)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model token limit: {e}")
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from openai import OpenAI
|
|||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.model_detection import is_claude_37
|
||||
|
||||
from .models_params import models_params
|
||||
|
||||
|
|
@ -218,7 +219,6 @@ def create_llm_client(
|
|||
is_expert,
|
||||
)
|
||||
|
||||
# Get model configuration
|
||||
model_config = models_params.get(provider, {}).get(model_name, {})
|
||||
|
||||
# Default to True for known providers that support temperature if not specified
|
||||
|
|
@ -228,6 +228,10 @@ def create_llm_client(
|
|||
supports_temperature = model_config["supports_temperature"]
|
||||
supports_thinking = model_config.get("supports_thinking", False)
|
||||
|
||||
other_kwargs = {}
|
||||
if is_claude_37(model_name):
|
||||
other_kwargs = {"max_tokens": 64000}
|
||||
|
||||
# Handle temperature settings
|
||||
if is_expert:
|
||||
temp_kwargs = {"temperature": 0} if supports_temperature else {}
|
||||
|
|
@ -235,22 +239,26 @@ def create_llm_client(
|
|||
if temperature is None:
|
||||
temperature = 0.7
|
||||
# Import repository classes directly to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.repositories.trajectory_repository import (
|
||||
TrajectoryRepository,
|
||||
)
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
HumanInputRepository,
|
||||
)
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.",
|
||||
"display_title": "Information",
|
||||
},
|
||||
record_type="info",
|
||||
human_input_id=human_input_id
|
||||
human_input_id=human_input_id,
|
||||
)
|
||||
cpm(
|
||||
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
||||
|
|
@ -302,9 +310,9 @@ def create_llm_client(
|
|||
model_name=model_name,
|
||||
timeout=LLM_REQUEST_TIMEOUT,
|
||||
max_retries=LLM_MAX_RETRIES,
|
||||
max_tokens=model_config.get("max_tokens", 64000),
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
**other_kwargs,
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
return ChatOpenAI(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
"""Utilities for detecting and working with specific model types."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
def is_claude_37(model: str) -> bool:
|
||||
"""Check if the model is a Claude 3.7 model.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a Claude 3.7 model, False otherwise
|
||||
"""
|
||||
patterns = ["claude-3.7", "claude3.7", "claude-3-7"]
|
||||
return any(pattern in model for pattern in patterns)
|
||||
|
||||
|
||||
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
||||
"""Check if the provider and model name indicate an Anthropic Claude model.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
|
||||
Returns:
|
||||
bool: True if this is an Anthropic Claude model
|
||||
"""
|
||||
# For backwards compatibility, allow passing of config directly
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
result = (
|
||||
provider.lower() == "anthropic"
|
||||
and model_name
|
||||
and "claude" in model_name.lower()
|
||||
) or (
|
||||
provider.lower() == "openrouter"
|
||||
and model_name.lower().startswith("anthropic/claude-")
|
||||
)
|
||||
return result
|
||||
|
|
@ -13,8 +13,7 @@ from ra_aid.agent_context import (
|
|||
)
|
||||
from ra_aid.agent_utils import (
|
||||
AgentState,
|
||||
create_agent,
|
||||
is_anthropic_claude,
|
||||
create_agent
|
||||
)
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
get_model_token_limit,
|
||||
|
|
@ -453,31 +452,6 @@ def test_handle_api_error_retry(monkeypatch):
|
|||
_handle_api_error(Exception("error code 429"), 0, 5, 1)
|
||||
|
||||
|
||||
def test_is_anthropic_claude():
|
||||
"""Test is_anthropic_claude function with various configurations."""
|
||||
# Test Anthropic provider cases
|
||||
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
|
||||
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
|
||||
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
|
||||
|
||||
# Test OpenRouter provider cases
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-2"}
|
||||
)
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-instant"}
|
||||
)
|
||||
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
|
||||
|
||||
# Test edge cases
|
||||
assert not is_anthropic_claude({}) # Empty config
|
||||
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
||||
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
||||
assert not is_anthropic_claude(
|
||||
{"provider": "other", "model": "claude-2"}
|
||||
) # Wrong provider
|
||||
|
||||
|
||||
def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repository):
|
||||
"""Test that run_agent_with_retry checks for crash status at the beginning of each iteration."""
|
||||
from ra_aid.agent_context import agent_context, mark_agent_crashed
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
"""Unit tests for model_detection.py."""
|
||||
|
||||
import pytest
|
||||
from ra_aid.model_detection import is_anthropic_claude, is_claude_37
|
||||
|
||||
|
||||
def test_is_anthropic_claude():
|
||||
"""Test is_anthropic_claude function with various configurations."""
|
||||
# Test Anthropic provider cases
|
||||
assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"})
|
||||
assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"})
|
||||
assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"})
|
||||
|
||||
# Test OpenRouter provider cases
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-2"}
|
||||
)
|
||||
assert is_anthropic_claude(
|
||||
{"provider": "openrouter", "model": "anthropic/claude-instant"}
|
||||
)
|
||||
assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"})
|
||||
|
||||
# Test edge cases
|
||||
assert not is_anthropic_claude({}) # Empty config
|
||||
assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model
|
||||
assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider
|
||||
assert not is_anthropic_claude(
|
||||
{"provider": "other", "model": "claude-2"}
|
||||
) # Wrong provider
|
||||
|
||||
|
||||
def test_is_claude_37():
|
||||
"""Test is_claude_37 function with various model names."""
|
||||
# Test positive cases
|
||||
assert is_claude_37("claude-3.7")
|
||||
assert is_claude_37("claude3.7")
|
||||
assert is_claude_37("claude-3-7")
|
||||
assert is_claude_37("anthropic/claude-3.7")
|
||||
assert is_claude_37("anthropic/claude3.7")
|
||||
assert is_claude_37("anthropic/claude-3-7")
|
||||
assert is_claude_37("claude-3.7-sonnet")
|
||||
assert is_claude_37("claude3.7-haiku")
|
||||
|
||||
# Test negative cases
|
||||
assert not is_claude_37("claude-3")
|
||||
assert not is_claude_37("claude-3.5")
|
||||
assert not is_claude_37("claude3.5")
|
||||
assert not is_claude_37("claude-3-5")
|
||||
assert not is_claude_37("gpt-4")
|
||||
assert not is_claude_37("")
|
||||
Loading…
Reference in New Issue