diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 3121e5d..d613041 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -39,35 +39,41 @@ from ra_aid.agents.research_agent import run_research_agent from ra_aid.agents import run_planning_agent from ra_aid.config import ( DEFAULT_MAX_TEST_CMD_RETRIES, + DEFAULT_MODEL, DEFAULT_RECURSION_LIMIT, DEFAULT_TEST_CMD_TIMEOUT, VALID_PROVIDERS, ) -from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository +from ra_aid.database.repositories.key_fact_repository import ( + KeyFactRepositoryManager, + get_key_fact_repository, +) from ra_aid.database.repositories.key_snippet_repository import ( - KeySnippetRepositoryManager, get_key_snippet_repository + KeySnippetRepositoryManager, + get_key_snippet_repository, ) from ra_aid.database.repositories.human_input_repository import ( - HumanInputRepositoryManager, get_human_input_repository + HumanInputRepositoryManager, + get_human_input_repository, ) from ra_aid.database.repositories.research_note_repository import ( - ResearchNoteRepositoryManager, get_research_note_repository + ResearchNoteRepositoryManager, + get_research_note_repository, ) from ra_aid.database.repositories.trajectory_repository import ( - TrajectoryRepositoryManager, get_trajectory_repository + TrajectoryRepositoryManager, + get_trajectory_repository, ) from ra_aid.database.repositories.session_repository import ( SessionRepositoryManager, get_session_repository ) from ra_aid.database.repositories.related_files_repository import ( - RelatedFilesRepositoryManager -) -from ra_aid.database.repositories.work_log_repository import ( - WorkLogRepositoryManager + RelatedFilesRepositoryManager, ) +from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager from ra_aid.database.repositories.config_repository import ( ConfigRepositoryManager, - get_config_repository + get_config_repository, ) from ra_aid.env_inv import EnvDiscovery from ra_aid.env_inv_context import EnvInvManager, get_env_inv @@ -103,9 +109,9 @@ def launch_webui(host: str, port: int): def parse_arguments(args=None): - ANTHROPIC_DEFAULT_MODEL = "claude-3-7-sonnet-20250219" + ANTHROPIC_DEFAULT_MODEL = DEFAULT_MODEL OPENAI_DEFAULT_MODEL = "gpt-4o" - + # Case-insensitive log level argument type def log_level_type(value): value = value.lower() @@ -202,8 +208,10 @@ Examples: help="Enable chat mode with direct human interaction (implies --hil)", ) parser.add_argument( - "--log-mode", choices=["console", "file"], default="file", - help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console" + "--log-mode", + choices=["console", "file"], + default="file", + help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console", ) parser.add_argument( "--pretty-logger", action="store_true", help="Enable pretty logging output" @@ -386,20 +394,20 @@ def is_stage_requested(stage: str) -> bool: def wipe_project_memory(): """Delete the project database file to wipe all stored memory. - + Returns: str: A message indicating the result of the operation """ import os from pathlib import Path - + cwd = os.getcwd() ra_aid_dir = Path(os.path.join(cwd, ".ra-aid")) db_path = os.path.join(ra_aid_dir, "pk.db") - + if not os.path.exists(db_path): return "No project memory found to wipe." - + try: os.remove(db_path) return "Project memory wiped successfully." @@ -411,11 +419,11 @@ def wipe_project_memory(): def build_status(): """Build status panel with model and feature information. - + Includes memory statistics at the bottom with counts of key facts, snippets, and research notes. """ status = Text() - + # Get the config repository to get model/provider information config_repo = get_config_repository() provider = config_repo.get("provider", "") @@ -423,12 +431,14 @@ def build_status(): temperature = config_repo.get("temperature") expert_provider = config_repo.get("expert_provider", "") expert_model = config_repo.get("expert_model", "") - experimental_fallback_handler = config_repo.get("experimental_fallback_handler", False) + experimental_fallback_handler = config_repo.get( + "experimental_fallback_handler", False + ) web_research_enabled = config_repo.get("web_research_enabled", False) - + # Get the expert enabled status expert_enabled = bool(expert_provider and expert_model) - + # Basic model information status.append("šŸ¤– ") status.append(f"{provider}/{model}") @@ -460,39 +470,41 @@ def build_status(): [fb_handler._format_model(m) for m in fb_handler.fallback_tool_models] ) status.append(msg) - + # Add memory statistics # Get counts of key facts, snippets, and research notes with error handling fact_count = 0 snippet_count = 0 note_count = 0 - + try: fact_count = len(get_key_fact_repository().get_all()) except RuntimeError as e: logger.debug(f"Failed to get key facts count: {e}") - + try: snippet_count = len(get_key_snippet_repository().get_all()) except RuntimeError as e: logger.debug(f"Failed to get key snippets count: {e}") - + try: note_count = len(get_research_note_repository().get_all()) except RuntimeError as e: logger.debug(f"Failed to get research notes count: {e}") - + # Add memory statistics line with reset option note - status.append(f"\nšŸ’¾ Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes") + status.append( + f"\nšŸ’¾ Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes" + ) if fact_count > 0 or snippet_count > 0 or note_count > 0: status.append(" (use --wipe-project-memory to reset)") - + # Check for newer version version_message = check_for_newer_version() if version_message: status.append("\n\n") status.append(version_message, style="yellow") - + return status @@ -501,7 +513,7 @@ def main(): args = parse_arguments() setup_logging(args.log_mode, args.pretty_logger, args.log_level) logger.debug("Starting RA.Aid with arguments: %s", args) - + # Check if we need to wipe project memory before starting if args.wipe_project_memory: result = wipe_project_memory() @@ -527,7 +539,7 @@ def main(): # Initialize empty config dictionary to be populated later config = {} - + # Initialize repositories with database connection # Create environment inventory data env_discovery = EnvDiscovery() @@ -568,7 +580,9 @@ def main(): expert_missing, web_research_enabled, web_research_missing, - ) = validate_environment(args) # Will exit if main env vars missing + ) = validate_environment( + args + ) # Will exit if main env vars missing logger.debug("Environment validation successful") # Validate model configuration early @@ -604,12 +618,16 @@ def main(): config_repo.set("expert_provider", args.expert_provider) config_repo.set("expert_model", args.expert_model) config_repo.set("temperature", args.temperature) - config_repo.set("experimental_fallback_handler", args.experimental_fallback_handler) + config_repo.set( + "experimental_fallback_handler", args.experimental_fallback_handler + ) config_repo.set("web_research_enabled", web_research_enabled) config_repo.set("show_thoughts", args.show_thoughts) config_repo.set("show_cost", args.show_cost) config_repo.set("force_reasoning_assistance", args.reasoning_assistance) - config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) + config_repo.set( + "disable_reasoning_assistance", args.no_reasoning_assistance + ) # Build status panel with memory statistics status = build_status() @@ -678,13 +696,15 @@ def main(): initial_request = ask_human.invoke( {"question": "What would you like help with?"} ) - + # Record chat input in database (redundant as ask_human already records it, # but needed in case the ask_human implementation changes) try: # Using get_human_input_repository() to access the repository from context human_input_repository = get_human_input_repository() - human_input_repository.create(content=initial_request, source='chat') + human_input_repository.create( + content=initial_request, source="chat" + ) human_input_repository.garbage_collect() except Exception as e: logger.error(f"Failed to record initial chat input: {str(e)}") @@ -742,8 +762,12 @@ def main(): ), working_directory=working_directory, current_date=current_date, - key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()), - key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), + key_facts=format_key_facts_dict( + get_key_fact_repository().get_facts_dict() + ), + key_snippets=format_key_snippets_dict( + get_key_snippet_repository().get_snippets_dict() + ), project_info=formatted_project_info, env_inv=get_env_inv(), ), @@ -775,12 +799,12 @@ def main(): sys.exit(1) base_task = args.message - + # Record CLI input in database try: # Using get_human_input_repository() to access the repository from context human_input_repository = get_human_input_repository() - human_input_repository.create(content=base_task, source='cli') + human_input_repository.create(content=base_task, source="cli") # Run garbage collection to ensure we don't exceed 100 inputs human_input_repository.garbage_collect() logger.debug(f"Recorded CLI input: {base_task}") @@ -814,19 +838,25 @@ def main(): config_repo.set("expert_model", args.expert_model) # Store planner config with fallback to base values - config_repo.set("planner_provider", args.planner_provider or args.provider) + config_repo.set( + "planner_provider", args.planner_provider or args.provider + ) config_repo.set("planner_model", args.planner_model or args.model) # Store research config with fallback to base values - config_repo.set("research_provider", args.research_provider or args.provider) + config_repo.set( + "research_provider", args.research_provider or args.provider + ) config_repo.set("research_model", args.research_model or args.model) # Store temperature in config config_repo.set("temperature", args.temperature) - + # Store reasoning assistance flags config_repo.set("force_reasoning_assistance", args.reasoning_assistance) - config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) + config_repo.set( + "disable_reasoning_assistance", args.no_reasoning_assistance + ) # Set modification tools based on use_aider flag set_modification_tools(args.use_aider) @@ -870,5 +900,6 @@ def main(): print() sys.exit(0) + if __name__ == "__main__": main() diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index bef5bd5..fa1ff62 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -1,19 +1,14 @@ """Utility functions for working with agents.""" -import inspect -import os import signal import sys import threading import time -import uuid -from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional +from langchain_anthropic import ChatAnthropic from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler - -import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from openai import RateLimitError as OpenAIRateLimitError from litellm.exceptions import RateLimitError as LiteLLMRateLimitError @@ -23,28 +18,24 @@ from langchain_core.messages import ( BaseMessage, HumanMessage, SystemMessage, - trim_messages, ) from langchain_core.tools import tool -from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from langgraph.prebuilt.chat_agent_executor import AgentState -from litellm import get_model_info from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel from ra_aid.agent_context import ( agent_context, - get_depth, is_completed, reset_completion_flags, should_exit, ) from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.agents_alias import RAgents -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT -from ra_aid.console.formatting import print_error, print_stage_header +from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES +from ra_aid.console.formatting import print_error from ra_aid.console.output import print_agent_output from ra_aid.exceptions import ( AgentInterrupt, @@ -53,77 +44,20 @@ from ra_aid.exceptions import ( ) from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger -from ra_aid.llm import initialize_expert_llm -from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params -from ra_aid.text.processing import process_thinking_content -from ra_aid.project_info import ( - display_project_status, - format_project_info, - get_project_info, -) -from ra_aid.prompts.expert_prompts import ( - EXPERT_PROMPT_SECTION_IMPLEMENTATION, - EXPERT_PROMPT_SECTION_PLANNING, - EXPERT_PROMPT_SECTION_RESEARCH, -) -from ra_aid.prompts.human_prompts import ( - HUMAN_PROMPT_SECTION_IMPLEMENTATION, - HUMAN_PROMPT_SECTION_PLANNING, - HUMAN_PROMPT_SECTION_RESEARCH, -) -from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT -from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS -from ra_aid.prompts.planning_prompts import PLANNING_PROMPT -from ra_aid.prompts.reasoning_assist_prompt import ( - REASONING_ASSIST_PROMPT_PLANNING, - REASONING_ASSIST_PROMPT_IMPLEMENTATION, - REASONING_ASSIST_PROMPT_RESEARCH, -) -from ra_aid.prompts.research_prompts import ( - RESEARCH_ONLY_PROMPT, - RESEARCH_PROMPT, -) -from ra_aid.prompts.web_research_prompts import ( - WEB_RESEARCH_PROMPT, - WEB_RESEARCH_PROMPT_SECTION_CHAT, - WEB_RESEARCH_PROMPT_SECTION_PLANNING, - WEB_RESEARCH_PROMPT_SECTION_RESEARCH, -) -from ra_aid.tool_configs import ( - get_implementation_tools, - get_planning_tools, - get_research_tools, - get_web_research_tools, -) +from ra_aid.models_params import DEFAULT_TOKEN_LIMIT from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command -from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository -from ra_aid.database.repositories.key_snippet_repository import ( - get_key_snippet_repository, -) from ra_aid.database.repositories.human_input_repository import ( get_human_input_repository, ) -from ra_aid.database.repositories.research_note_repository import ( - get_research_note_repository, -) from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository -from ra_aid.database.repositories.work_log_repository import get_work_log_repository -from ra_aid.model_formatters import format_key_facts_dict -from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict -from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict -from ra_aid.tools.memory import ( - get_related_files, - log_work_event, -) from ra_aid.database.repositories.config_repository import get_config_repository -from ra_aid.env_inv_context import get_env_inv +from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit console = Console() logger = get_logger(__name__) # Import repositories using get_* functions -from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository @tool @@ -133,131 +67,19 @@ def output_markdown_message(message: str) -> str: return "Message output." -def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: - """Helper function to estimate total tokens in a sequence of messages. - - Args: - messages: Sequence of messages to count tokens for - - Returns: - Total estimated token count - """ - if not messages: - return 0 - - estimate_tokens = CiaynAgent._estimate_tokens - return sum(estimate_tokens(msg) for msg in messages) - - -def state_modifier( - state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT -) -> list[BaseMessage]: - """Given the agent state and max_tokens, return a trimmed list of messages. - - Args: - state: The current agent state containing messages - max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) - - Returns: - list[BaseMessage]: Trimmed list of messages that fits within token limit - """ - messages = state["messages"] - - if not messages: - return [] - - first_message = messages[0] - remaining_messages = messages[1:] - first_tokens = estimate_messages_tokens([first_message]) - new_max_tokens = max_input_tokens - first_tokens - - trimmed_remaining = trim_messages( - remaining_messages, - token_counter=estimate_messages_tokens, - max_tokens=new_max_tokens, - strategy="last", - allow_partial=False, - ) - - return [first_message] + trimmed_remaining - - -def get_model_token_limit( - config: Dict[str, Any], agent_type: Literal["default", "research", "planner"] -) -> Optional[int]: - """Get the token limit for the current model configuration based on agent type. - - Returns: - Optional[int]: The token limit if found, None otherwise - """ - try: - # Try to get config from repository for production use - try: - config_from_repo = get_config_repository().get_all() - # If we succeeded, use the repository config instead of passed config - config = config_from_repo - except RuntimeError: - # In tests, this may fail because the repository isn't set up - # So we'll use the passed config directly - pass - if agent_type == "research": - provider = config.get("research_provider", "") or config.get("provider", "") - model_name = config.get("research_model", "") or config.get("model", "") - elif agent_type == "planner": - provider = config.get("planner_provider", "") or config.get("provider", "") - model_name = config.get("planner_model", "") or config.get("model", "") - else: - provider = config.get("provider", "") - model_name = config.get("model", "") - - try: - provider_model = model_name if not provider else f"{provider}/{model_name}" - model_info = get_model_info(provider_model) - max_input_tokens = model_info.get("max_input_tokens") - if max_input_tokens: - logger.debug( - f"Using litellm token limit for {model_name}: {max_input_tokens}" - ) - return max_input_tokens - except litellm.exceptions.NotFoundError: - logger.debug( - f"Model {model_name} not found in litellm, falling back to models_params" - ) - except Exception as e: - logger.debug( - f"Error getting model info from litellm: {e}, falling back to models_params" - ) - - # Fallback to models_params dict - # Normalize model name for fallback lookup (e.g. claude-2 -> claude2) - normalized_name = model_name.replace("-", "") - provider_tokens = models_params.get(provider, {}) - if normalized_name in provider_tokens: - max_input_tokens = provider_tokens[normalized_name]["token_limit"] - logger.debug( - f"Found token limit for {provider}/{model_name}: {max_input_tokens}" - ) - else: - max_input_tokens = None - logger.debug(f"Could not find token limit for {provider}/{model_name}") - - return max_input_tokens - - except Exception as e: - logger.warning(f"Failed to get model token limit: {e}") - return None def build_agent_kwargs( checkpointer: Optional[Any] = None, + model: ChatAnthropic = None, max_input_tokens: Optional[int] = None, ) -> Dict[str, Any]: """Build kwargs dictionary for agent creation. Args: checkpointer: Optional memory checkpointer - config: Optional configuration dictionary - token_limit: Optional token limit for the model + model: The language model to use for token counting + max_input_tokens: Optional token limit for the model Returns: Dictionary of kwargs for agent creation @@ -270,12 +92,20 @@ def build_agent_kwargs( agent_kwargs["checkpointer"] = checkpointer config = get_config_repository().get_all() - if config.get("limit_tokens", True) and is_anthropic_claude(config): + if ( + config.get("limit_tokens", True) + and is_anthropic_claude(config) + and model is not None + ): def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: - return state_modifier(state, max_input_tokens=max_input_tokens) + if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]): + return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens) + + return state_modifier(state, model, max_input_tokens=max_input_tokens) agent_kwargs["state_modifier"] = wrapped_state_modifier + agent_kwargs["name"] = "React" return agent_kwargs @@ -345,7 +175,8 @@ def create_agent( # Use REACT agent for Anthropic Claude models, otherwise use CIAYN if is_anthropic_claude(config): logger.debug("Using create_react_agent to instantiate agent.") - agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) + + agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens) return create_react_agent( model, tools, interrupt_after=["tools"], **agent_kwargs ) @@ -358,16 +189,12 @@ def create_agent( logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") config = get_config_repository().get_all() max_input_tokens = get_model_token_limit(config, agent_type) - agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) + agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens) return create_react_agent( model, tools, interrupt_after=["tools"], **agent_kwargs ) -from ra_aid.agents.research_agent import run_research_agent, run_web_research_agent -from ra_aid.agents.implementation_agent import run_task_implementation_agent - - _CONTEXT_STACK = [] _INTERRUPT_CONTEXT = None _FEEDBACK_MODE = False diff --git a/ra_aid/anthropic_message_utils.py b/ra_aid/anthropic_message_utils.py new file mode 100644 index 0000000..79f271e --- /dev/null +++ b/ra_aid/anthropic_message_utils.py @@ -0,0 +1,312 @@ +"""Utilities for handling Anthropic-specific message formats and trimming.""" + +from typing import Callable, List, Literal, Optional, Sequence, Union, cast + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + + +def _is_message_type( + message: BaseMessage, message_types: Union[str, type, List[Union[str, type]]] +) -> bool: + """Check if a message is of a specific type or types. + + Args: + message: The message to check + message_types: Type(s) to check against (string name or class) + + Returns: + bool: True if message matches any of the specified types + """ + if not isinstance(message_types, list): + message_types = [message_types] + + types_str = [t for t in message_types if isinstance(t, str)] + types_classes = tuple(t for t in message_types if isinstance(t, type)) + + return message.type in types_str or isinstance(message, types_classes) + + +def has_tool_use(message: BaseMessage) -> bool: + """Check if a message contains tool use. + + Args: + message: The message to check + + Returns: + bool: True if the message contains tool use + """ + if not isinstance(message, AIMessage): + return False + + # Check content for tool_use + if isinstance(message.content, str) and "tool_use" in message.content: + return True + + # Check content list for tool_use blocks + if isinstance(message.content, list): + for item in message.content: + if isinstance(item, dict) and item.get("type") == "tool_use": + return True + + # Check additional_kwargs for tool_calls + if hasattr(message, "additional_kwargs") and message.additional_kwargs.get( + "tool_calls" + ): + return True + + return False + + +def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool: + """Check if two messages form a tool use/result pair. + + Args: + message1: First message + message2: Second message + + Returns: + bool: True if the messages form a tool use/result pair + """ + return ( + isinstance(message1, AIMessage) + and isinstance(message2, ToolMessage) + and has_tool_use(message1) + ) + + + +def anthropic_trim_messages( + messages: Sequence[BaseMessage], + *, + max_tokens: int, + token_counter: Callable[[List[BaseMessage]], int], + strategy: Literal["first", "last"] = "last", + num_messages_to_keep: int = 2, + allow_partial: bool = False, + include_system: bool = True, + start_on: Optional[Union[str, type, List[Union[str, type]]]] = None, +) -> List[BaseMessage]: + """Trim messages to fit within a token limit, with Anthropic-specific handling. + + Warning - not fully implemented - last strategy is supported and test, not + allow partial, not 'first' strategy either. + This function is similar to langchain_core's trim_messages but with special + handling for Anthropic message formats to avoid API errors. + + It always keeps the first num_messages_to_keep messages. + + Args: + messages: Sequence of messages to trim + max_tokens: Maximum number of tokens allowed + token_counter: Function to count tokens in messages + strategy: Whether to keep the "first" or "last" messages + allow_partial: Whether to allow partial messages + include_system: Whether to always include the system message + start_on: Message type to start on (only for "last" strategy) + + Returns: + List[BaseMessage]: Trimmed messages that fit within token limit + """ + if not messages: + return [] + + messages = list(messages) + + # Always keep the first num_messages_to_keep messages + kept_messages = messages[:num_messages_to_keep] + remaining_msgs = messages[num_messages_to_keep:] + + + # For Anthropic, we need to maintain the conversation structure where: + # 1. Every AIMessage with tool_use must be followed by a ToolMessage + # 2. Every AIMessage that follows a ToolMessage must start with a tool_result + + # First, check if we have any tool_use in the messages + has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages) + + # If we have tool_use anywhere, we need to be very careful about trimming + if has_tool_use_anywhere: + # For safety, just keep all messages if we're under the token limit + if token_counter(messages) <= max_tokens: + return messages + + # We need to identify all tool_use/tool_result relationships + # First, find all AIMessage+ToolMessage pairs + pairs = [] + i = 0 + while i < len(messages) - 1: + if is_tool_pair(messages[i], messages[i + 1]): + pairs.append((i, i + 1)) + i += 2 + else: + i += 1 + + # For Anthropic, we need to ensure that: + # 1. If we include an AIMessage with tool_use, we must include the following ToolMessage + # 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use + + # The safest approach is to always keep complete AIMessage+ToolMessage pairs together + # First, identify all complete pairs + complete_pairs = [] + for start, end in pairs: + complete_pairs.append((start, end)) + + # Now we'll build our result, starting with the kept_messages + # But we need to be careful about the first message if it has tool_use + result = [] + + # Check if the last message in kept_messages has tool_use + if ( + kept_messages + and isinstance(kept_messages[-1], AIMessage) + and has_tool_use(kept_messages[-1]) + ): + # We need to find the corresponding ToolMessage + for i, (ai_idx, tool_idx) in enumerate(pairs): + if messages[ai_idx] is kept_messages[-1]: + # Found the pair, add all kept_messages except the last one + result.extend(kept_messages[:-1]) + # Add the AIMessage and ToolMessage as a pair + result.extend([messages[ai_idx], messages[tool_idx]]) + # Remove this pair from the list of pairs to process later + pairs = pairs[:i] + pairs[i + 1 :] + break + else: + # If we didn't find a matching pair, just add all kept_messages + result.extend(kept_messages) + else: + # No tool_use in the last kept message, just add all kept_messages + result.extend(kept_messages) + + # If we're using the "last" strategy, we'll try to include pairs from the end + if strategy == "last": + # First collect all pairs we can include within the token limit + pairs_to_include = [] + + # Process pairs from the end (newest first) + for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)): + # Try adding this pair + test_msgs = result.copy() + + # Add all previously selected pairs + for prev_ai_idx, prev_tool_idx in pairs_to_include: + test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]]) + + # Add this pair + test_msgs.extend([messages[ai_idx], messages[tool_idx]]) + + if token_counter(test_msgs) <= max_tokens: + # This pair fits, add it to our list + pairs_to_include.append((ai_idx, tool_idx)) + else: + # This pair would exceed the token limit + break + + # Now add the pairs in the correct order + # Sort by index to maintain the original conversation flow + pairs_to_include.sort(key=lambda x: x[0]) + for ai_idx, tool_idx in pairs_to_include: + result.extend([messages[ai_idx], messages[tool_idx]]) + + # No need to sort - we've already added messages in the correct order + + return result + + # If no tool_use, proceed with normal segmentation + segments = [] + i = 0 + + # Group messages into segments + while i < len(remaining_msgs): + segments.append([remaining_msgs[i]]) + i += 1 + + # Now we have segments that maintain the required structure + # We'll add segments from the end (for "last" strategy) or beginning (for "first") + # until we hit the token limit + + if strategy == "last": + # If we have no segments, just return kept_messages + if not segments: + return kept_messages + + result = [] + + # Process segments from the end + for i, segment in enumerate(reversed(segments)): + # Try adding this segment + test_msgs = segment + result + + if token_counter(kept_messages + test_msgs) <= max_tokens: + result = segment + result + else: + # This segment would exceed the token limit + break + + final_result = kept_messages + result + + # For Anthropic, we need to ensure the conversation follows a valid structure + # We'll do a final check of the entire conversation + + # Validate the conversation structure + valid_result = [] + i = 0 + + # Process messages in order + while i < len(final_result): + current_msg = final_result[i] + + # If this is an AIMessage with tool_use, it must be followed by a ToolMessage + if ( + i < len(final_result) - 1 + and isinstance(current_msg, AIMessage) + and has_tool_use(current_msg) + ): + if isinstance(final_result[i + 1], ToolMessage): + # This is a valid tool_use + tool_result pair + valid_result.append(current_msg) + valid_result.append(final_result[i + 1]) + i += 2 + else: + # Invalid: AIMessage with tool_use not followed by ToolMessage + # Skip this message to maintain valid structure + i += 1 + else: + # Regular message, just add it + valid_result.append(current_msg) + i += 1 + + # Final check: don't end with an AIMessage that has tool_use + if ( + valid_result + and isinstance(valid_result[-1], AIMessage) + and has_tool_use(valid_result[-1]) + ): + valid_result.pop() # Remove the last message + + return valid_result + + elif strategy == "first": + result = [] + + # Process segments from the beginning + for i, segment in enumerate(segments): + # Try adding this segment + test_msgs = result + segment + if token_counter(kept_messages + test_msgs) <= max_tokens: + result = result + segment + else: + # This segment would exceed the token limit + break + + final_result = kept_messages + result + + return final_result diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py new file mode 100644 index 0000000..45a79d4 --- /dev/null +++ b/ra_aid/anthropic_token_limiter.py @@ -0,0 +1,236 @@ +"""Utilities for handling token limits with Anthropic models.""" + +from functools import partial +from typing import Any, Dict, List, Optional, Sequence +from dataclasses import dataclass + +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import ( + AIMessage, + BaseMessage, + RemoveMessage, + ToolMessage, + trim_messages, +) +from langchain_core.messages.base import message_to_dict + +from ra_aid.anthropic_message_utils import ( + anthropic_trim_messages, + has_tool_use, +) +from langgraph.prebuilt.chat_agent_executor import AgentState +from litellm import token_counter + +from ra_aid.agent_backends.ciayn_agent import CiaynAgent +from ra_aid.database.repositories.config_repository import get_config_repository +from ra_aid.logging_config import get_logger +from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params +from ra_aid.console.output import cpm, print_messages_compact + +logger = get_logger(__name__) + + +def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: + """Helper function to estimate total tokens in a sequence of messages. + + Args: + messages: Sequence of messages to count tokens for + + Returns: + Total estimated token count + """ + if not messages: + return 0 + + estimate_tokens = CiaynAgent._estimate_tokens + return sum(estimate_tokens(msg) for msg in messages) + + +def convert_message_to_litellm_format(message: BaseMessage) -> Dict: + """Convert a BaseMessage to the format expected by litellm. + + Args: + message: The BaseMessage to convert + + Returns: + Dict in litellm format + """ + message_dict = message_to_dict(message) + return { + "role": message_dict["type"], + "content": message_dict["data"]["content"], + } + + +def create_token_counter_wrapper(model: str): + """Create a wrapper for token counter that handles BaseMessage conversion. + + Args: + model: The model name to use for token counting + + Returns: + A function that accepts BaseMessage objects and returns token count + """ + + # Create a partial function that already has the model parameter set + base_token_counter = partial(token_counter, model=model) + + def wrapped_token_counter(messages: List[BaseMessage]) -> int: + """Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage. + + Args: + messages: List of BaseMessage objects + + Returns: + Token count for the messages + """ + if not messages: + return 0 + + litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages] + result = base_token_counter(messages=litellm_messages) + return result + + return wrapped_token_counter + + +def state_modifier( + state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT +) -> list[BaseMessage]: + """Given the agent state and max_tokens, return a trimmed list of messages. + + This uses anthropic_trim_messages which always keeps the first 2 messages. + + Args: + state: The current agent state containing messages + model: The language model to use for token counting + max_input_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) + + Returns: + list[BaseMessage]: Trimmed list of messages that fits within token limit + """ + + messages = state["messages"] + if not messages: + return [] + + wrapped_token_counter = create_token_counter_wrapper(model.model) + + result = anthropic_trim_messages( + messages, + token_counter=wrapped_token_counter, + max_tokens=max_input_tokens, + strategy="last", + allow_partial=False, + include_system=True, + num_messages_to_keep=2, + ) + + if len(result) < len(messages): + logger.info(f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages") + + return result + + +def sonnet_35_state_modifier( + state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT +) -> list[BaseMessage]: + """Given the agent state and max_tokens, return a trimmed list of messages. + + Args: + state: The current agent state containing messages + max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) + + Returns: + list[BaseMessage]: Trimmed list of messages that fits within token limit + """ + messages = state["messages"] + + if not messages: + return [] + + first_message = messages[0] + remaining_messages = messages[1:] + first_tokens = estimate_messages_tokens([first_message]) + new_max_tokens = max_input_tokens - first_tokens + + trimmed_remaining = trim_messages( + remaining_messages, + token_counter=estimate_messages_tokens, + max_tokens=new_max_tokens, + strategy="last", + allow_partial=False, + include_system=True, + ) + + result = [first_message] + trimmed_remaining + + return result + + +def get_model_token_limit( + config: Dict[str, Any], agent_type: str = "default" +) -> Optional[int]: + """Get the token limit for the current model configuration based on agent type. + + Args: + config: Configuration dictionary containing provider and model information + agent_type: Type of agent ("default", "research", or "planner") + + Returns: + Optional[int]: The token limit if found, None otherwise + """ + try: + # Try to get config from repository for production use + try: + config_from_repo = get_config_repository().get_all() + # If we succeeded, use the repository config instead of passed config + config = config_from_repo + except RuntimeError: + # In tests, this may fail because the repository isn't set up + # So we'll use the passed config directly + pass + if agent_type == "research": + provider = config.get("research_provider", "") or config.get("provider", "") + model_name = config.get("research_model", "") or config.get("model", "") + elif agent_type == "planner": + provider = config.get("planner_provider", "") or config.get("provider", "") + model_name = config.get("planner_model", "") or config.get("model", "") + else: + provider = config.get("provider", "") + model_name = config.get("model", "") + + try: + from litellm import get_model_info + + provider_model = model_name if not provider else f"{provider}/{model_name}" + model_info = get_model_info(provider_model) + max_input_tokens = model_info.get("max_input_tokens") + if max_input_tokens: + logger.debug( + f"Using litellm token limit for {model_name}: {max_input_tokens}" + ) + return max_input_tokens + except Exception as e: + logger.debug( + f"Error getting model info from litellm: {e}, falling back to models_params" + ) + + # Fallback to models_params dict + # Normalize model name for fallback lookup (e.g. claude-2 -> claude2) + normalized_name = model_name.replace("-", "") + provider_tokens = models_params.get(provider, {}) + if normalized_name in provider_tokens: + max_input_tokens = provider_tokens[normalized_name]["token_limit"] + logger.debug( + f"Found token limit for {provider}/{model_name}: {max_input_tokens}" + ) + else: + max_input_tokens = None + logger.debug(f"Could not find token limit for {provider}/{model_name}") + + return max_input_tokens + + except Exception as e: + logger.warning(f"Failed to get model token limit: {e}") + return None diff --git a/ra_aid/config.py b/ra_aid/config.py index 936fc26..47afde7 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3 FALLBACK_TOOL_MODEL_LIMIT = 5 RETRY_FALLBACK_COUNT = 3 DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds +DEFAULT_MODEL="claude-3-7-sonnet-20250219" DEFAULT_SHOW_COST = False @@ -16,4 +17,4 @@ VALID_PROVIDERS = [ "openai-compatible", "deepseek", "gemini", -] \ No newline at end of file +] diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 68b0b85..b45aeba 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Sequence -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from rich.markdown import Markdown from rich.panel import Panel @@ -98,3 +98,57 @@ def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") - """ console.print(Panel(Markdown(message), title=title, border_style=border_style)) + + +def print_messages_compact(messages: Sequence[BaseMessage]) -> None: + """Print a compact representation of a list of messages. + + Warning: Used mainly for debugging purposes so do not delete if not referenced anywhere! + For all message types, only the first 30 characters of content are shown. + + Args: + messages: A sequence of BaseMessage objects to print + """ + if not messages: + console.print("[italic]No messages[/italic]") + return + + for i, msg in enumerate(messages): + msg_type = msg.__class__.__name__ + content = msg.content + + # Process content based on its type + if isinstance(content, str): + display_content = f"{content[:30]}..." if len(content) > 30 else content + elif isinstance(content, list): + # Handle structured content (list of content blocks) + content_preview = [] + for item in content[:2]: # Show first 2 items at most + if isinstance(item, dict): + if item.get("type") == "text": + text = item.get("text", "") + content_preview.append(f"text: {text[:20]}..." if len(text) > 20 else f"text: {text}") + elif item.get("type") == "tool_call": + tool_name = item.get("tool_call", {}).get("name", "unknown") + content_preview.append(f"tool_call: {tool_name}") + else: + content_preview.append(f"{item.get('type', 'unknown')}") + + if len(content) > 2: + content_preview.append(f"...({len(content)-2} more)") + + display_content = ", ".join(content_preview) + else: + display_content = str(content)[:30] + "..." if len(str(content)) > 30 else str(content) + + # Add additional tool message info if available + additional_info = [] + if hasattr(msg, "tool_call_id") and msg.tool_call_id: + additional_info.append(f"tool_call_id: {msg.tool_call_id}") + if hasattr(msg, "name") and msg.name: + additional_info.append(f"name: {msg.name}") + if hasattr(msg, "status") and msg.status: + additional_info.append(f"status: {msg.status}") + + info_str = f" ({', '.join(additional_info)})" if additional_info else "" + console.print(f"[{i}] [bold]{msg_type}{info_str}[/bold]: {display_content}") diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 2ad07f4..56ec811 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -259,8 +259,9 @@ def create_llm_client( else: temp_kwargs = {} + thinking_kwargs = {} if supports_thinking: - temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}} + thinking_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}} if provider == "deepseek": return create_deepseek_client( @@ -268,6 +269,7 @@ def create_llm_client( api_key=config["api_key"], base_url=config["base_url"], **temp_kwargs, + **thinking_kwargs, is_expert=is_expert, ) elif provider == "openrouter": @@ -275,6 +277,7 @@ def create_llm_client( model_name=model_name, api_key=config["api_key"], **temp_kwargs, + **thinking_kwargs, is_expert=is_expert, ) elif provider == "openai": @@ -301,6 +304,7 @@ def create_llm_client( max_retries=LLM_MAX_RETRIES, max_tokens=model_config.get("max_tokens", 64000), **temp_kwargs, + **thinking_kwargs, ) elif provider == "openai-compatible": return ChatOpenAI( @@ -310,6 +314,7 @@ def create_llm_client( timeout=LLM_REQUEST_TIMEOUT, max_retries=LLM_MAX_RETRIES, **temp_kwargs, + **thinking_kwargs, ) elif provider == "gemini": return ChatGoogleGenerativeAI( @@ -318,6 +323,7 @@ def create_llm_client( timeout=LLM_REQUEST_TIMEOUT, max_retries=LLM_MAX_RETRIES, **temp_kwargs, + **thinking_kwargs, ) else: raise ValueError(f"Unsupported provider: {provider}") diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index e066690..740ba80 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -14,8 +14,9 @@ from ra_aid.agent_context import ( is_crashed, reset_completion_flags, ) +from ra_aid.config import DEFAULT_MODEL from ra_aid.console.formatting import print_error, print_task_header -from ra_aid.database.repositories.human_input_repository import HumanInputRepository, get_human_input_repository +from ra_aid.database.repositories.human_input_repository import get_human_input_repository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.config_repository import get_config_repository @@ -385,7 +386,7 @@ def request_task_implementation(task_spec: str) -> str: config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), - config.get("model", "claude-3-5-sonnet-20241022"), + config.get("model",DEFAULT_MODEL), temperature=config.get("temperature"), ) @@ -552,7 +553,7 @@ def request_implementation(task_spec: str) -> str: config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), - config.get("model", "claude-3-5-sonnet-20241022"), + config.get("model", DEFAULT_MODEL), temperature=config.get("temperature"), ) @@ -685,4 +686,4 @@ def request_implementation(task_spec: str) -> str: # Join all parts into a single markdown string markdown_output = "".join(markdown_parts) - return markdown_output \ No newline at end of file + return markdown_output diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 13739c8..ff978a8 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -14,12 +14,18 @@ from ra_aid.agent_context import ( from ra_aid.agent_utils import ( AgentState, create_agent, - get_model_token_limit, is_anthropic_claude, +) +from ra_aid.anthropic_token_limiter import ( + get_model_token_limit, state_modifier, ) from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params -from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var +from ra_aid.database.repositories.config_repository import ( + ConfigRepositoryManager, + get_config_repository, + config_repo_var, +) @pytest.fixture @@ -32,154 +38,91 @@ def mock_model(): @pytest.fixture def mock_config_repository(): """Mock the ConfigRepository to avoid database operations during tests""" - with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var: + with patch( + "ra_aid.database.repositories.config_repository.config_repo_var" + ) as mock_repo_var: # Setup a mock repository mock_repo = MagicMock() - + # Create a dictionary to simulate config config = {} - + # Setup get method to return config values def get_config(key, default=None): return config.get(key, default) + mock_repo.get.side_effect = get_config - + # Setup get_all method to return all config values mock_repo.get_all.return_value = config - + # Setup set method to update config values def set_config(key, value): config[key] = value + mock_repo.set.side_effect = set_config - + # Setup update method to update multiple config values def update_config(update_dict): config.update(update_dict) + mock_repo.update.side_effect = update_config - + # Make the mock context var return our mock repo mock_repo_var.get.return_value = mock_repo - + yield mock_repo @pytest.fixture(autouse=True) def mock_trajectory_repository(): """Mock the TrajectoryRepository to avoid database operations during tests""" - with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var: + with patch( + "ra_aid.database.repositories.trajectory_repository.trajectory_repo_var" + ) as mock_repo_var: # Setup a mock repository mock_repo = MagicMock() - + # Setup create method to return a mock trajectory def mock_create(**kwargs): mock_trajectory = MagicMock() mock_trajectory.id = 1 return mock_trajectory + mock_repo.create.side_effect = mock_create - + # Make the mock context var return our mock repo mock_repo_var.get.return_value = mock_repo - + yield mock_repo @pytest.fixture(autouse=True) def mock_human_input_repository(): """Mock the HumanInputRepository to avoid database operations during tests""" - with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var: + with patch( + "ra_aid.database.repositories.human_input_repository.human_input_repo_var" + ) as mock_repo_var: # Setup a mock repository mock_repo = MagicMock() - + # Setup get_most_recent_id method to return a dummy ID mock_repo.get_most_recent_id.return_value = 1 - + # Make the mock context var return our mock repo mock_repo_var.get.return_value = mock_repo - + yield mock_repo -def test_get_model_token_limit_anthropic(mock_config_repository): - """Test get_model_token_limit with Anthropic model.""" - config = {"provider": "anthropic", "model": "claude2"} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] - - -def test_get_model_token_limit_openai(mock_config_repository): - """Test get_model_token_limit with OpenAI model.""" - config = {"provider": "openai", "model": "gpt-4"} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] - - -def test_get_model_token_limit_unknown(mock_config_repository): - """Test get_model_token_limit with unknown provider/model.""" - config = {"provider": "unknown", "model": "unknown-model"} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit is None - - -def test_get_model_token_limit_missing_config(mock_config_repository): - """Test get_model_token_limit with missing configuration.""" - config = {} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit is None - - -def test_get_model_token_limit_litellm_success(): - """Test get_model_token_limit successfully getting limit from litellm.""" - config = {"provider": "anthropic", "model": "claude-2"} - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.return_value = {"max_input_tokens": 100000} - token_limit = get_model_token_limit(config, "default") - assert token_limit == 100000 - - -def test_get_model_token_limit_litellm_not_found(): - """Test fallback to models_tokens when litellm raises NotFoundError.""" - config = {"provider": "anthropic", "model": "claude-2"} - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.side_effect = litellm.exceptions.NotFoundError( - message="Model not found", model="claude-2", llm_provider="anthropic" - ) - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] - - -def test_get_model_token_limit_litellm_error(): - """Test fallback to models_tokens when litellm raises other exceptions.""" - config = {"provider": "anthropic", "model": "claude-2"} - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.side_effect = Exception("Unknown error") - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] - - -def test_get_model_token_limit_unexpected_error(): - """Test returning None when unexpected errors occur.""" - config = None # This will cause an attribute error when accessed - - token_limit = get_model_token_limit(config, "default") - assert token_limit is None - - def test_create_agent_anthropic(mock_model, mock_config_repository): """Test create_agent with Anthropic Claude model.""" mock_config_repository.update({"provider": "anthropic", "model": "claude-2"}) - with patch("ra_aid.agent_utils.create_react_agent") as mock_react: + with ( + patch("ra_aid.agent_utils.create_react_agent") as mock_react, + patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier, + ): mock_react.return_value = "react_agent" agent = create_agent(mock_model, []) @@ -187,9 +130,10 @@ def test_create_agent_anthropic(mock_model, mock_config_repository): mock_react.assert_called_once_with( mock_model, [], - interrupt_after=['tools'], + interrupt_after=["tools"], version="v2", state_modifier=mock_react.call_args[1]["state_modifier"], + name="React", ) @@ -257,20 +201,7 @@ def mock_messages(): ] -def test_state_modifier(mock_messages): - """Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens.""" - state = AgentState(messages=mock_messages) - - with patch( - "ra_aid.agent_backends.ciayn_agent.CiaynAgent._estimate_tokens" - ) as mock_estimate: - mock_estimate.side_effect = lambda msg: 100 if msg else 0 - - result = state_modifier(state, max_input_tokens=250) - - assert len(result) < len(mock_messages) - assert isinstance(result[0], SystemMessage) - assert result[-1] == mock_messages[-1] +# This test has been moved to test_anthropic_token_limiter.py def test_create_agent_with_checkpointer(mock_model, mock_config_repository): @@ -291,17 +222,21 @@ def test_create_agent_with_checkpointer(mock_model, mock_config_repository): ) -def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository): +def test_create_agent_anthropic_token_limiting_enabled( + mock_model, mock_config_repository +): """Test create_agent sets up token limiting for Claude models when enabled.""" - mock_config_repository.update({ - "provider": "anthropic", - "model": "claude-2", - "limit_tokens": True, - }) + mock_config_repository.update( + { + "provider": "anthropic", + "model": "claude-2", + "limit_tokens": True, + } + ) with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, - patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, + patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit, ): mock_react.return_value = "react_agent" mock_limit.return_value = 100000 @@ -314,17 +249,21 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r assert callable(args[1]["state_modifier"]) -def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository): +def test_create_agent_anthropic_token_limiting_disabled( + mock_model, mock_config_repository +): """Test create_agent doesn't set up token limiting for Claude models when disabled.""" - mock_config_repository.update({ - "provider": "anthropic", - "model": "claude-2", - "limit_tokens": False, - }) + mock_config_repository.update( + { + "provider": "anthropic", + "model": "claude-2", + "limit_tokens": False, + } + ) with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, - patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, + patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit, ): mock_react.return_value = "react_agent" mock_limit.return_value = 100000 @@ -332,39 +271,12 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_ agent = create_agent(mock_model, []) assert agent == "react_agent" - mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2") + mock_react.assert_called_once_with( + mock_model, [], interrupt_after=["tools"], version="v2", name="React" + ) -def test_get_model_token_limit_research(mock_config_repository): - """Test get_model_token_limit with research provider and model.""" - config = { - "provider": "openai", - "model": "gpt-4", - "research_provider": "anthropic", - "research_model": "claude-2", - } - mock_config_repository.update(config) - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.return_value = {"max_input_tokens": 150000} - token_limit = get_model_token_limit(config, "research") - assert token_limit == 150000 - - -def test_get_model_token_limit_planner(mock_config_repository): - """Test get_model_token_limit with planner provider and model.""" - config = { - "provider": "openai", - "model": "gpt-4", - "planner_provider": "deepseek", - "planner_model": "dsm-1", - } - mock_config_repository.update(config) - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.return_value = {"max_input_tokens": 120000} - token_limit = get_model_token_limit(config, "planner") - assert token_limit == 120000 +# These tests have been moved to test_anthropic_token_limiter.py # New tests for private helper methods in agent_utils.py @@ -396,11 +308,11 @@ def test_agent_context_depth(): with agent_context() as ctx1: assert get_depth() == 0 # Root context has depth 0 assert ctx1.depth == 0 - + with agent_context() as ctx2: assert get_depth() == 1 # Nested context has depth 1 assert ctx2.depth == 1 - + with agent_context() as ctx3: assert get_depth() == 2 # Doubly nested context has depth 2 assert ctx3.depth == 2 @@ -418,7 +330,7 @@ def test_run_agent_stream(monkeypatch, mock_config_repository): class DummyAgent: def stream(self, input_data, cfg: dict): yield {"content": "chunk1"} - + def get_state(self, state_config=None): # Return an object with a next property set to None return State() @@ -469,28 +381,28 @@ def test_handle_api_error_valueerror(): # ValueError not containing "code" or rate limit phrases should be re-raised with pytest.raises(ValueError): _handle_api_error(ValueError("some unrelated error"), 0, 5, 1) - + # ValueError with "429" should be handled without raising _handle_api_error(ValueError("error code 429"), 0, 5, 1) - + # ValueError with "rate limit" phrase should be handled without raising _handle_api_error(ValueError("hit rate limit"), 0, 5, 1) - + # ValueError with "too many requests" phrase should be handled without raising _handle_api_error(ValueError("too many requests, try later"), 0, 5, 1) - + # ValueError with "quota exceeded" phrase should be handled without raising _handle_api_error(ValueError("quota exceeded for this month"), 0, 5, 1) def test_handle_api_error_status_code(): from ra_aid.agent_utils import _handle_api_error - + # Error with status_code=429 attribute should be handled without raising error_with_status = Exception("Rate limited") error_with_status.status_code = 429 _handle_api_error(error_with_status, 0, 5, 1) - + # Error with http_status=429 attribute should be handled without raising error_with_http_status = Exception("Too many requests") error_with_http_status.http_status = 429 @@ -499,16 +411,16 @@ def test_handle_api_error_status_code(): def test_handle_api_error_rate_limit_phrases(): from ra_aid.agent_utils import _handle_api_error - + # Generic exception with "rate limit" phrase should be handled without raising _handle_api_error(Exception("You have exceeded your rate limit"), 0, 5, 1) - + # Generic exception with "too many requests" phrase should be handled without raising _handle_api_error(Exception("Too many requests, please slow down"), 0, 5, 1) - + # Generic exception with "quota exceeded" phrase should be handled without raising _handle_api_error(Exception("API quota exceeded for this billing period"), 0, 5, 1) - + # Generic exception with "rate" and "limit" separate but in message should be handled _handle_api_error(Exception("You hit the rate at which we limit requests"), 0, 5, 1) @@ -629,7 +541,9 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repos assert "Agent has crashed: Test crash message" in result -def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_repository): +def test_run_agent_with_retry_handles_badrequest_error( + monkeypatch, mock_config_repository +): """Test that run_agent_with_retry properly handles BadRequestError as unretryable.""" from ra_aid.agent_context import agent_context, is_crashed from ra_aid.agent_utils import run_agent_with_retry @@ -687,7 +601,9 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_ assert is_crashed() -def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch, mock_config_repository): +def test_run_agent_with_retry_handles_api_badrequest_error( + monkeypatch, mock_config_repository +): """Test that run_agent_with_retry properly handles API BadRequestError as unretryable.""" # Import APIError from anthropic module and patch it on the agent_utils module @@ -758,7 +674,9 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch, mock_con def test_handle_api_error_resource_exhausted(): from google.api_core.exceptions import ResourceExhausted from ra_aid.agent_utils import _handle_api_error - + # ResourceExhausted exception should be handled without raising - resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).") - _handle_api_error(resource_exhausted_error, 0, 5, 1) \ No newline at end of file + resource_exhausted_error = ResourceExhausted( + "429 Resource has been exhausted (e.g. check quota)." + ) + _handle_api_error(resource_exhausted_error, 0, 5, 1) diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py new file mode 100644 index 0000000..36f9528 --- /dev/null +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -0,0 +1,507 @@ +import unittest +from unittest.mock import MagicMock, patch +import litellm + +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage +) +from langgraph.prebuilt.chat_agent_executor import AgentState + +from ra_aid.anthropic_token_limiter import ( + create_token_counter_wrapper, + estimate_messages_tokens, + get_model_token_limit, + state_modifier, + sonnet_35_state_modifier, + convert_message_to_litellm_format +) +from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair +from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT + + +class TestAnthropicTokenLimiter(unittest.TestCase): + def setUp(self): + from ra_aid.config import DEFAULT_MODEL + + self.mock_model = MagicMock(spec=ChatAnthropic) + self.mock_model.model = DEFAULT_MODEL + + # Sample messages for testing + self.system_message = SystemMessage(content="You are a helpful assistant.") + self.human_message = HumanMessage(content="Hello, can you help me with a task?") + self.ai_message = AIMessage(content="I'd be happy to help! What do you need?") + self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming + + # Create more messages for testing + self.extra_messages = [ + HumanMessage(content=f"Extra message {i}") for i in range(5) + ] + + # Mock state for testing state_modifier with many messages + self.state = AgentState( + messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages, + next=None, + ) + + # Create tool-related messages for testing + self.ai_with_tool_use = AIMessage( + content="I'll use a tool to help you", + additional_kwargs={"tool_calls": [{"name": "calculator", "input": {"expression": "2+2"}}]} + ) + self.tool_message = ToolMessage( + content="4", + tool_call_id="tool_call_1", + name="calculator" + ) + + def test_convert_message_to_litellm_format(self): + """Test conversion of BaseMessage to litellm format.""" + # Test human message + human_result = convert_message_to_litellm_format(self.human_message) + self.assertEqual(human_result["role"], "human") + self.assertEqual(human_result["content"], "Hello, can you help me with a task?") + + # Test system message + system_result = convert_message_to_litellm_format(self.system_message) + self.assertEqual(system_result["role"], "system") + self.assertEqual(system_result["content"], "You are a helpful assistant.") + + # Test AI message + ai_result = convert_message_to_litellm_format(self.ai_message) + self.assertEqual(ai_result["role"], "ai") + self.assertEqual(ai_result["content"], "I'd be happy to help! What do you need?") + + @patch("ra_aid.anthropic_token_limiter.token_counter") + def test_create_token_counter_wrapper(self, mock_token_counter): + from ra_aid.config import DEFAULT_MODEL + + # Setup mock return values + mock_token_counter.return_value = 50 + + # Create the wrapper + wrapper = create_token_counter_wrapper(DEFAULT_MODEL) + + # Test with BaseMessage objects + result = wrapper([self.human_message]) + self.assertEqual(result, 50) + + # Test with empty list + result = wrapper([]) + self.assertEqual(result, 0) + + # Verify the mock was called with the right parameters + mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL) + + @patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens") + def test_estimate_messages_tokens(self, mock_estimate_tokens): + # Setup mock to return different values for different messages + mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20 + + # Test with multiple messages + messages = [self.system_message, self.human_message] + result = estimate_messages_tokens(messages) + + # Should be sum of individual token counts (10 + 20) + self.assertEqual(result, 30) + + # Test with empty list + result = estimate_messages_tokens([]) + self.assertEqual(result, 0) + + @patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") + @patch("ra_aid.anthropic_token_limiter.print_messages_compact") + @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") + def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper): + # Setup a proper token counter function that returns integers + def token_counter(msgs): + # Return token count based on number of messages + return len(msgs) * 10 + + # Configure the mock to return our token counter + mock_create_wrapper.return_value = token_counter + + # Configure anthropic_trim_messages to return a subset of messages + mock_trim_messages.return_value = [self.system_message, self.human_message] + + # Call state_modifier with a max token limit of 50 + result = state_modifier(self.state, self.mock_model, max_input_tokens=50) + + # Should return what anthropic_trim_messages returned + self.assertEqual(result, [self.system_message, self.human_message]) + + # Verify the wrapper was created with the right model + mock_create_wrapper.assert_called_with(self.mock_model.model) + + # Verify anthropic_trim_messages was called with the right parameters + mock_trim_messages.assert_called_once() + + + def test_state_modifier_with_messages(self): + """Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens.""" + # Create a state with messages + messages = [ + SystemMessage(content="System prompt"), + HumanMessage(content="Human message 1"), + AIMessage(content="AI response 1"), + HumanMessage(content="Human message 2"), + AIMessage(content="AI response 2"), + ] + state = AgentState(messages=messages) + model = MagicMock(spec=ChatAnthropic) + model.model = "claude-3-opus-20240229" + + with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \ + patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \ + patch("ra_aid.anthropic_token_limiter.print_messages_compact"): + # Setup mock to return a fixed token count per message + mock_wrapper.return_value = lambda msgs: len(msgs) * 100 + # Setup mock to return a subset of messages + mock_trim.return_value = [messages[0], messages[-2], messages[-1]] + + result = state_modifier(state, model, max_input_tokens=250) + + # Should return what anthropic_trim_messages returned + self.assertEqual(len(result), 3) + self.assertEqual(result[0], messages[0]) # First message preserved + self.assertEqual(result[-1], messages[-1]) # Last message preserved + + def test_sonnet_35_state_modifier(self): + """Test the sonnet 35 state modifier function.""" + # Create a state with messages + state = {"messages": [self.system_message, self.human_message, self.ai_message]} + + # Test with empty messages + empty_state = {"messages": []} + + # Instead of patching trim_messages which has complex internal logic, + # we'll directly patch the sonnet_35_state_modifier's call to trim_messages + with patch("ra_aid.anthropic_token_limiter.trim_messages") as mock_trim: + # Setup mock to return our desired messages + mock_trim.return_value = [self.human_message, self.ai_message] + + # Test with empty messages + self.assertEqual(sonnet_35_state_modifier(empty_state), []) + + # Test with messages under the limit + result = sonnet_35_state_modifier(state, max_input_tokens=10000) + + # Should keep the first message and call trim_messages for the rest + self.assertEqual(len(result), 3) + self.assertEqual(result[0], self.system_message) + self.assertEqual(result[1:], [self.human_message, self.ai_message]) + + # Verify trim_messages was called with the right parameters + mock_trim.assert_called_once() + # We can check some of the key arguments + call_args = mock_trim.call_args[1] + # The actual value is based on the token estimation logic, not a hard-coded 9000 + self.assertIn("max_tokens", call_args) + self.assertEqual(call_args["strategy"], "last") + self.assertEqual(call_args["strategy"], "last") + self.assertEqual(call_args["allow_partial"], False) + self.assertEqual(call_args["include_system"], True) + + @patch("ra_aid.anthropic_token_limiter.get_config_repository") + @patch("litellm.get_model_info") + def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo): + from ra_aid.config import DEFAULT_MODEL + + # Setup mocks + mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL} + mock_get_config_repo.return_value.get_all.return_value = mock_config + + # Mock litellm's get_model_info to return a token limit + mock_get_model_info.return_value = {"max_input_tokens": 100000} + + # Test getting token limit + result = get_model_token_limit(mock_config) + self.assertEqual(result, 100000) + + # Verify get_model_info was called with the right model + mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}") + + def test_get_model_token_limit_research(self): + """Test get_model_token_limit with research provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "research_provider": "anthropic", + "research_model": "claude-2", + } + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.return_value = {"max_input_tokens": 150000} + token_limit = get_model_token_limit(config, "research") + self.assertEqual(token_limit, 150000) + # Verify get_model_info was called with the research model + mock_get_info.assert_called_with("anthropic/claude-2") + + def test_get_model_token_limit_planner(self): + """Test get_model_token_limit with planner provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "planner_provider": "deepseek", + "planner_model": "dsm-1", + } + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.return_value = {"max_input_tokens": 120000} + token_limit = get_model_token_limit(config, "planner") + self.assertEqual(token_limit, 120000) + # Verify get_model_info was called with the planner model + mock_get_info.assert_called_with("deepseek/dsm-1") + + @patch("ra_aid.anthropic_token_limiter.get_config_repository") + @patch("litellm.get_model_info") + def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo): + # Setup mocks + mock_config = {"provider": "anthropic", "model": "claude-2"} + mock_get_config_repo.return_value.get_all.return_value = mock_config + + # Make litellm's get_model_info raise an exception to test fallback + mock_get_model_info.side_effect = Exception("Model not found") + + # Test getting token limit from models_params fallback + with patch("ra_aid.anthropic_token_limiter.models_params", { + "anthropic": { + "claude2": {"token_limit": 100000} + } + }): + result = get_model_token_limit(mock_config) + self.assertEqual(result, 100000) + + @patch("ra_aid.anthropic_token_limiter.get_config_repository") + @patch("litellm.get_model_info") + def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo): + from ra_aid.config import DEFAULT_MODEL + + # Setup mocks for different agent types + mock_config = { + "provider": "anthropic", + "model": DEFAULT_MODEL, + "research_provider": "openai", + "research_model": "gpt-4", + "planner_provider": "anthropic", + "planner_model": "claude-3-sonnet-20240229" + } + mock_get_config_repo.return_value.get_all.return_value = mock_config + + # Mock different returns for different models + def model_info_side_effect(model_name): + if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name: + return {"max_input_tokens": 200000} + elif "gpt-4" in model_name: + return {"max_input_tokens": 8192} + elif "claude-3-sonnet" in model_name: + return {"max_input_tokens": 100000} + else: + raise Exception(f"Unknown model: {model_name}") + + mock_get_model_info.side_effect = model_info_side_effect + + # Test default agent type + result = get_model_token_limit(mock_config, "default") + self.assertEqual(result, 200000) + + # Test research agent type + result = get_model_token_limit(mock_config, "research") + self.assertEqual(result, 8192) + + # Test planner agent type + result = get_model_token_limit(mock_config, "planner") + self.assertEqual(result, 100000) + + def test_get_model_token_limit_anthropic(self): + """Test get_model_token_limit with Anthropic model.""" + config = {"provider": "anthropic", "model": "claude2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + + def test_get_model_token_limit_openai(self): + """Test get_model_token_limit with OpenAI model.""" + config = {"provider": "openai", "model": "gpt-4"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["openai"]["gpt-4"]["token_limit"]) + + def test_get_model_token_limit_unknown(self): + """Test get_model_token_limit with unknown provider/model.""" + config = {"provider": "unknown", "model": "unknown-model"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertIsNone(token_limit) + + def test_get_model_token_limit_missing_config(self): + """Test get_model_token_limit with missing configuration.""" + config = {} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertIsNone(token_limit) + + def test_get_model_token_limit_litellm_success(self): + """Test get_model_token_limit successfully getting limit from litellm.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.return_value = {"max_input_tokens": 100000} + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, 100000) + mock_get_info.assert_called_with("anthropic/claude-2") + + def test_get_model_token_limit_litellm_not_found(self): + """Test fallback to models_tokens when litellm raises NotFoundError.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.side_effect = litellm.exceptions.NotFoundError( + message="Model not found", model="claude-2", llm_provider="anthropic" + ) + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + + def test_get_model_token_limit_litellm_error(self): + """Test fallback to models_tokens when litellm raises other exceptions.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.side_effect = Exception("Unknown error") + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + + def test_get_model_token_limit_unexpected_error(self): + """Test returning None when unexpected errors occur.""" + config = None # This will cause an attribute error when accessed + + token_limit = get_model_token_limit(config, "default") + self.assertIsNone(token_limit) + + def test_has_tool_use(self): + """Test the has_tool_use function.""" + # Test with regular AI message + self.assertFalse(has_tool_use(self.ai_message)) + + # Test with AI message containing tool_use in string content + ai_with_tool_str = AIMessage(content="I'll use a tool_use to help you") + self.assertTrue(has_tool_use(ai_with_tool_str)) + + # Test with AI message containing tool_use in structured content + ai_with_tool_dict = AIMessage(content=[ + {"type": "text", "text": "I'll use a tool to help you"}, + {"type": "tool_use", "tool_use": {"name": "calculator", "input": {"expression": "2+2"}}} + ]) + self.assertTrue(has_tool_use(ai_with_tool_dict)) + + # Test with AI message containing tool_calls in additional_kwargs + self.assertTrue(has_tool_use(self.ai_with_tool_use)) + + # Test with non-AI message + self.assertFalse(has_tool_use(self.human_message)) + + def test_is_tool_pair(self): + """Test the is_tool_pair function.""" + # Test with valid tool pair + self.assertTrue(is_tool_pair(self.ai_with_tool_use, self.tool_message)) + + # Test with non-tool pair (wrong order) + self.assertFalse(is_tool_pair(self.tool_message, self.ai_with_tool_use)) + + # Test with non-tool pair (wrong types) + self.assertFalse(is_tool_pair(self.ai_message, self.human_message)) + + # Test with non-tool pair (AI message without tool use) + self.assertFalse(is_tool_pair(self.ai_message, self.tool_message)) + + @patch("ra_aid.anthropic_message_utils.has_tool_use") + def test_anthropic_trim_messages_with_tool_use(self, mock_has_tool_use): + """Test anthropic_trim_messages with a sequence of messages including tool use.""" + from ra_aid.anthropic_message_utils import anthropic_trim_messages + + # Setup mock for has_tool_use to return True for AI messages at even indices + def side_effect(msg): + if isinstance(msg, AIMessage) and hasattr(msg, 'test_index'): + return msg.test_index % 2 == 0 # Even indices have tool use + return False + + mock_has_tool_use.side_effect = side_effect + + # Create a sequence of alternating human and AI messages with tool use + messages = [] + + # Start with system message + system_msg = SystemMessage(content="You are a helpful assistant.") + messages.append(system_msg) + + # Add alternating human and AI messages with tool use + for i in range(8): + if i % 2 == 0: + # Human message + msg = HumanMessage(content=f"Human message {i}") + messages.append(msg) + else: + # AI message, every other one has tool use + ai_msg = AIMessage(content=f"AI message {i}") + # Add a test_index attribute to track position + ai_msg.test_index = i + messages.append(ai_msg) + + # If this AI message has tool use (even index), add a tool message after it + if i % 4 == 1: # 1, 5, etc. + tool_msg = ToolMessage( + content=f"Tool result {i}", + tool_call_id=f"tool_call_{i}", + name="test_tool" + ) + messages.append(tool_msg) + + # Define a token counter that returns a fixed value per message + def token_counter(msgs): + return len(msgs) * 1000 + + # Test with a token limit that will require trimming + result = anthropic_trim_messages( + messages, + token_counter=token_counter, + max_tokens=5000, # This will allow 5 messages + strategy="last", + allow_partial=False, + include_system=True, + num_messages_to_keep=2 # Keep system and first human message + ) + + # We should have kept the first 2 messages (system + human) + self.assertEqual(len(result), 5) # 2 kept + 3 more that fit in token limit + self.assertEqual(result[0], system_msg) + + # Verify that we don't have any AI messages with tool use that aren't followed by a tool message + for i in range(len(result) - 1): + if isinstance(result[i], AIMessage) and mock_has_tool_use(result[i]): + self.assertTrue(isinstance(result[i+1], ToolMessage), + f"AI message with tool use at index {i} not followed by ToolMessage") + + +if __name__ == "__main__": + unittest.main()