diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index d613041..53f4c43 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -5,24 +5,8 @@ import sys import uuid from datetime import datetime -# Add litellm import import litellm -# Configure litellm to suppress debug logs -os.environ["LITELLM_LOG"] = "ERROR" -litellm.suppress_debug_info = True -litellm.set_verbose = False - -# Explicitly configure LiteLLM's loggers -for logger_name in ["litellm", "LiteLLM"]: - litellm_logger = logging.getLogger(logger_name) - litellm_logger.setLevel(logging.WARNING) - litellm_logger.propagate = True - -# Use litellm's internal method to disable debugging -if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"): - litellm._logging._disable_debugging() - from langgraph.checkpoint.memory import MemorySaver from rich.console import Console from rich.panel import Panel @@ -99,6 +83,21 @@ from ra_aid.tools.human import ask_human logger = get_logger(__name__) +# Configure litellm to suppress debug logs +os.environ["LITELLM_LOG"] = "ERROR" +litellm.suppress_debug_info = True +litellm.set_verbose = False + +# Explicitly configure LiteLLM's loggers +for logger_name in ["litellm", "LiteLLM"]: + litellm_logger = logging.getLogger(logger_name) + litellm_logger.setLevel(logging.WARNING) + litellm_logger.propagate = True + +# Use litellm's internal method to disable debugging +if hasattr(litellm, "_logging") and hasattr(litellm._logging, "_disable_debugging"): + litellm._logging._disable_debugging() + def launch_webui(host: str, port: int): """Launch the RA.Aid web interface.""" diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 0abb164..1bb2ccd 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -51,7 +51,12 @@ from ra_aid.database.repositories.human_input_repository import ( ) from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository from ra_aid.database.repositories.config_repository import get_config_repository -from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit +from ra_aid.anthropic_token_limiter import ( + sonnet_35_state_modifier, + state_modifier, + get_model_token_limit, +) +from ra_aid.model_detection import is_anthropic_claude console = Console() @@ -67,8 +72,6 @@ def output_markdown_message(message: str) -> str: return "Message output." - - def build_agent_kwargs( checkpointer: Optional[Any] = None, model: ChatAnthropic = None, @@ -99,8 +102,13 @@ def build_agent_kwargs( ): def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: - if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]): - return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens) + if any( + pattern in model.model + for pattern in ["claude-3.5", "claude3.5", "claude-3-5"] + ): + return sonnet_35_state_modifier( + state, max_input_tokens=max_input_tokens + ) return state_modifier(state, model, max_input_tokens=max_input_tokens) @@ -110,27 +118,6 @@ def build_agent_kwargs( return agent_kwargs -def is_anthropic_claude(config: Dict[str, Any]) -> bool: - """Check if the provider and model name indicate an Anthropic Claude model. - - Args: - config: Configuration dictionary containing provider and model information - - Returns: - bool: True if this is an Anthropic Claude model - """ - # For backwards compatibility, allow passing of config directly - provider = config.get("provider", "") - model_name = config.get("model", "") - result = ( - provider.lower() == "anthropic" - and model_name - and "claude" in model_name.lower() - ) or ( - provider.lower() == "openrouter" - and model_name.lower().startswith("anthropic/claude-") - ) - return result def create_agent( @@ -169,7 +156,7 @@ def create_agent( # So we'll use the passed config directly pass max_input_tokens = ( - get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT + get_model_token_limit(config, agent_type, model) or DEFAULT_TOKEN_LIMIT ) # Use REACT agent for Anthropic Claude models, otherwise use CIAYN @@ -188,7 +175,7 @@ def create_agent( # Default to REACT agent if provider/model detection fails logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") config = get_config_repository().get_all() - max_input_tokens = get_model_token_limit(config, agent_type) + max_input_tokens = get_model_token_limit(config, agent_type, model) agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens) return create_react_agent( model, tools, interrupt_after=["tools"], **agent_kwargs @@ -289,7 +276,7 @@ def _handle_api_error(e, attempt, max_retries, base_delay): logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e)) delay = base_delay * (2**attempt) error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})" - + # Record error in trajectory trajectory_repo = get_trajectory_repository() human_input_id = get_human_input_repository().get_most_recent_id() @@ -301,9 +288,9 @@ def _handle_api_error(e, attempt, max_retries, base_delay): record_type="error", human_input_id=human_input_id, is_error=True, - error_message=error_message + error_message=error_message, ) - + print_error(error_message) start = time.monotonic() while time.monotonic() - start < delay: @@ -464,7 +451,9 @@ def run_agent_with_retry( try: _run_agent_stream(agent, msg_list) - if fallback_handler and hasattr(fallback_handler, 'reset_fallback_handler'): + if fallback_handler and hasattr( + fallback_handler, "reset_fallback_handler" + ): fallback_handler.reset_fallback_handler() should_break, prompt, auto_test, test_attempts = ( _execute_test_command_wrapper( diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index 45a79d4..ba7b890 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -1,7 +1,9 @@ """Utilities for handling token limits with Anthropic models.""" from functools import partial -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence, Union +from langchain_core.language_models import BaseChatModel +from ra_aid.model_detection import is_claude_37 from dataclasses import dataclass from langchain_anthropic import ChatAnthropic @@ -19,7 +21,7 @@ from ra_aid.anthropic_message_utils import ( has_tool_use, ) from langgraph.prebuilt.chat_agent_executor import AgentState -from litellm import token_counter +from litellm import token_counter, get_model_info from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.database.repositories.config_repository import get_config_repository @@ -168,14 +170,39 @@ def sonnet_35_state_modifier( return result +def adjust_claude_37_token_limit(max_input_tokens: int, model: Optional[BaseChatModel]) -> Optional[int]: + """Adjust token limit for Claude 3.7 models by subtracting max_tokens. + + Args: + max_input_tokens: The original token limit + model: The model instance to check + + Returns: + Optional[int]: Adjusted token limit if model is Claude 3.7, otherwise original limit + """ + if not max_input_tokens: + return max_input_tokens + + if model and hasattr(model, 'model') and is_claude_37(model.model): + if hasattr(model, 'max_tokens') and model.max_tokens: + effective_max_input_tokens = max_input_tokens - model.max_tokens + logger.debug( + f"Adjusting token limit for Claude 3.7 model: {max_input_tokens} - {model.max_tokens} = {effective_max_input_tokens}" + ) + return effective_max_input_tokens + + return max_input_tokens + + def get_model_token_limit( - config: Dict[str, Any], agent_type: str = "default" + config: Dict[str, Any], agent_type: str = "default", model: Optional[BaseChatModel] = None ) -> Optional[int]: """Get the token limit for the current model configuration based on agent type. Args: config: Configuration dictionary containing provider and model information agent_type: Type of agent ("default", "research", or "planner") + model: Optional BaseChatModel instance to check for model-specific attributes Returns: Optional[int]: The token limit if found, None otherwise @@ -201,7 +228,6 @@ def get_model_token_limit( model_name = config.get("model", "") try: - from litellm import get_model_info provider_model = model_name if not provider else f"{provider}/{model_name}" model_info = get_model_info(provider_model) @@ -210,7 +236,7 @@ def get_model_token_limit( logger.debug( f"Using litellm token limit for {model_name}: {max_input_tokens}" ) - return max_input_tokens + return adjust_claude_37_token_limit(max_input_tokens, model) except Exception as e: logger.debug( f"Error getting model info from litellm: {e}, falling back to models_params" @@ -229,7 +255,7 @@ def get_model_token_limit( max_input_tokens = None logger.debug(f"Could not find token limit for {provider}/{model_name}") - return max_input_tokens + return adjust_claude_37_token_limit(max_input_tokens, model) except Exception as e: logger.warning(f"Failed to get model token limit: {e}") diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 56ec811..3f60d6f 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -10,6 +10,7 @@ from openai import OpenAI from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.console.output import cpm from ra_aid.logging_config import get_logger +from ra_aid.model_detection import is_claude_37 from .models_params import models_params @@ -218,7 +219,6 @@ def create_llm_client( is_expert, ) - # Get model configuration model_config = models_params.get(provider, {}).get(model_name, {}) # Default to True for known providers that support temperature if not specified @@ -228,6 +228,10 @@ def create_llm_client( supports_temperature = model_config["supports_temperature"] supports_thinking = model_config.get("supports_thinking", False) + other_kwargs = {} + if is_claude_37(model_name): + other_kwargs = {"max_tokens": 64000} + # Handle temperature settings if is_expert: temp_kwargs = {"temperature": 0} if supports_temperature else {} @@ -235,22 +239,26 @@ def create_llm_client( if temperature is None: temperature = 0.7 # Import repository classes directly to avoid circular imports - from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository - from ra_aid.database.repositories.human_input_repository import HumanInputRepository + from ra_aid.database.repositories.trajectory_repository import ( + TrajectoryRepository, + ) + from ra_aid.database.repositories.human_input_repository import ( + HumanInputRepository, + ) from ra_aid.database.connection import get_db - + # Create repositories directly trajectory_repo = TrajectoryRepository(get_db()) human_input_repo = HumanInputRepository(get_db()) human_input_id = human_input_repo.get_most_recent_id() - + trajectory_repo.create( step_data={ "message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.", "display_title": "Information", }, record_type="info", - human_input_id=human_input_id + human_input_id=human_input_id, ) cpm( "This model supports temperature argument but none was given. Setting default temperature to 0.7." @@ -302,9 +310,9 @@ def create_llm_client( model_name=model_name, timeout=LLM_REQUEST_TIMEOUT, max_retries=LLM_MAX_RETRIES, - max_tokens=model_config.get("max_tokens", 64000), **temp_kwargs, **thinking_kwargs, + **other_kwargs, ) elif provider == "openai-compatible": return ChatOpenAI( diff --git a/ra_aid/model_detection.py b/ra_aid/model_detection.py new file mode 100644 index 0000000..ca45c51 --- /dev/null +++ b/ra_aid/model_detection.py @@ -0,0 +1,39 @@ +"""Utilities for detecting and working with specific model types.""" + +from typing import Optional, Dict, Any + + +def is_claude_37(model: str) -> bool: + """Check if the model is a Claude 3.7 model. + + Args: + model: The model name to check + + Returns: + bool: True if the model is a Claude 3.7 model, False otherwise + """ + patterns = ["claude-3.7", "claude3.7", "claude-3-7"] + return any(pattern in model for pattern in patterns) + + +def is_anthropic_claude(config: Dict[str, Any]) -> bool: + """Check if the provider and model name indicate an Anthropic Claude model. + + Args: + config: Configuration dictionary containing provider and model information + + Returns: + bool: True if this is an Anthropic Claude model + """ + # For backwards compatibility, allow passing of config directly + provider = config.get("provider", "") + model_name = config.get("model", "") + result = ( + provider.lower() == "anthropic" + and model_name + and "claude" in model_name.lower() + ) or ( + provider.lower() == "openrouter" + and model_name.lower().startswith("anthropic/claude-") + ) + return result diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index ff978a8..1d7ed50 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -13,8 +13,7 @@ from ra_aid.agent_context import ( ) from ra_aid.agent_utils import ( AgentState, - create_agent, - is_anthropic_claude, + create_agent ) from ra_aid.anthropic_token_limiter import ( get_model_token_limit, @@ -453,31 +452,6 @@ def test_handle_api_error_retry(monkeypatch): _handle_api_error(Exception("error code 429"), 0, 5, 1) -def test_is_anthropic_claude(): - """Test is_anthropic_claude function with various configurations.""" - # Test Anthropic provider cases - assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"}) - assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"}) - assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"}) - - # Test OpenRouter provider cases - assert is_anthropic_claude( - {"provider": "openrouter", "model": "anthropic/claude-2"} - ) - assert is_anthropic_claude( - {"provider": "openrouter", "model": "anthropic/claude-instant"} - ) - assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"}) - - # Test edge cases - assert not is_anthropic_claude({}) # Empty config - assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model - assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider - assert not is_anthropic_claude( - {"provider": "other", "model": "claude-2"} - ) # Wrong provider - - def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repository): """Test that run_agent_with_retry checks for crash status at the beginning of each iteration.""" from ra_aid.agent_context import agent_context, mark_agent_crashed diff --git a/tests/ra_aid/test_model_detection.py b/tests/ra_aid/test_model_detection.py new file mode 100644 index 0000000..4127c50 --- /dev/null +++ b/tests/ra_aid/test_model_detection.py @@ -0,0 +1,50 @@ +"""Unit tests for model_detection.py.""" + +import pytest +from ra_aid.model_detection import is_anthropic_claude, is_claude_37 + + +def test_is_anthropic_claude(): + """Test is_anthropic_claude function with various configurations.""" + # Test Anthropic provider cases + assert is_anthropic_claude({"provider": "anthropic", "model": "claude-2"}) + assert is_anthropic_claude({"provider": "ANTHROPIC", "model": "claude-instant"}) + assert not is_anthropic_claude({"provider": "anthropic", "model": "gpt-4"}) + + # Test OpenRouter provider cases + assert is_anthropic_claude( + {"provider": "openrouter", "model": "anthropic/claude-2"} + ) + assert is_anthropic_claude( + {"provider": "openrouter", "model": "anthropic/claude-instant"} + ) + assert not is_anthropic_claude({"provider": "openrouter", "model": "openai/gpt-4"}) + + # Test edge cases + assert not is_anthropic_claude({}) # Empty config + assert not is_anthropic_claude({"provider": "anthropic"}) # Missing model + assert not is_anthropic_claude({"model": "claude-2"}) # Missing provider + assert not is_anthropic_claude( + {"provider": "other", "model": "claude-2"} + ) # Wrong provider + + +def test_is_claude_37(): + """Test is_claude_37 function with various model names.""" + # Test positive cases + assert is_claude_37("claude-3.7") + assert is_claude_37("claude3.7") + assert is_claude_37("claude-3-7") + assert is_claude_37("anthropic/claude-3.7") + assert is_claude_37("anthropic/claude3.7") + assert is_claude_37("anthropic/claude-3-7") + assert is_claude_37("claude-3.7-sonnet") + assert is_claude_37("claude3.7-haiku") + + # Test negative cases + assert not is_claude_37("claude-3") + assert not is_claude_37("claude-3.5") + assert not is_claude_37("claude3.5") + assert not is_claude_37("claude-3-5") + assert not is_claude_37("gpt-4") + assert not is_claude_37("")