From 5c9a1e81d2bd8969452d46b5e4799a43e873da22 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Mar 2025 14:03:18 -0700 Subject: [PATCH 01/11] feat(main.py): refactor imports for better organization and readability feat(main.py): add DEFAULT_MODEL constant to centralize model configuration feat(main.py): enhance logging and error handling for better debugging feat(main.py): implement state_modifier for managing token limits in agent state feat(anthropic_token_limiter.py): create utilities for handling token limits with Anthropic models feat(output.py): add print_messages_compact function for debugging message output test(anthropic_token_limiter.py): add unit tests for token limit utilities and state management --- ra_aid/__main__.py | 153 ++++++++----- ra_aid/agent_utils.py | 218 ++----------------- ra_aid/anthropic_token_limiter.py | 210 ++++++++++++++++++ ra_aid/config.py | 1 + ra_aid/console/output.py | 58 ++++- ra_aid/llm.py | 9 +- ra_aid/tools/agent.py | 7 +- tests/ra_aid/test_anthropic_token_limiter.py | 198 +++++++++++++++++ 8 files changed, 592 insertions(+), 262 deletions(-) create mode 100644 ra_aid/anthropic_token_limiter.py create mode 100644 tests/ra_aid/test_anthropic_token_limiter.py diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index bf42ec8..837f261 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -39,32 +39,38 @@ from ra_aid.agents.research_agent import run_research_agent from ra_aid.agents import run_planning_agent from ra_aid.config import ( DEFAULT_MAX_TEST_CMD_RETRIES, + DEFAULT_MODEL, DEFAULT_RECURSION_LIMIT, DEFAULT_TEST_CMD_TIMEOUT, VALID_PROVIDERS, ) -from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository +from ra_aid.database.repositories.key_fact_repository import ( + KeyFactRepositoryManager, + get_key_fact_repository, +) from ra_aid.database.repositories.key_snippet_repository import ( - KeySnippetRepositoryManager, get_key_snippet_repository + KeySnippetRepositoryManager, + get_key_snippet_repository, ) from ra_aid.database.repositories.human_input_repository import ( - HumanInputRepositoryManager, get_human_input_repository + HumanInputRepositoryManager, + get_human_input_repository, ) from ra_aid.database.repositories.research_note_repository import ( - ResearchNoteRepositoryManager, get_research_note_repository + ResearchNoteRepositoryManager, + get_research_note_repository, ) from ra_aid.database.repositories.trajectory_repository import ( - TrajectoryRepositoryManager, get_trajectory_repository + TrajectoryRepositoryManager, + get_trajectory_repository, ) from ra_aid.database.repositories.related_files_repository import ( - RelatedFilesRepositoryManager -) -from ra_aid.database.repositories.work_log_repository import ( - WorkLogRepositoryManager + RelatedFilesRepositoryManager, ) +from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager from ra_aid.database.repositories.config_repository import ( ConfigRepositoryManager, - get_config_repository + get_config_repository, ) from ra_aid.env_inv import EnvDiscovery from ra_aid.env_inv_context import EnvInvManager, get_env_inv @@ -100,9 +106,9 @@ def launch_webui(host: str, port: int): def parse_arguments(args=None): - ANTHROPIC_DEFAULT_MODEL = "claude-3-7-sonnet-20250219" + ANTHROPIC_DEFAULT_MODEL = DEFAULT_MODEL OPENAI_DEFAULT_MODEL = "gpt-4o" - + # Case-insensitive log level argument type def log_level_type(value): value = value.lower() @@ -199,8 +205,10 @@ Examples: help="Enable chat mode with direct human interaction (implies --hil)", ) parser.add_argument( - "--log-mode", choices=["console", "file"], default="file", - help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console" + "--log-mode", + choices=["console", "file"], + default="file", + help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console", ) parser.add_argument( "--pretty-logger", action="store_true", help="Enable pretty logging output" @@ -378,20 +386,20 @@ def is_stage_requested(stage: str) -> bool: def wipe_project_memory(): """Delete the project database file to wipe all stored memory. - + Returns: str: A message indicating the result of the operation """ import os from pathlib import Path - + cwd = os.getcwd() ra_aid_dir = Path(os.path.join(cwd, ".ra-aid")) db_path = os.path.join(ra_aid_dir, "pk.db") - + if not os.path.exists(db_path): return "No project memory found to wipe." - + try: os.remove(db_path) return "Project memory wiped successfully." @@ -403,11 +411,11 @@ def wipe_project_memory(): def build_status(): """Build status panel with model and feature information. - + Includes memory statistics at the bottom with counts of key facts, snippets, and research notes. """ status = Text() - + # Get the config repository to get model/provider information config_repo = get_config_repository() provider = config_repo.get("provider", "") @@ -415,12 +423,14 @@ def build_status(): temperature = config_repo.get("temperature") expert_provider = config_repo.get("expert_provider", "") expert_model = config_repo.get("expert_model", "") - experimental_fallback_handler = config_repo.get("experimental_fallback_handler", False) + experimental_fallback_handler = config_repo.get( + "experimental_fallback_handler", False + ) web_research_enabled = config_repo.get("web_research_enabled", False) - + # Get the expert enabled status expert_enabled = bool(expert_provider and expert_model) - + # Basic model information status.append("šŸ¤– ") status.append(f"{provider}/{model}") @@ -452,39 +462,41 @@ def build_status(): [fb_handler._format_model(m) for m in fb_handler.fallback_tool_models] ) status.append(msg) - + # Add memory statistics # Get counts of key facts, snippets, and research notes with error handling fact_count = 0 snippet_count = 0 note_count = 0 - + try: fact_count = len(get_key_fact_repository().get_all()) except RuntimeError as e: logger.debug(f"Failed to get key facts count: {e}") - + try: snippet_count = len(get_key_snippet_repository().get_all()) except RuntimeError as e: logger.debug(f"Failed to get key snippets count: {e}") - + try: note_count = len(get_research_note_repository().get_all()) except RuntimeError as e: logger.debug(f"Failed to get research notes count: {e}") - + # Add memory statistics line with reset option note - status.append(f"\nšŸ’¾ Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes") + status.append( + f"\nšŸ’¾ Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes" + ) if fact_count > 0 or snippet_count > 0 or note_count > 0: status.append(" (use --wipe-project-memory to reset)") - + # Check for newer version version_message = check_for_newer_version() if version_message: status.append("\n\n") status.append(version_message, style="yellow") - + return status @@ -493,7 +505,7 @@ def main(): args = parse_arguments() setup_logging(args.log_mode, args.pretty_logger, args.log_level) logger.debug("Starting RA.Aid with arguments: %s", args) - + # Check if we need to wipe project memory before starting if args.wipe_project_memory: result = wipe_project_memory() @@ -519,22 +531,24 @@ def main(): # Initialize empty config dictionary to be populated later config = {} - + # Initialize repositories with database connection # Create environment inventory data env_discovery = EnvDiscovery() env_discovery.discover() env_data = env_discovery.format_markdown() - - with KeyFactRepositoryManager(db) as key_fact_repo, \ - KeySnippetRepositoryManager(db) as key_snippet_repo, \ - HumanInputRepositoryManager(db) as human_input_repo, \ - ResearchNoteRepositoryManager(db) as research_note_repo, \ - RelatedFilesRepositoryManager() as related_files_repo, \ - TrajectoryRepositoryManager(db) as trajectory_repo, \ - WorkLogRepositoryManager() as work_log_repo, \ - ConfigRepositoryManager(config) as config_repo, \ - EnvInvManager(env_data) as env_inv: + + with ( + KeyFactRepositoryManager(db) as key_fact_repo, + KeySnippetRepositoryManager(db) as key_snippet_repo, + HumanInputRepositoryManager(db) as human_input_repo, + ResearchNoteRepositoryManager(db) as research_note_repo, + RelatedFilesRepositoryManager() as related_files_repo, + TrajectoryRepositoryManager(db) as trajectory_repo, + WorkLogRepositoryManager() as work_log_repo, + ConfigRepositoryManager(config) as config_repo, + EnvInvManager(env_data) as env_inv, + ): # This initializes all repositories and makes them available via their respective get methods logger.debug("Initialized KeyFactRepository") logger.debug("Initialized KeySnippetRepository") @@ -554,7 +568,9 @@ def main(): expert_missing, web_research_enabled, web_research_missing, - ) = validate_environment(args) # Will exit if main env vars missing + ) = validate_environment( + args + ) # Will exit if main env vars missing logger.debug("Environment validation successful") # Validate model configuration early @@ -590,11 +606,15 @@ def main(): config_repo.set("expert_provider", args.expert_provider) config_repo.set("expert_model", args.expert_model) config_repo.set("temperature", args.temperature) - config_repo.set("experimental_fallback_handler", args.experimental_fallback_handler) + config_repo.set( + "experimental_fallback_handler", args.experimental_fallback_handler + ) config_repo.set("web_research_enabled", web_research_enabled) config_repo.set("show_thoughts", args.show_thoughts) config_repo.set("force_reasoning_assistance", args.reasoning_assistance) - config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) + config_repo.set( + "disable_reasoning_assistance", args.no_reasoning_assistance + ) # Build status panel with memory statistics status = build_status() @@ -633,13 +653,15 @@ def main(): initial_request = ask_human.invoke( {"question": "What would you like help with?"} ) - + # Record chat input in database (redundant as ask_human already records it, # but needed in case the ask_human implementation changes) try: # Using get_human_input_repository() to access the repository from context human_input_repository = get_human_input_repository() - human_input_repository.create(content=initial_request, source='chat') + human_input_repository.create( + content=initial_request, source="chat" + ) human_input_repository.garbage_collect() except Exception as e: logger.error(f"Failed to record initial chat input: {str(e)}") @@ -668,8 +690,12 @@ def main(): config_repo.set("expert_model", args.expert_model) config_repo.set("temperature", args.temperature) config_repo.set("show_thoughts", args.show_thoughts) - config_repo.set("force_reasoning_assistance", args.reasoning_assistance) - config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) + config_repo.set( + "force_reasoning_assistance", args.reasoning_assistance + ) + config_repo.set( + "disable_reasoning_assistance", args.no_reasoning_assistance + ) # Set modification tools based on use_aider flag set_modification_tools(args.use_aider) @@ -696,8 +722,12 @@ def main(): ), working_directory=working_directory, current_date=current_date, - key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()), - key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), + key_facts=format_key_facts_dict( + get_key_fact_repository().get_facts_dict() + ), + key_snippets=format_key_snippets_dict( + get_key_snippet_repository().get_snippets_dict() + ), project_info=formatted_project_info, env_inv=get_env_inv(), ), @@ -711,12 +741,12 @@ def main(): sys.exit(1) base_task = args.message - + # Record CLI input in database try: # Using get_human_input_repository() to access the repository from context human_input_repository = get_human_input_repository() - human_input_repository.create(content=base_task, source='cli') + human_input_repository.create(content=base_task, source="cli") # Run garbage collection to ensure we don't exceed 100 inputs human_input_repository.garbage_collect() logger.debug(f"Recorded CLI input: {base_task}") @@ -750,19 +780,25 @@ def main(): config_repo.set("expert_model", args.expert_model) # Store planner config with fallback to base values - config_repo.set("planner_provider", args.planner_provider or args.provider) + config_repo.set( + "planner_provider", args.planner_provider or args.provider + ) config_repo.set("planner_model", args.planner_model or args.model) # Store research config with fallback to base values - config_repo.set("research_provider", args.research_provider or args.provider) + config_repo.set( + "research_provider", args.research_provider or args.provider + ) config_repo.set("research_model", args.research_model or args.model) # Store temperature in config config_repo.set("temperature", args.temperature) - + # Store reasoning assistance flags config_repo.set("force_reasoning_assistance", args.reasoning_assistance) - config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) + config_repo.set( + "disable_reasoning_assistance", args.no_reasoning_assistance + ) # Set modification tools based on use_aider flag set_modification_tools(args.use_aider) @@ -794,5 +830,6 @@ def main(): print() sys.exit(0) + if __name__ == "__main__": main() diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index e325745..0b18d34 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -1,19 +1,14 @@ """Utility functions for working with agents.""" -import inspect -import os import signal import sys import threading import time -import uuid -from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional +from langchain_anthropic import ChatAnthropic from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler - -import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from openai import RateLimitError as OpenAIRateLimitError from litellm.exceptions import RateLimitError as LiteLLMRateLimitError @@ -23,28 +18,24 @@ from langchain_core.messages import ( BaseMessage, HumanMessage, SystemMessage, - trim_messages, ) from langchain_core.tools import tool -from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from langgraph.prebuilt.chat_agent_executor import AgentState -from litellm import get_model_info from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel from ra_aid.agent_context import ( agent_context, - get_depth, is_completed, reset_completion_flags, should_exit, ) from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.agents_alias import RAgents -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT -from ra_aid.console.formatting import print_error, print_stage_header +from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES +from ra_aid.console.formatting import print_error from ra_aid.console.output import print_agent_output from ra_aid.exceptions import ( AgentInterrupt, @@ -53,76 +44,16 @@ from ra_aid.exceptions import ( ) from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger -from ra_aid.llm import initialize_expert_llm -from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params -from ra_aid.text.processing import process_thinking_content -from ra_aid.project_info import ( - display_project_status, - format_project_info, - get_project_info, -) -from ra_aid.prompts.expert_prompts import ( - EXPERT_PROMPT_SECTION_IMPLEMENTATION, - EXPERT_PROMPT_SECTION_PLANNING, - EXPERT_PROMPT_SECTION_RESEARCH, -) -from ra_aid.prompts.human_prompts import ( - HUMAN_PROMPT_SECTION_IMPLEMENTATION, - HUMAN_PROMPT_SECTION_PLANNING, - HUMAN_PROMPT_SECTION_RESEARCH, -) -from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT -from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS -from ra_aid.prompts.planning_prompts import PLANNING_PROMPT -from ra_aid.prompts.reasoning_assist_prompt import ( - REASONING_ASSIST_PROMPT_PLANNING, - REASONING_ASSIST_PROMPT_IMPLEMENTATION, - REASONING_ASSIST_PROMPT_RESEARCH, -) -from ra_aid.prompts.research_prompts import ( - RESEARCH_ONLY_PROMPT, - RESEARCH_PROMPT, -) -from ra_aid.prompts.web_research_prompts import ( - WEB_RESEARCH_PROMPT, - WEB_RESEARCH_PROMPT_SECTION_CHAT, - WEB_RESEARCH_PROMPT_SECTION_PLANNING, - WEB_RESEARCH_PROMPT_SECTION_RESEARCH, -) -from ra_aid.tool_configs import ( - get_implementation_tools, - get_planning_tools, - get_research_tools, - get_web_research_tools, -) +from ra_aid.models_params import DEFAULT_TOKEN_LIMIT from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command -from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository -from ra_aid.database.repositories.key_snippet_repository import ( - get_key_snippet_repository, -) -from ra_aid.database.repositories.human_input_repository import ( - get_human_input_repository, -) -from ra_aid.database.repositories.research_note_repository import ( - get_research_note_repository, -) -from ra_aid.database.repositories.work_log_repository import get_work_log_repository -from ra_aid.model_formatters import format_key_facts_dict -from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict -from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict -from ra_aid.tools.memory import ( - get_related_files, - log_work_event, -) from ra_aid.database.repositories.config_repository import get_config_repository -from ra_aid.env_inv_context import get_env_inv +from ra_aid.anthropic_token_limiter import state_modifier, get_model_token_limit console = Console() logger = get_logger(__name__) # Import repositories using get_* functions -from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository @tool @@ -132,131 +63,19 @@ def output_markdown_message(message: str) -> str: return "Message output." -def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: - """Helper function to estimate total tokens in a sequence of messages. - - Args: - messages: Sequence of messages to count tokens for - - Returns: - Total estimated token count - """ - if not messages: - return 0 - - estimate_tokens = CiaynAgent._estimate_tokens - return sum(estimate_tokens(msg) for msg in messages) - - -def state_modifier( - state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT -) -> list[BaseMessage]: - """Given the agent state and max_tokens, return a trimmed list of messages. - - Args: - state: The current agent state containing messages - max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) - - Returns: - list[BaseMessage]: Trimmed list of messages that fits within token limit - """ - messages = state["messages"] - - if not messages: - return [] - - first_message = messages[0] - remaining_messages = messages[1:] - first_tokens = estimate_messages_tokens([first_message]) - new_max_tokens = max_input_tokens - first_tokens - - trimmed_remaining = trim_messages( - remaining_messages, - token_counter=estimate_messages_tokens, - max_tokens=new_max_tokens, - strategy="last", - allow_partial=False, - ) - - return [first_message] + trimmed_remaining - - -def get_model_token_limit( - config: Dict[str, Any], agent_type: Literal["default", "research", "planner"] -) -> Optional[int]: - """Get the token limit for the current model configuration based on agent type. - - Returns: - Optional[int]: The token limit if found, None otherwise - """ - try: - # Try to get config from repository for production use - try: - config_from_repo = get_config_repository().get_all() - # If we succeeded, use the repository config instead of passed config - config = config_from_repo - except RuntimeError: - # In tests, this may fail because the repository isn't set up - # So we'll use the passed config directly - pass - if agent_type == "research": - provider = config.get("research_provider", "") or config.get("provider", "") - model_name = config.get("research_model", "") or config.get("model", "") - elif agent_type == "planner": - provider = config.get("planner_provider", "") or config.get("provider", "") - model_name = config.get("planner_model", "") or config.get("model", "") - else: - provider = config.get("provider", "") - model_name = config.get("model", "") - - try: - provider_model = model_name if not provider else f"{provider}/{model_name}" - model_info = get_model_info(provider_model) - max_input_tokens = model_info.get("max_input_tokens") - if max_input_tokens: - logger.debug( - f"Using litellm token limit for {model_name}: {max_input_tokens}" - ) - return max_input_tokens - except litellm.exceptions.NotFoundError: - logger.debug( - f"Model {model_name} not found in litellm, falling back to models_params" - ) - except Exception as e: - logger.debug( - f"Error getting model info from litellm: {e}, falling back to models_params" - ) - - # Fallback to models_params dict - # Normalize model name for fallback lookup (e.g. claude-2 -> claude2) - normalized_name = model_name.replace("-", "") - provider_tokens = models_params.get(provider, {}) - if normalized_name in provider_tokens: - max_input_tokens = provider_tokens[normalized_name]["token_limit"] - logger.debug( - f"Found token limit for {provider}/{model_name}: {max_input_tokens}" - ) - else: - max_input_tokens = None - logger.debug(f"Could not find token limit for {provider}/{model_name}") - - return max_input_tokens - - except Exception as e: - logger.warning(f"Failed to get model token limit: {e}") - return None def build_agent_kwargs( checkpointer: Optional[Any] = None, + model: ChatAnthropic = None, max_input_tokens: Optional[int] = None, ) -> Dict[str, Any]: """Build kwargs dictionary for agent creation. Args: checkpointer: Optional memory checkpointer - config: Optional configuration dictionary - token_limit: Optional token limit for the model + model: The language model to use for token counting + max_input_tokens: Optional token limit for the model Returns: Dictionary of kwargs for agent creation @@ -269,12 +88,17 @@ def build_agent_kwargs( agent_kwargs["checkpointer"] = checkpointer config = get_config_repository().get_all() - if config.get("limit_tokens", True) and is_anthropic_claude(config): + if ( + config.get("limit_tokens", True) + and is_anthropic_claude(config) + and model is not None + ): def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: - return state_modifier(state, max_input_tokens=max_input_tokens) + return state_modifier(state, model, max_input_tokens=max_input_tokens) agent_kwargs["state_modifier"] = wrapped_state_modifier + agent_kwargs["name"] = "React" return agent_kwargs @@ -340,11 +164,13 @@ def create_agent( max_input_tokens = ( get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT ) + print(f"max_input_tokens={max_input_tokens}") # Use REACT agent for Anthropic Claude models, otherwise use CIAYN if is_anthropic_claude(config): logger.debug("Using create_react_agent to instantiate agent.") - agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) + + agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens) return create_react_agent( model, tools, interrupt_after=["tools"], **agent_kwargs ) @@ -357,16 +183,12 @@ def create_agent( logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") config = get_config_repository().get_all() max_input_tokens = get_model_token_limit(config, agent_type) - agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) + agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens) return create_react_agent( model, tools, interrupt_after=["tools"], **agent_kwargs ) -from ra_aid.agents.research_agent import run_research_agent, run_web_research_agent -from ra_aid.agents.implementation_agent import run_task_implementation_agent - - _CONTEXT_STACK = [] _INTERRUPT_CONTEXT = None _FEEDBACK_MODE = False diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py new file mode 100644 index 0000000..97f235c --- /dev/null +++ b/ra_aid/anthropic_token_limiter.py @@ -0,0 +1,210 @@ +"""Utilities for handling token limits with Anthropic models.""" + +from functools import partial +from typing import Any, Dict, List, Optional, Sequence, Union + +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import BaseMessage, trim_messages +from langchain_core.messages.base import messages_to_dict +from langgraph.prebuilt.chat_agent_executor import AgentState +from litellm import token_counter + +from ra_aid.agent_backends.ciayn_agent import CiaynAgent +from ra_aid.database.repositories.config_repository import get_config_repository +from ra_aid.logging_config import get_logger +from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params +from ra_aid.console.output import print_messages_compact + +logger = get_logger(__name__) + + +def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: + """Helper function to estimate total tokens in a sequence of messages. + + Args: + messages: Sequence of messages to count tokens for + + Returns: + Total estimated token count + """ + if not messages: + return 0 + + estimate_tokens = CiaynAgent._estimate_tokens + return sum(estimate_tokens(msg) for msg in messages) + + +def create_token_counter_wrapper(model: str): + """Create a wrapper for token counter that handles BaseMessage conversion. + + Args: + model: The model name to use for token counting + + Returns: + A function that accepts BaseMessage objects and returns token count + """ + + # Create a partial function that already has the model parameter set + base_token_counter = partial(token_counter, model=model) + + def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int: + """Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage. + + Args: + messages: List of messages (either BaseMessage objects or dicts) + + Returns: + Token count for the messages + """ + if not messages: + return 0 + + if isinstance(messages[0], BaseMessage): + messages_dicts = [msg["data"] for msg in messages_to_dict(messages)] + return base_token_counter(messages=messages_dicts) + else: + # Already in dict format + return base_token_counter(messages=messages) + + return wrapped_token_counter + + +def state_modifier( + state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT +) -> list[BaseMessage]: + """Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message. + + Args: + state: The current agent state containing messages + model: The language model to use for token counting + max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) + + Returns: + list[BaseMessage]: Trimmed list of messages that fits within token limit + """ + messages = state["messages"] + + if not messages: + return [] + + first_message = messages[0] + remaining_messages = messages[1:] + + + wrapped_token_counter = create_token_counter_wrapper(model.model) + + first_tokens = wrapped_token_counter([first_message]) + new_max_tokens = max_input_tokens - first_tokens + + print_messages_compact(messages) + + trimmed_remaining = trim_messages( + remaining_messages, + token_counter=wrapped_token_counter, + max_tokens=new_max_tokens, + strategy="last", + allow_partial=False, + ) + + return [first_message] + trimmed_remaining + + +def sonnet_3_5_state_modifier( + state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT +) -> list[BaseMessage]: + """Given the agent state and max_tokens, return a trimmed list of messages. + + Args: + state: The current agent state containing messages + max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) + + Returns: + list[BaseMessage]: Trimmed list of messages that fits within token limit + """ + messages = state["messages"] + + if not messages: + return [] + + first_message = messages[0] + remaining_messages = messages[1:] + first_tokens = estimate_messages_tokens([first_message]) + new_max_tokens = max_input_tokens - first_tokens + + trimmed_remaining = trim_messages( + remaining_messages, + token_counter=estimate_messages_tokens, + max_tokens=new_max_tokens, + strategy="last", + allow_partial=False, + ) + + return [first_message] + trimmed_remaining + + +def get_model_token_limit( + config: Dict[str, Any], agent_type: str = "default" +) -> Optional[int]: + """Get the token limit for the current model configuration based on agent type. + + Args: + config: Configuration dictionary containing provider and model information + agent_type: Type of agent ("default", "research", or "planner") + + Returns: + Optional[int]: The token limit if found, None otherwise + """ + try: + # Try to get config from repository for production use + try: + config_from_repo = get_config_repository().get_all() + # If we succeeded, use the repository config instead of passed config + config = config_from_repo + except RuntimeError: + # In tests, this may fail because the repository isn't set up + # So we'll use the passed config directly + pass + if agent_type == "research": + provider = config.get("research_provider", "") or config.get("provider", "") + model_name = config.get("research_model", "") or config.get("model", "") + elif agent_type == "planner": + provider = config.get("planner_provider", "") or config.get("provider", "") + model_name = config.get("planner_model", "") or config.get("model", "") + else: + provider = config.get("provider", "") + model_name = config.get("model", "") + + try: + from litellm import get_model_info + + provider_model = model_name if not provider else f"{provider}/{model_name}" + model_info = get_model_info(provider_model) + max_input_tokens = model_info.get("max_input_tokens") + if max_input_tokens: + logger.debug( + f"Using litellm token limit for {model_name}: {max_input_tokens}" + ) + return max_input_tokens + except Exception as e: + logger.debug( + f"Error getting model info from litellm: {e}, falling back to models_params" + ) + + # Fallback to models_params dict + # Normalize model name for fallback lookup (e.g. claude-2 -> claude2) + normalized_name = model_name.replace("-", "") + provider_tokens = models_params.get(provider, {}) + if normalized_name in provider_tokens: + max_input_tokens = provider_tokens[normalized_name]["token_limit"] + logger.debug( + f"Found token limit for {provider}/{model_name}: {max_input_tokens}" + ) + else: + max_input_tokens = None + logger.debug(f"Could not find token limit for {provider}/{model_name}") + + return max_input_tokens + + except Exception as e: + logger.warning(f"Failed to get model token limit: {e}") + return None diff --git a/ra_aid/config.py b/ra_aid/config.py index c414f3b..ac69365 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3 FALLBACK_TOOL_MODEL_LIMIT = 5 RETRY_FALLBACK_COUNT = 3 DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds +DEFAULT_MODEL="claude-3-7-sonnet-20250219" VALID_PROVIDERS = [ diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 8a45fec..14ef069 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Sequence -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from rich.markdown import Markdown from rich.panel import Panel @@ -94,3 +94,57 @@ def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") - """ console.print(Panel(Markdown(message), title=title, border_style=border_style)) + + +def print_messages_compact(messages: Sequence[BaseMessage]) -> None: + """Print a compact representation of a list of messages. + + Warning: Used mainly for debugging purposes so do not delete if not referenced anywhere! + For all message types, only the first 30 characters of content are shown. + + Args: + messages: A sequence of BaseMessage objects to print + """ + if not messages: + console.print("[italic]No messages[/italic]") + return + + for i, msg in enumerate(messages): + msg_type = msg.__class__.__name__ + content = msg.content + + # Process content based on its type + if isinstance(content, str): + display_content = f"{content[:30]}..." if len(content) > 30 else content + elif isinstance(content, list): + # Handle structured content (list of content blocks) + content_preview = [] + for item in content[:2]: # Show first 2 items at most + if isinstance(item, dict): + if item.get("type") == "text": + text = item.get("text", "") + content_preview.append(f"text: {text[:20]}..." if len(text) > 20 else f"text: {text}") + elif item.get("type") == "tool_call": + tool_name = item.get("tool_call", {}).get("name", "unknown") + content_preview.append(f"tool_call: {tool_name}") + else: + content_preview.append(f"{item.get('type', 'unknown')}") + + if len(content) > 2: + content_preview.append(f"...({len(content)-2} more)") + + display_content = ", ".join(content_preview) + else: + display_content = str(content)[:30] + "..." if len(str(content)) > 30 else str(content) + + # Add additional tool message info if available + additional_info = [] + if hasattr(msg, "tool_call_id") and msg.tool_call_id: + additional_info.append(f"tool_call_id: {msg.tool_call_id}") + if hasattr(msg, "name") and msg.name: + additional_info.append(f"name: {msg.name}") + if hasattr(msg, "status") and msg.status: + additional_info.append(f"status: {msg.status}") + + info_str = f" ({', '.join(additional_info)})" if additional_info else "" + console.print(f"[{i}] [bold]{msg_type}{info_str}[/bold]: {display_content}") diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 3091260..baedf02 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -241,8 +241,9 @@ def create_llm_client( else: temp_kwargs = {} + thinking_kwargs = {} if supports_thinking: - temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}} + thinking_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}} if provider == "deepseek": return create_deepseek_client( @@ -250,6 +251,7 @@ def create_llm_client( api_key=config["api_key"], base_url=config["base_url"], **temp_kwargs, + **thinking_kwargs, is_expert=is_expert, ) elif provider == "openrouter": @@ -257,6 +259,7 @@ def create_llm_client( model_name=model_name, api_key=config["api_key"], **temp_kwargs, + **thinking_kwargs, is_expert=is_expert, ) elif provider == "openai": @@ -271,6 +274,7 @@ def create_llm_client( return ChatOpenAI( **{ **openai_kwargs, + **thinking_kwargs, "timeout": LLM_REQUEST_TIMEOUT, "max_retries": LLM_MAX_RETRIES, } @@ -283,6 +287,7 @@ def create_llm_client( max_retries=LLM_MAX_RETRIES, max_tokens=model_config.get("max_tokens", 64000), **temp_kwargs, + **thinking_kwargs, ) elif provider == "openai-compatible": return ChatOpenAI( @@ -292,6 +297,7 @@ def create_llm_client( timeout=LLM_REQUEST_TIMEOUT, max_retries=LLM_MAX_RETRIES, **temp_kwargs, + **thinking_kwargs, ) elif provider == "gemini": return ChatGoogleGenerativeAI( @@ -300,6 +306,7 @@ def create_llm_client( timeout=LLM_REQUEST_TIMEOUT, max_retries=LLM_MAX_RETRIES, **temp_kwargs, + **thinking_kwargs, ) else: raise ValueError(f"Unsupported provider: {provider}") diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 26190e3..c2c09be 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -14,6 +14,7 @@ from ra_aid.agent_context import ( is_crashed, reset_completion_flags, ) +from ra_aid.config import DEFAULT_MODEL from ra_aid.console.formatting import print_error from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository @@ -337,7 +338,7 @@ def request_task_implementation(task_spec: str) -> str: config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), - config.get("model", "claude-3-5-sonnet-20241022"), + config.get("model",DEFAULT_MODEL), temperature=config.get("temperature"), ) @@ -475,7 +476,7 @@ def request_implementation(task_spec: str) -> str: config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), - config.get("model", "claude-3-5-sonnet-20241022"), + config.get("model", DEFAULT_MODEL), temperature=config.get("temperature"), ) @@ -592,4 +593,4 @@ def request_implementation(task_spec: str) -> str: # Join all parts into a single markdown string markdown_output = "".join(markdown_parts) - return markdown_output \ No newline at end of file + return markdown_output diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py new file mode 100644 index 0000000..933c73e --- /dev/null +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -0,0 +1,198 @@ +import unittest +from unittest.mock import MagicMock, patch + +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langgraph.prebuilt.chat_agent_executor import AgentState + +from ra_aid.anthropic_token_limiter import ( + create_token_counter_wrapper, + estimate_messages_tokens, + get_model_token_limit, + state_modifier, +) + + +class TestAnthropicTokenLimiter(unittest.TestCase): + def setUp(self): + from ra_aid.config import DEFAULT_MODEL + + self.mock_model = MagicMock(spec=ChatAnthropic) + self.mock_model.model = DEFAULT_MODEL + + # Sample messages for testing + self.system_message = SystemMessage(content="You are a helpful assistant.") + self.human_message = HumanMessage(content="Hello, can you help me with a task?") + self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming + + # Create more messages for testing + self.extra_messages = [ + HumanMessage(content=f"Extra message {i}") for i in range(5) + ] + + # Mock state for testing state_modifier with many messages + self.state = AgentState( + messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages, + next=None, + ) + + @patch("ra_aid.anthropic_token_limiter.token_counter") + def test_create_token_counter_wrapper(self, mock_token_counter): + from ra_aid.config import DEFAULT_MODEL + + # Setup mock return values + mock_token_counter.return_value = 50 + + # Create the wrapper + wrapper = create_token_counter_wrapper(DEFAULT_MODEL) + + # Test with BaseMessage objects + result = wrapper([self.human_message]) + self.assertEqual(result, 50) + + # Test with empty list + result = wrapper([]) + self.assertEqual(result, 0) + + # Verify the mock was called with the right parameters + mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL) + + @patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens") + def test_estimate_messages_tokens(self, mock_estimate_tokens): + # Setup mock to return different values for different messages + mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20 + + # Test with multiple messages + messages = [self.system_message, self.human_message] + result = estimate_messages_tokens(messages) + + # Should be sum of individual token counts (10 + 20) + self.assertEqual(result, 30) + + # Test with empty list + result = estimate_messages_tokens([]) + self.assertEqual(result, 0) + + @patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") + @patch("ra_aid.anthropic_token_limiter.print_messages_compact") + def test_state_modifier(self, mock_print, mock_create_wrapper): + # Setup a proper token counter function that returns integers + # This function needs to return values that will cause trim_messages to keep only the first message + def token_counter(msgs): + # For a single message, return a small token count + if len(msgs) == 1: + return 10 + # For two messages (first + one more), return a value under our limit + elif len(msgs) == 2: + return 30 # This is under our 40 token remaining budget (50-10) + # For three messages, return a value just under our limit + elif len(msgs) == 3: + return 40 # This is exactly at our 40 token remaining budget (50-10) + # For four messages, return a value just at our limit + elif len(msgs) == 4: + return 40 # This is exactly at our 40 token remaining budget (50-10) + # For five messages, return a value that exceeds our 40 token budget + elif len(msgs) == 5: + return 60 # This exceeds our 40 token budget, forcing only 4 more messages + # For more messages, return a value over our limit + else: + return 100 # This exceeds our limit + + # Don't use side_effect here, directly return the function + mock_create_wrapper.return_value = token_counter + + # Call state_modifier with a max token limit of 50 + result = state_modifier(self.state, self.mock_model, max_input_tokens=50) + + # Should keep first message and some of the others (up to 5 total) + self.assertEqual(len(result), 5) # First message plus four more + self.assertEqual(result[0], self.system_message) # First message is preserved + + # Verify the wrapper was created with the right model + mock_create_wrapper.assert_called_with(self.mock_model.model) + + # Verify print_messages_compact was called + mock_print.assert_called_once() + + @patch("ra_aid.anthropic_token_limiter.get_config_repository") + @patch("litellm.get_model_info") + def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo): + from ra_aid.config import DEFAULT_MODEL + + # Setup mocks + mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL} + mock_get_config_repo.return_value.get_all.return_value = mock_config + + # Mock litellm's get_model_info to return a token limit + mock_get_model_info.return_value = {"max_input_tokens": 100000} + + # Test getting token limit + result = get_model_token_limit(mock_config) + self.assertEqual(result, 100000) + + # Verify get_model_info was called with the right model + mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}") + + @patch("ra_aid.anthropic_token_limiter.get_config_repository") + @patch("litellm.get_model_info") + def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo): + # Setup mocks + mock_config = {"provider": "anthropic", "model": "claude-2"} + mock_get_config_repo.return_value.get_all.return_value = mock_config + + # Make litellm's get_model_info raise an exception to test fallback + mock_get_model_info.side_effect = Exception("Model not found") + + # Test getting token limit from models_params fallback + with patch("ra_aid.anthropic_token_limiter.models_params", { + "anthropic": { + "claude2": {"token_limit": 100000} + } + }): + result = get_model_token_limit(mock_config) + self.assertEqual(result, 100000) + + @patch("ra_aid.anthropic_token_limiter.get_config_repository") + @patch("litellm.get_model_info") + def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo): + from ra_aid.config import DEFAULT_MODEL + + # Setup mocks for different agent types + mock_config = { + "provider": "anthropic", + "model": DEFAULT_MODEL, + "research_provider": "openai", + "research_model": "gpt-4", + "planner_provider": "anthropic", + "planner_model": "claude-3-sonnet-20240229" + } + mock_get_config_repo.return_value.get_all.return_value = mock_config + + # Mock different returns for different models + def model_info_side_effect(model_name): + if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name: + return {"max_input_tokens": 200000} + elif "gpt-4" in model_name: + return {"max_input_tokens": 8192} + elif "claude-3-sonnet" in model_name: + return {"max_input_tokens": 100000} + else: + raise Exception(f"Unknown model: {model_name}") + + mock_get_model_info.side_effect = model_info_side_effect + + # Test default agent type + result = get_model_token_limit(mock_config, "default") + self.assertEqual(result, 200000) + + # Test research agent type + result = get_model_token_limit(mock_config, "research") + self.assertEqual(result, 8192) + + # Test planner agent type + result = get_model_token_limit(mock_config, "planner") + self.assertEqual(result, 100000) + + +if __name__ == "__main__": + unittest.main() From 09ba1ee0b9adf7b4ae1769c064e6df1e57d92076 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Mar 2025 21:26:57 -0700 Subject: [PATCH 02/11] refactor(anthropic_token_limiter.py): rename messages_to_dict to message_to_dict for consistency and clarity feat(anthropic_token_limiter.py): add convert_message_to_litellm_format function to standardize message format for litellm fix(anthropic_token_limiter.py): update wrapped_token_counter to handle only BaseMessage objects and improve token counting logic chore(anthropic_token_limiter.py): add debug print statements to track token counts before and after trimming messages --- ra_aid/anthropic_token_limiter.py | 98 ++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 22 deletions(-) diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index 97f235c..fc6d3ac 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union from langchain_anthropic import ChatAnthropic from langchain_core.messages import BaseMessage, trim_messages -from langchain_core.messages.base import messages_to_dict +from langchain_core.messages.base import message_to_dict from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import token_counter @@ -34,38 +34,51 @@ def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: return sum(estimate_tokens(msg) for msg in messages) +def convert_message_to_litellm_format(message: BaseMessage) -> Dict: + """Convert a BaseMessage to the format expected by litellm. + + Args: + message: The BaseMessage to convert + + Returns: + Dict in litellm format + """ + message_dict = message_to_dict(message) + return { + "role": message_dict["type"], + "content": message_dict["data"]["content"], + } + + def create_token_counter_wrapper(model: str): """Create a wrapper for token counter that handles BaseMessage conversion. - + Args: model: The model name to use for token counting - + Returns: A function that accepts BaseMessage objects and returns token count """ - + # Create a partial function that already has the model parameter set base_token_counter = partial(token_counter, model=model) - - def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int: + + def wrapped_token_counter(messages: List[BaseMessage]) -> int: """Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage. - + Args: - messages: List of messages (either BaseMessage objects or dicts) - + messages: List of BaseMessage objects + Returns: Token count for the messages """ if not messages: return 0 - - if isinstance(messages[0], BaseMessage): - messages_dicts = [msg["data"] for msg in messages_to_dict(messages)] - return base_token_counter(messages=messages_dicts) - else: - # Already in dict format - return base_token_counter(messages=messages) - + + litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages] + result = base_token_counter(messages=litellm_messages) + return result + return wrapped_token_counter @@ -90,12 +103,31 @@ def state_modifier( first_message = messages[0] remaining_messages = messages[1:] - wrapped_token_counter = create_token_counter_wrapper(model.model) - + + print(f"max_input_tokens={max_input_tokens}") + max_input_tokens = 17000 first_tokens = wrapped_token_counter([first_message]) + print(f"first_tokens={first_tokens}") new_max_tokens = max_input_tokens - first_tokens + # Calculate total tokens before trimming + total_tokens_before = wrapped_token_counter(messages) + print( + f"Current token total: {total_tokens_before} (should be at least {first_tokens})" + ) + + # Verify the token count is correct + if total_tokens_before < first_tokens: + print(f"WARNING: Token count inconsistency detected! Recounting...") + # Count message by message to debug + for i, msg in enumerate(messages): + msg_tokens = wrapped_token_counter([msg]) + print(f" Message {i}: {msg_tokens} tokens") + # Try alternative counting method + alt_count = sum(wrapped_token_counter([msg]) for msg in messages) + print(f" Alternative count method: {alt_count} tokens") + print_messages_compact(messages) trimmed_remaining = trim_messages( @@ -106,10 +138,19 @@ def state_modifier( allow_partial=False, ) - return [first_message] + trimmed_remaining + result = [first_message] + trimmed_remaining + + # Only show message if some messages were trimmed + if len(result) < len(messages): + print(f"TRIMMED: {len(messages)} messages → {len(result)} messages") + # Calculate total tokens after trimming + total_tokens_after = wrapped_token_counter(result) + print(f"New token total: {total_tokens_after}") + + return result -def sonnet_3_5_state_modifier( +def sonnet_35_state_modifier( state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT ) -> list[BaseMessage]: """Given the agent state and max_tokens, return a trimmed list of messages. @@ -131,6 +172,10 @@ def sonnet_3_5_state_modifier( first_tokens = estimate_messages_tokens([first_message]) new_max_tokens = max_input_tokens - first_tokens + # Calculate total tokens before trimming + total_tokens_before = estimate_messages_tokens(messages) + print(f"Current token total: {total_tokens_before}") + trimmed_remaining = trim_messages( remaining_messages, token_counter=estimate_messages_tokens, @@ -139,7 +184,16 @@ def sonnet_3_5_state_modifier( allow_partial=False, ) - return [first_message] + trimmed_remaining + result = [first_message] + trimmed_remaining + + # Only show message if some messages were trimmed + if len(result) < len(messages): + print(f"TRIMMED: {len(messages)} messages → {len(result)} messages") + # Calculate total tokens after trimming + total_tokens_after = estimate_messages_tokens(result) + print(f"New token total: {total_tokens_after}") + + return result def get_model_token_limit( From ee73c85b02123c51c201ed3d9c7f4f181f9d7e80 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Mar 2025 23:24:57 -0700 Subject: [PATCH 03/11] feat(anthropic_message_utils.py): add utilities for handling Anthropic-specific message formats and trimming to improve message processing fix(agent_utils.py): remove debug print statement for max_input_tokens to clean up code refactor(anthropic_token_limiter.py): update state_modifier to use anthropic_trim_messages for better token management and maintain message structure --- ra_aid/agent_utils.py | 1 - ra_aid/anthropic_message_utils.py | 393 ++++++++++++++++++++++++++++++ ra_aid/anthropic_token_limiter.py | 80 +++--- 3 files changed, 433 insertions(+), 41 deletions(-) create mode 100644 ra_aid/anthropic_message_utils.py diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 0b18d34..d0f26cb 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -164,7 +164,6 @@ def create_agent( max_input_tokens = ( get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT ) - print(f"max_input_tokens={max_input_tokens}") # Use REACT agent for Anthropic Claude models, otherwise use CIAYN if is_anthropic_claude(config): diff --git a/ra_aid/anthropic_message_utils.py b/ra_aid/anthropic_message_utils.py new file mode 100644 index 0000000..91c285f --- /dev/null +++ b/ra_aid/anthropic_message_utils.py @@ -0,0 +1,393 @@ +"""Utilities for handling Anthropic-specific message formats and trimming.""" + +from typing import Callable, List, Literal, Optional, Sequence, Union, cast + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + + +def _is_message_type( + message: BaseMessage, message_types: Union[str, type, List[Union[str, type]]] +) -> bool: + """Check if a message is of a specific type or types. + + Args: + message: The message to check + message_types: Type(s) to check against (string name or class) + + Returns: + bool: True if message matches any of the specified types + """ + if not isinstance(message_types, list): + message_types = [message_types] + + types_str = [t for t in message_types if isinstance(t, str)] + types_classes = tuple(t for t in message_types if isinstance(t, type)) + + return message.type in types_str or isinstance(message, types_classes) + + +def has_tool_use(message: BaseMessage) -> bool: + """Check if a message contains tool use. + + Args: + message: The message to check + + Returns: + bool: True if the message contains tool use + """ + if not isinstance(message, AIMessage): + return False + + # Check content for tool_use + if isinstance(message.content, str) and "tool_use" in message.content: + return True + + # Check content list for tool_use blocks + if isinstance(message.content, list): + for item in message.content: + if isinstance(item, dict) and item.get("type") == "tool_use": + return True + + # Check additional_kwargs for tool_calls + if hasattr(message, "additional_kwargs") and message.additional_kwargs.get("tool_calls"): + return True + + return False + + +def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool: + """Check if two messages form a tool use/result pair. + + Args: + message1: First message + message2: Second message + + Returns: + bool: True if the messages form a tool use/result pair + """ + return ( + isinstance(message1, AIMessage) and + isinstance(message2, ToolMessage) and + has_tool_use(message1) + ) + + +def fix_anthropic_message_content(message: BaseMessage) -> BaseMessage: + """Fix message content format for Anthropic API compatibility.""" + if not isinstance(message, AIMessage) or not isinstance(message.content, list): + return message + + fixed_message = message.model_copy(deep=True) + + # Ensure first block is valid thinking type + if fixed_message.content and isinstance(fixed_message.content[0], dict): + first_block_type = fixed_message.content[0].get("type") + if first_block_type not in ("thinking", "redacted_thinking"): + # Prepend redacted_thinking block instead of thinking + fixed_message.content.insert( + 0, + { + "type": "redacted_thinking", + "data": "ENCRYPTED_REASONING", # Required field for redacted_thinking + }, + ) + + # Ensure all thinking blocks have valid structure + for i, block in enumerate(fixed_message.content): + if block.get("type") == "thinking": + # Convert thinking blocks to redacted_thinking to avoid signature validation + fixed_message.content[i] = { + "type": "redacted_thinking", + "data": "ENCRYPTED_REASONING", + } + elif block.get("type") == "redacted_thinking": + # Ensure required data field exists + if "data" not in block: + fixed_message.content[i]["data"] = "ENCRYPTED_REASONING" + + return fixed_message + + +def anthropic_trim_messages( + messages: Sequence[BaseMessage], + *, + max_tokens: int, + token_counter: Callable[[List[BaseMessage]], int], + strategy: Literal["first", "last"] = "last", + num_messages_to_keep: int = 2, + allow_partial: bool = False, + include_system: bool = True, + start_on: Optional[Union[str, type, List[Union[str, type]]]] = None, +) -> List[BaseMessage]: + """Trim messages to fit within a token limit, with Anthropic-specific handling. + + This function is similar to langchain_core's trim_messages but with special + handling for Anthropic message formats to avoid API errors. + + It always keeps the first num_messages_to_keep messages. + + Args: + messages: Sequence of messages to trim + max_tokens: Maximum number of tokens allowed + token_counter: Function to count tokens in messages + strategy: Whether to keep the "first" or "last" messages + allow_partial: Whether to allow partial messages + include_system: Whether to always include the system message + start_on: Message type to start on (only for "last" strategy) + + Returns: + List[BaseMessage]: Trimmed messages that fit within token limit + """ + if not messages: + return [] + + messages = list(messages) + + # Always keep the first num_messages_to_keep messages + kept_messages = messages[:num_messages_to_keep] + remaining_msgs = messages[num_messages_to_keep:] + + # Debug: Print message types for all messages + print("\nDEBUG - All messages:") + for i, msg in enumerate(messages): + msg_type = type(msg).__name__ + tool_use = ( + "tool_use" + if isinstance(msg, AIMessage) + and hasattr(msg, "additional_kwargs") + and msg.additional_kwargs.get("tool_calls") + else "" + ) + tool_result = ( + f"tool_call_id: {msg.tool_call_id}" + if isinstance(msg, ToolMessage) and hasattr(msg, "tool_call_id") + else "" + ) + print(f" [{i}] {msg_type} {tool_use} {tool_result}") + + # For Anthropic, we need to maintain the conversation structure where: + # 1. Every AIMessage with tool_use must be followed by a ToolMessage + # 2. Every AIMessage that follows a ToolMessage must start with a tool_result + + # First, check if we have any tool_use in the messages + has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages) + print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}") + + # Print debug info for AIMessages + for i, msg in enumerate(messages): + if isinstance(msg, AIMessage): + print(f"DEBUG - AIMessage[{i}] details:") + print(f" has_tool_use: {has_tool_use(msg)}") + if hasattr(msg, "additional_kwargs"): + print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}") + + # If we have tool_use anywhere, we need to be very careful about trimming + if has_tool_use_anywhere: + # For safety, just keep all messages if we're under the token limit + if token_counter(messages) <= max_tokens: + print("DEBUG - All messages fit within token limit, keeping all") + return messages + + # We need to identify all tool_use/tool_result relationships + # First, find all AIMessage+ToolMessage pairs + pairs = [] + i = 0 + while i < len(messages) - 1: + if is_tool_pair(messages[i], messages[i+1]): + pairs.append((i, i+1)) + print(f"DEBUG - Found tool_use pair: ({i}, {i+1})") + i += 2 + else: + i += 1 + + print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs") + + # For Anthropic, we need to ensure that: + # 1. If we include an AIMessage with tool_use, we must include the following ToolMessage + # 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use + + # The safest approach is to always keep complete AIMessage+ToolMessage pairs together + # First, identify all complete pairs + complete_pairs = [] + for start, end in pairs: + complete_pairs.append((start, end)) + + print(f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs") + + # Now we'll build our result, starting with the kept_messages + # But we need to be careful about the first message if it has tool_use + result = [] + + # Check if the last message in kept_messages has tool_use + if kept_messages and isinstance(kept_messages[-1], AIMessage) and has_tool_use(kept_messages[-1]): + # We need to find the corresponding ToolMessage + for i, (ai_idx, tool_idx) in enumerate(pairs): + if messages[ai_idx] is kept_messages[-1]: + # Found the pair, add all kept_messages except the last one + result.extend(kept_messages[:-1]) + # Add the AIMessage and ToolMessage as a pair + result.extend([messages[ai_idx], messages[tool_idx]]) + # Remove this pair from the list of pairs to process later + pairs = pairs[:i] + pairs[i+1:] + break + else: + # If we didn't find a matching pair, just add all kept_messages + result.extend(kept_messages) + else: + # No tool_use in the last kept message, just add all kept_messages + result.extend(kept_messages) + + # If we're using the "last" strategy, we'll try to include pairs from the end + if strategy == "last": + # First collect all pairs we can include within the token limit + pairs_to_include = [] + + # Process pairs from the end (newest first) + for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)): + # Try adding this pair + test_msgs = result.copy() + + # Add all previously selected pairs + for prev_ai_idx, prev_tool_idx in pairs_to_include: + test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]]) + + # Add this pair + test_msgs.extend([messages[ai_idx], messages[tool_idx]]) + + if token_counter(test_msgs) <= max_tokens: + # This pair fits, add it to our list + pairs_to_include.append((ai_idx, tool_idx)) + print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})") + else: + # This pair would exceed the token limit + print(f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping") + break + + # Now add the pairs in the correct order + # Sort by index to maintain the original conversation flow + pairs_to_include.sort(key=lambda x: x[0]) + for ai_idx, tool_idx in pairs_to_include: + result.extend([messages[ai_idx], messages[tool_idx]]) + + # No need to sort - we've already added messages in the correct order + + print(f"DEBUG - Final result has {len(result)} messages") + return result + + # If no tool_use, proceed with normal segmentation + segments = [] + i = 0 + + # Group messages into segments + while i < len(remaining_msgs): + segments.append([remaining_msgs[i]]) + print(f"DEBUG - Added message as segment: [{i}]") + i += 1 + + print(f"\nDEBUG - Created {len(segments)} segments") + for i, segment in enumerate(segments): + segment_types = [type(msg).__name__ for msg in segment] + print(f" Segment {i}: {segment_types}") + + # Now we have segments that maintain the required structure + # We'll add segments from the end (for "last" strategy) or beginning (for "first") + # until we hit the token limit + + if strategy == "last": + # If we have no segments, just return kept_messages + if not segments: + return kept_messages + + result = [] + + # Process segments from the end + for i, segment in enumerate(reversed(segments)): + # Try adding this segment + test_msgs = segment + result + + if token_counter(kept_messages + test_msgs) <= max_tokens: + result = segment + result + print(f"DEBUG - Added segment {len(segments)-i-1} to result") + else: + # This segment would exceed the token limit + print(f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping") + break + + final_result = kept_messages + result + + # For Anthropic, we need to ensure the conversation follows a valid structure + # We'll do a final check of the entire conversation + print("\nDEBUG - Final result before validation:") + for i, msg in enumerate(final_result): + msg_type = type(msg).__name__ + print(f" [{i}] {msg_type}") + + # Validate the conversation structure + valid_result = [] + i = 0 + + # Process messages in order + while i < len(final_result): + current_msg = final_result[i] + + # If this is an AIMessage with tool_use, it must be followed by a ToolMessage + if i < len(final_result) - 1 and isinstance(current_msg, AIMessage) and has_tool_use(current_msg): + if isinstance(final_result[i+1], ToolMessage): + # This is a valid tool_use + tool_result pair + valid_result.append(current_msg) + valid_result.append(final_result[i+1]) + print(f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}") + i += 2 + else: + # Invalid: AIMessage with tool_use not followed by ToolMessage + print(f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage") + # Skip this message to maintain valid structure + i += 1 + else: + # Regular message, just add it + valid_result.append(current_msg) + print(f"DEBUG - Added regular message at position {i}") + i += 1 + + # Final check: don't end with an AIMessage that has tool_use + if valid_result and isinstance(valid_result[-1], AIMessage) and has_tool_use(valid_result[-1]): + print("WARNING: Last message is AIMessage with tool_use but no following ToolMessage") + valid_result.pop() # Remove the last message + + print("\nDEBUG - Final validated result:") + for i, msg in enumerate(valid_result): + msg_type = type(msg).__name__ + print(f" [{i}] {msg_type}") + + return valid_result + + elif strategy == "first": + result = [] + + # Process segments from the beginning + for i, segment in enumerate(segments): + # Try adding this segment + test_msgs = result + segment + if token_counter(kept_messages + test_msgs) <= max_tokens: + result = result + segment + print(f"DEBUG - Added segment {i} to result") + else: + # This segment would exceed the token limit + print(f"DEBUG - Segment {i} would exceed token limit, stopping") + break + + final_result = kept_messages + result + print("\nDEBUG - Final result:") + for i, msg in enumerate(final_result): + msg_type = type(msg).__name__ + print(f" [{i}] {msg_type}") + + return final_result diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index fc6d3ac..c46cbc9 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -1,11 +1,17 @@ """Utilities for handling token limits with Anthropic models.""" from functools import partial -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence from langchain_anthropic import ChatAnthropic -from langchain_core.messages import BaseMessage, trim_messages +from langchain_core.messages import AIMessage, BaseMessage, ToolMessage, trim_messages from langchain_core.messages.base import message_to_dict + +from ra_aid.anthropic_message_utils import ( + fix_anthropic_message_content, + anthropic_trim_messages, + has_tool_use, +) from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import token_counter @@ -13,7 +19,7 @@ from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params -from ra_aid.console.output import print_messages_compact +from ra_aid.console.output import cpm, print_messages_compact logger = get_logger(__name__) @@ -85,68 +91,58 @@ def create_token_counter_wrapper(model: str): def state_modifier( state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT ) -> list[BaseMessage]: - """Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message. + """Given the agent state and max_tokens, return a trimmed list of messages. + + This uses anthropic_trim_messages which always keeps the first 2 messages. Args: state: The current agent state containing messages model: The language model to use for token counting - max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) + max_input_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) Returns: list[BaseMessage]: Trimmed list of messages that fits within token limit """ messages = state["messages"] - if not messages: return [] - first_message = messages[0] - remaining_messages = messages[1:] - wrapped_token_counter = create_token_counter_wrapper(model.model) - print(f"max_input_tokens={max_input_tokens}") - max_input_tokens = 17000 - first_tokens = wrapped_token_counter([first_message]) - print(f"first_tokens={first_tokens}") - new_max_tokens = max_input_tokens - first_tokens + # Keep max_input_tokens at 21000 as requested + max_input_tokens = 21000 - # Calculate total tokens before trimming - total_tokens_before = wrapped_token_counter(messages) - print( - f"Current token total: {total_tokens_before} (should be at least {first_tokens})" - ) + print("\nDEBUG - Starting token trimming with max_tokens:", max_input_tokens) + print(f"Current token total: {wrapped_token_counter(messages)}") - # Verify the token count is correct - if total_tokens_before < first_tokens: - print(f"WARNING: Token count inconsistency detected! Recounting...") - # Count message by message to debug - for i, msg in enumerate(messages): - msg_tokens = wrapped_token_counter([msg]) - print(f" Message {i}: {msg_tokens} tokens") - # Try alternative counting method - alt_count = sum(wrapped_token_counter([msg]) for msg in messages) - print(f" Alternative count method: {alt_count} tokens") + # Print more details about the messages to help debug + for i, msg in enumerate(messages): + if isinstance(msg, AIMessage): + print(f"DEBUG - AIMessage[{i}] content type: {type(msg.content)}") + print(f"DEBUG - AIMessage[{i}] has_tool_use: {has_tool_use(msg)}") + if has_tool_use(msg) and i < len(messages) - 1: + print( + f"DEBUG - Next message is ToolMessage: {isinstance(messages[i+1], ToolMessage)}" + ) - print_messages_compact(messages) - - trimmed_remaining = trim_messages( - remaining_messages, + result = anthropic_trim_messages( + messages, token_counter=wrapped_token_counter, - max_tokens=new_max_tokens, + max_tokens=max_input_tokens, strategy="last", allow_partial=False, + include_system=True, + num_messages_to_keep=2, ) - result = [first_message] + trimmed_remaining - - # Only show message if some messages were trimmed if len(result) < len(messages): print(f"TRIMMED: {len(messages)} messages → {len(result)} messages") - # Calculate total tokens after trimming total_tokens_after = wrapped_token_counter(result) print(f"New token total: {total_tokens_after}") - + print("BEFORE TRIMMING") + print_messages_compact(messages) + print("AFTER TRIMMING") + print_messages_compact(result) return result @@ -176,12 +172,14 @@ def sonnet_35_state_modifier( total_tokens_before = estimate_messages_tokens(messages) print(f"Current token total: {total_tokens_before}") - trimmed_remaining = trim_messages( + # Trim remaining messages + trimmed_remaining = anthropic_trim_messages( remaining_messages, token_counter=estimate_messages_tokens, max_tokens=new_max_tokens, strategy="last", allow_partial=False, + include_system=True, ) result = [first_message] + trimmed_remaining @@ -193,6 +191,8 @@ def sonnet_35_state_modifier( total_tokens_after = estimate_messages_tokens(result) print(f"New token total: {total_tokens_after}") + # No need to fix message content as anthropic_trim_messages already handles this + return result From a3284c9d7e66a4b3f42997c7e40f5fa55eee8d28 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Mar 2025 23:37:20 -0700 Subject: [PATCH 04/11] feat(anthropic_token_limiter.py): add dataclass import for future use and improve code readability by restructuring import statements --- ra_aid/anthropic_token_limiter.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index c46cbc9..ae82ab2 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -2,9 +2,16 @@ from functools import partial from typing import Any, Dict, List, Optional, Sequence +from dataclasses import dataclass from langchain_anthropic import ChatAnthropic -from langchain_core.messages import AIMessage, BaseMessage, ToolMessage, trim_messages +from langchain_core.messages import ( + AIMessage, + BaseMessage, + RemoveMessage, + ToolMessage, + trim_messages, +) from langchain_core.messages.base import message_to_dict from ra_aid.anthropic_message_utils import ( @@ -143,6 +150,7 @@ def state_modifier( print_messages_compact(messages) print("AFTER TRIMMING") print_messages_compact(result) + return result From 376d486db86014c0dd8e6454a396cab188b0aa49 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Mar 2025 23:38:31 -0700 Subject: [PATCH 05/11] refactor(anthropic_message_utils.py): clean up whitespace and improve code readability by removing unnecessary blank lines and aligning code formatting fix(anthropic_message_utils.py): add warning in docstring for anthropic_trim_messages function to indicate incomplete implementation and clarify behavior fix(anthropic_message_utils.py): ensure consistent formatting in conditional statements and improve readability of logical checks --- ra_aid/anthropic_message_utils.py | 142 ++++++++++++++++++------------ 1 file changed, 85 insertions(+), 57 deletions(-) diff --git a/ra_aid/anthropic_message_utils.py b/ra_aid/anthropic_message_utils.py index 91c285f..0df4564 100644 --- a/ra_aid/anthropic_message_utils.py +++ b/ra_aid/anthropic_message_utils.py @@ -36,47 +36,49 @@ def _is_message_type( def has_tool_use(message: BaseMessage) -> bool: """Check if a message contains tool use. - + Args: message: The message to check - + Returns: bool: True if the message contains tool use """ if not isinstance(message, AIMessage): return False - + # Check content for tool_use if isinstance(message.content, str) and "tool_use" in message.content: return True - + # Check content list for tool_use blocks if isinstance(message.content, list): for item in message.content: if isinstance(item, dict) and item.get("type") == "tool_use": return True - + # Check additional_kwargs for tool_calls - if hasattr(message, "additional_kwargs") and message.additional_kwargs.get("tool_calls"): + if hasattr(message, "additional_kwargs") and message.additional_kwargs.get( + "tool_calls" + ): return True - + return False def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool: """Check if two messages form a tool use/result pair. - + Args: message1: First message message2: Second message - + Returns: bool: True if the messages form a tool use/result pair """ return ( - isinstance(message1, AIMessage) and - isinstance(message2, ToolMessage) and - has_tool_use(message1) + isinstance(message1, AIMessage) + and isinstance(message2, ToolMessage) + and has_tool_use(message1) ) @@ -129,6 +131,8 @@ def anthropic_trim_messages( ) -> List[BaseMessage]: """Trim messages to fit within a token limit, with Anthropic-specific handling. + Warning - not fully implemented - last strategy is supported and test, not + allow partial, not 'first' strategy either. This function is similar to langchain_core's trim_messages but with special handling for Anthropic message formats to avoid API errors. @@ -176,11 +180,11 @@ def anthropic_trim_messages( # For Anthropic, we need to maintain the conversation structure where: # 1. Every AIMessage with tool_use must be followed by a ToolMessage # 2. Every AIMessage that follows a ToolMessage must start with a tool_result - + # First, check if we have any tool_use in the messages has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages) print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}") - + # Print debug info for AIMessages for i, msg in enumerate(messages): if isinstance(msg, AIMessage): @@ -188,46 +192,52 @@ def anthropic_trim_messages( print(f" has_tool_use: {has_tool_use(msg)}") if hasattr(msg, "additional_kwargs"): print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}") - + # If we have tool_use anywhere, we need to be very careful about trimming if has_tool_use_anywhere: # For safety, just keep all messages if we're under the token limit if token_counter(messages) <= max_tokens: print("DEBUG - All messages fit within token limit, keeping all") return messages - + # We need to identify all tool_use/tool_result relationships # First, find all AIMessage+ToolMessage pairs pairs = [] i = 0 while i < len(messages) - 1: - if is_tool_pair(messages[i], messages[i+1]): - pairs.append((i, i+1)) + if is_tool_pair(messages[i], messages[i + 1]): + pairs.append((i, i + 1)) print(f"DEBUG - Found tool_use pair: ({i}, {i+1})") i += 2 else: i += 1 - + print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs") - + # For Anthropic, we need to ensure that: # 1. If we include an AIMessage with tool_use, we must include the following ToolMessage # 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use - + # The safest approach is to always keep complete AIMessage+ToolMessage pairs together # First, identify all complete pairs complete_pairs = [] for start, end in pairs: complete_pairs.append((start, end)) - - print(f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs") - + + print( + f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs" + ) + # Now we'll build our result, starting with the kept_messages # But we need to be careful about the first message if it has tool_use result = [] - + # Check if the last message in kept_messages has tool_use - if kept_messages and isinstance(kept_messages[-1], AIMessage) and has_tool_use(kept_messages[-1]): + if ( + kept_messages + and isinstance(kept_messages[-1], AIMessage) + and has_tool_use(kept_messages[-1]) + ): # We need to find the corresponding ToolMessage for i, (ai_idx, tool_idx) in enumerate(pairs): if messages[ai_idx] is kept_messages[-1]: @@ -236,7 +246,7 @@ def anthropic_trim_messages( # Add the AIMessage and ToolMessage as a pair result.extend([messages[ai_idx], messages[tool_idx]]) # Remove this pair from the list of pairs to process later - pairs = pairs[:i] + pairs[i+1:] + pairs = pairs[:i] + pairs[i + 1 :] break else: # If we didn't find a matching pair, just add all kept_messages @@ -244,48 +254,50 @@ def anthropic_trim_messages( else: # No tool_use in the last kept message, just add all kept_messages result.extend(kept_messages) - + # If we're using the "last" strategy, we'll try to include pairs from the end if strategy == "last": # First collect all pairs we can include within the token limit pairs_to_include = [] - + # Process pairs from the end (newest first) for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)): # Try adding this pair test_msgs = result.copy() - + # Add all previously selected pairs for prev_ai_idx, prev_tool_idx in pairs_to_include: test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]]) - + # Add this pair test_msgs.extend([messages[ai_idx], messages[tool_idx]]) - + if token_counter(test_msgs) <= max_tokens: # This pair fits, add it to our list pairs_to_include.append((ai_idx, tool_idx)) print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})") else: # This pair would exceed the token limit - print(f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping") + print( + f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping" + ) break - + # Now add the pairs in the correct order # Sort by index to maintain the original conversation flow pairs_to_include.sort(key=lambda x: x[0]) for ai_idx, tool_idx in pairs_to_include: result.extend([messages[ai_idx], messages[tool_idx]]) - + # No need to sort - we've already added messages in the correct order - + print(f"DEBUG - Final result has {len(result)} messages") return result - + # If no tool_use, proceed with normal segmentation segments = [] i = 0 - + # Group messages into segments while i < len(remaining_msgs): segments.append([remaining_msgs[i]]) @@ -305,50 +317,60 @@ def anthropic_trim_messages( # If we have no segments, just return kept_messages if not segments: return kept_messages - + result = [] - + # Process segments from the end for i, segment in enumerate(reversed(segments)): # Try adding this segment test_msgs = segment + result - + if token_counter(kept_messages + test_msgs) <= max_tokens: result = segment + result print(f"DEBUG - Added segment {len(segments)-i-1} to result") else: # This segment would exceed the token limit - print(f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping") + print( + f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping" + ) break - + final_result = kept_messages + result - + # For Anthropic, we need to ensure the conversation follows a valid structure # We'll do a final check of the entire conversation print("\nDEBUG - Final result before validation:") for i, msg in enumerate(final_result): msg_type = type(msg).__name__ print(f" [{i}] {msg_type}") - + # Validate the conversation structure valid_result = [] i = 0 - + # Process messages in order while i < len(final_result): current_msg = final_result[i] - + # If this is an AIMessage with tool_use, it must be followed by a ToolMessage - if i < len(final_result) - 1 and isinstance(current_msg, AIMessage) and has_tool_use(current_msg): - if isinstance(final_result[i+1], ToolMessage): + if ( + i < len(final_result) - 1 + and isinstance(current_msg, AIMessage) + and has_tool_use(current_msg) + ): + if isinstance(final_result[i + 1], ToolMessage): # This is a valid tool_use + tool_result pair valid_result.append(current_msg) - valid_result.append(final_result[i+1]) - print(f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}") + valid_result.append(final_result[i + 1]) + print( + f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}" + ) i += 2 else: # Invalid: AIMessage with tool_use not followed by ToolMessage - print(f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage") + print( + f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage" + ) # Skip this message to maintain valid structure i += 1 else: @@ -356,17 +378,23 @@ def anthropic_trim_messages( valid_result.append(current_msg) print(f"DEBUG - Added regular message at position {i}") i += 1 - + # Final check: don't end with an AIMessage that has tool_use - if valid_result and isinstance(valid_result[-1], AIMessage) and has_tool_use(valid_result[-1]): - print("WARNING: Last message is AIMessage with tool_use but no following ToolMessage") + if ( + valid_result + and isinstance(valid_result[-1], AIMessage) + and has_tool_use(valid_result[-1]) + ): + print( + "WARNING: Last message is AIMessage with tool_use but no following ToolMessage" + ) valid_result.pop() # Remove the last message - + print("\nDEBUG - Final validated result:") for i, msg in enumerate(valid_result): msg_type = type(msg).__name__ print(f" [{i}] {msg_type}") - + return valid_result elif strategy == "first": From e42f281f94853bad796e279716fed89d344cdd44 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Mar 2025 23:48:08 -0700 Subject: [PATCH 06/11] chore(anthropic_message_utils.py): remove unused fix_anthropic_message_content function to clean up codebase chore(anthropic_token_limiter.py): remove import of fix_anthropic_message_content as it is no longer needed test: add unit tests for has_tool_use and is_tool_pair functions to ensure correct functionality test: enhance test coverage for anthropic_trim_messages with tool use scenarios to validate message handling --- ra_aid/anthropic_message_utils.py | 35 --- ra_aid/anthropic_token_limiter.py | 1 - tests/ra_aid/test_anthropic_token_limiter.py | 216 ++++++++++++++++--- 3 files changed, 189 insertions(+), 63 deletions(-) diff --git a/ra_aid/anthropic_message_utils.py b/ra_aid/anthropic_message_utils.py index 0df4564..e71d0ed 100644 --- a/ra_aid/anthropic_message_utils.py +++ b/ra_aid/anthropic_message_utils.py @@ -82,41 +82,6 @@ def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool: ) -def fix_anthropic_message_content(message: BaseMessage) -> BaseMessage: - """Fix message content format for Anthropic API compatibility.""" - if not isinstance(message, AIMessage) or not isinstance(message.content, list): - return message - - fixed_message = message.model_copy(deep=True) - - # Ensure first block is valid thinking type - if fixed_message.content and isinstance(fixed_message.content[0], dict): - first_block_type = fixed_message.content[0].get("type") - if first_block_type not in ("thinking", "redacted_thinking"): - # Prepend redacted_thinking block instead of thinking - fixed_message.content.insert( - 0, - { - "type": "redacted_thinking", - "data": "ENCRYPTED_REASONING", # Required field for redacted_thinking - }, - ) - - # Ensure all thinking blocks have valid structure - for i, block in enumerate(fixed_message.content): - if block.get("type") == "thinking": - # Convert thinking blocks to redacted_thinking to avoid signature validation - fixed_message.content[i] = { - "type": "redacted_thinking", - "data": "ENCRYPTED_REASONING", - } - elif block.get("type") == "redacted_thinking": - # Ensure required data field exists - if "data" not in block: - fixed_message.content[i]["data"] = "ENCRYPTED_REASONING" - - return fixed_message - def anthropic_trim_messages( messages: Sequence[BaseMessage], diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index ae82ab2..d9e7355 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -15,7 +15,6 @@ from langchain_core.messages import ( from langchain_core.messages.base import message_to_dict from ra_aid.anthropic_message_utils import ( - fix_anthropic_message_content, anthropic_trim_messages, has_tool_use, ) diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py index 933c73e..3f7e35e 100644 --- a/tests/ra_aid/test_anthropic_token_limiter.py +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -2,7 +2,12 @@ import unittest from unittest.mock import MagicMock, patch from langchain_anthropic import ChatAnthropic -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage +) from langgraph.prebuilt.chat_agent_executor import AgentState from ra_aid.anthropic_token_limiter import ( @@ -10,7 +15,10 @@ from ra_aid.anthropic_token_limiter import ( estimate_messages_tokens, get_model_token_limit, state_modifier, + sonnet_35_state_modifier, + convert_message_to_litellm_format ) +from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair class TestAnthropicTokenLimiter(unittest.TestCase): @@ -23,6 +31,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase): # Sample messages for testing self.system_message = SystemMessage(content="You are a helpful assistant.") self.human_message = HumanMessage(content="Hello, can you help me with a task?") + self.ai_message = AIMessage(content="I'd be happy to help! What do you need?") self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming # Create more messages for testing @@ -35,6 +44,34 @@ class TestAnthropicTokenLimiter(unittest.TestCase): messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages, next=None, ) + + # Create tool-related messages for testing + self.ai_with_tool_use = AIMessage( + content="I'll use a tool to help you", + additional_kwargs={"tool_calls": [{"name": "calculator", "input": {"expression": "2+2"}}]} + ) + self.tool_message = ToolMessage( + content="4", + tool_call_id="tool_call_1", + name="calculator" + ) + + def test_convert_message_to_litellm_format(self): + """Test conversion of BaseMessage to litellm format.""" + # Test human message + human_result = convert_message_to_litellm_format(self.human_message) + self.assertEqual(human_result["role"], "human") + self.assertEqual(human_result["content"], "Hello, can you help me with a task?") + + # Test system message + system_result = convert_message_to_litellm_format(self.system_message) + self.assertEqual(system_result["role"], "system") + self.assertEqual(system_result["content"], "You are a helpful assistant.") + + # Test AI message + ai_result = convert_message_to_litellm_format(self.ai_message) + self.assertEqual(ai_result["role"], "ai") + self.assertEqual(ai_result["content"], "I'd be happy to help! What do you need?") @patch("ra_aid.anthropic_token_limiter.token_counter") def test_create_token_counter_wrapper(self, mock_token_counter): @@ -75,44 +112,66 @@ class TestAnthropicTokenLimiter(unittest.TestCase): @patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") @patch("ra_aid.anthropic_token_limiter.print_messages_compact") - def test_state_modifier(self, mock_print, mock_create_wrapper): + @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") + def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper): # Setup a proper token counter function that returns integers - # This function needs to return values that will cause trim_messages to keep only the first message def token_counter(msgs): - # For a single message, return a small token count - if len(msgs) == 1: - return 10 - # For two messages (first + one more), return a value under our limit - elif len(msgs) == 2: - return 30 # This is under our 40 token remaining budget (50-10) - # For three messages, return a value just under our limit - elif len(msgs) == 3: - return 40 # This is exactly at our 40 token remaining budget (50-10) - # For four messages, return a value just at our limit - elif len(msgs) == 4: - return 40 # This is exactly at our 40 token remaining budget (50-10) - # For five messages, return a value that exceeds our 40 token budget - elif len(msgs) == 5: - return 60 # This exceeds our 40 token budget, forcing only 4 more messages - # For more messages, return a value over our limit - else: - return 100 # This exceeds our limit + # Return token count based on number of messages + return len(msgs) * 10 - # Don't use side_effect here, directly return the function + # Configure the mock to return our token counter mock_create_wrapper.return_value = token_counter + # Configure anthropic_trim_messages to return a subset of messages + mock_trim_messages.return_value = [self.system_message, self.human_message] + # Call state_modifier with a max token limit of 50 result = state_modifier(self.state, self.mock_model, max_input_tokens=50) - # Should keep first message and some of the others (up to 5 total) - self.assertEqual(len(result), 5) # First message plus four more - self.assertEqual(result[0], self.system_message) # First message is preserved + # Should return what anthropic_trim_messages returned + self.assertEqual(result, [self.system_message, self.human_message]) # Verify the wrapper was created with the right model mock_create_wrapper.assert_called_with(self.mock_model.model) - # Verify print_messages_compact was called - mock_print.assert_called_once() + # Verify anthropic_trim_messages was called with the right parameters + mock_trim_messages.assert_called_once() + + # Verify print_messages_compact was called at least once + self.assertTrue(mock_print.call_count >= 1) + + @patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens") + @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") + def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate): + """Test the sonnet 35 state modifier function.""" + # Setup mocks + mock_estimate.side_effect = lambda msgs: len(msgs) * 1000 + mock_trim.return_value = [self.human_message, self.ai_message] + + # Create a state with messages + state = {"messages": [self.system_message, self.human_message, self.ai_message]} + + # Test with empty messages + empty_state = {"messages": []} + self.assertEqual(sonnet_35_state_modifier(empty_state), []) + + # Test with messages under the limit + result = sonnet_35_state_modifier(state, max_input_tokens=10000) + + # Should keep the first message and call anthropic_trim_messages for the rest + self.assertEqual(len(result), 3) + self.assertEqual(result[0], self.system_message) + self.assertEqual(result[1:], [self.human_message, self.ai_message]) + + # Verify anthropic_trim_messages was called with the right parameters + mock_trim.assert_called_once_with( + [self.human_message, self.ai_message], + token_counter=mock_estimate, + max_tokens=9000, # 10000 - 1000 (first message) + strategy="last", + allow_partial=False, + include_system=True + ) @patch("ra_aid.anthropic_token_limiter.get_config_repository") @patch("litellm.get_model_info") @@ -192,6 +251,109 @@ class TestAnthropicTokenLimiter(unittest.TestCase): # Test planner agent type result = get_model_token_limit(mock_config, "planner") self.assertEqual(result, 100000) + + def test_has_tool_use(self): + """Test the has_tool_use function.""" + # Test with regular AI message + self.assertFalse(has_tool_use(self.ai_message)) + + # Test with AI message containing tool_use in string content + ai_with_tool_str = AIMessage(content="I'll use a tool_use to help you") + self.assertTrue(has_tool_use(ai_with_tool_str)) + + # Test with AI message containing tool_use in structured content + ai_with_tool_dict = AIMessage(content=[ + {"type": "text", "text": "I'll use a tool to help you"}, + {"type": "tool_use", "tool_use": {"name": "calculator", "input": {"expression": "2+2"}}} + ]) + self.assertTrue(has_tool_use(ai_with_tool_dict)) + + # Test with AI message containing tool_calls in additional_kwargs + self.assertTrue(has_tool_use(self.ai_with_tool_use)) + + # Test with non-AI message + self.assertFalse(has_tool_use(self.human_message)) + + def test_is_tool_pair(self): + """Test the is_tool_pair function.""" + # Test with valid tool pair + self.assertTrue(is_tool_pair(self.ai_with_tool_use, self.tool_message)) + + # Test with non-tool pair (wrong order) + self.assertFalse(is_tool_pair(self.tool_message, self.ai_with_tool_use)) + + # Test with non-tool pair (wrong types) + self.assertFalse(is_tool_pair(self.ai_message, self.human_message)) + + # Test with non-tool pair (AI message without tool use) + self.assertFalse(is_tool_pair(self.ai_message, self.tool_message)) + + @patch("ra_aid.anthropic_message_utils.has_tool_use") + def test_anthropic_trim_messages_with_tool_use(self, mock_has_tool_use): + """Test anthropic_trim_messages with a sequence of messages including tool use.""" + from ra_aid.anthropic_message_utils import anthropic_trim_messages + + # Setup mock for has_tool_use to return True for AI messages at even indices + def side_effect(msg): + if isinstance(msg, AIMessage) and hasattr(msg, 'test_index'): + return msg.test_index % 2 == 0 # Even indices have tool use + return False + + mock_has_tool_use.side_effect = side_effect + + # Create a sequence of alternating human and AI messages with tool use + messages = [] + + # Start with system message + system_msg = SystemMessage(content="You are a helpful assistant.") + messages.append(system_msg) + + # Add alternating human and AI messages with tool use + for i in range(8): + if i % 2 == 0: + # Human message + msg = HumanMessage(content=f"Human message {i}") + messages.append(msg) + else: + # AI message, every other one has tool use + ai_msg = AIMessage(content=f"AI message {i}") + # Add a test_index attribute to track position + ai_msg.test_index = i + messages.append(ai_msg) + + # If this AI message has tool use (even index), add a tool message after it + if i % 4 == 1: # 1, 5, etc. + tool_msg = ToolMessage( + content=f"Tool result {i}", + tool_call_id=f"tool_call_{i}", + name="test_tool" + ) + messages.append(tool_msg) + + # Define a token counter that returns a fixed value per message + def token_counter(msgs): + return len(msgs) * 1000 + + # Test with a token limit that will require trimming + result = anthropic_trim_messages( + messages, + token_counter=token_counter, + max_tokens=5000, # This will allow 5 messages + strategy="last", + allow_partial=False, + include_system=True, + num_messages_to_keep=2 # Keep system and first human message + ) + + # We should have kept the first 2 messages (system + human) + self.assertEqual(len(result), 5) # 2 kept + 3 more that fit in token limit + self.assertEqual(result[0], system_msg) + + # Verify that we don't have any AI messages with tool use that aren't followed by a tool message + for i in range(len(result) - 1): + if isinstance(result[i], AIMessage) and mock_has_tool_use(result[i]): + self.assertTrue(isinstance(result[i+1], ToolMessage), + f"AI message with tool use at index {i} not followed by ToolMessage") if __name__ == "__main__": From 8d2d273c6bd5b7687eb080c2d6f98fb125e47d7b Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Mar 2025 23:53:37 -0700 Subject: [PATCH 07/11] refactor(tests): move token limit tests from test_agent_utils.py to test_anthropic_token_limiter.py for better organization and clarity --- tests/ra_aid/test_agent_utils.py | 132 ++--------------- tests/ra_aid/test_anthropic_token_limiter.py | 146 +++++++++++++++++++ 2 files changed, 156 insertions(+), 122 deletions(-) diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 5292317..f9a894c 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -14,8 +14,10 @@ from ra_aid.agent_context import ( from ra_aid.agent_utils import ( AgentState, create_agent, - get_model_token_limit, is_anthropic_claude, +) +from ra_aid.anthropic_token_limiter import ( + get_model_token_limit, state_modifier, ) from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params @@ -63,87 +65,15 @@ def mock_config_repository(): yield mock_repo -def test_get_model_token_limit_anthropic(mock_config_repository): - """Test get_model_token_limit with Anthropic model.""" - config = {"provider": "anthropic", "model": "claude2"} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] - - -def test_get_model_token_limit_openai(mock_config_repository): - """Test get_model_token_limit with OpenAI model.""" - config = {"provider": "openai", "model": "gpt-4"} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] - - -def test_get_model_token_limit_unknown(mock_config_repository): - """Test get_model_token_limit with unknown provider/model.""" - config = {"provider": "unknown", "model": "unknown-model"} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit is None - - -def test_get_model_token_limit_missing_config(mock_config_repository): - """Test get_model_token_limit with missing configuration.""" - config = {} - mock_config_repository.update(config) - - token_limit = get_model_token_limit(config, "default") - assert token_limit is None - - -def test_get_model_token_limit_litellm_success(): - """Test get_model_token_limit successfully getting limit from litellm.""" - config = {"provider": "anthropic", "model": "claude-2"} - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.return_value = {"max_input_tokens": 100000} - token_limit = get_model_token_limit(config, "default") - assert token_limit == 100000 - - -def test_get_model_token_limit_litellm_not_found(): - """Test fallback to models_tokens when litellm raises NotFoundError.""" - config = {"provider": "anthropic", "model": "claude-2"} - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.side_effect = litellm.exceptions.NotFoundError( - message="Model not found", model="claude-2", llm_provider="anthropic" - ) - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] - - -def test_get_model_token_limit_litellm_error(): - """Test fallback to models_tokens when litellm raises other exceptions.""" - config = {"provider": "anthropic", "model": "claude-2"} - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.side_effect = Exception("Unknown error") - token_limit = get_model_token_limit(config, "default") - assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] - - -def test_get_model_token_limit_unexpected_error(): - """Test returning None when unexpected errors occur.""" - config = None # This will cause an attribute error when accessed - - token_limit = get_model_token_limit(config, "default") - assert token_limit is None +# These tests have been moved to test_anthropic_token_limiter.py def test_create_agent_anthropic(mock_model, mock_config_repository): """Test create_agent with Anthropic Claude model.""" mock_config_repository.update({"provider": "anthropic", "model": "claude-2"}) - with patch("ra_aid.agent_utils.create_react_agent") as mock_react: + with patch("ra_aid.agent_utils.create_react_agent") as mock_react, \ + patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier: mock_react.return_value = "react_agent" agent = create_agent(mock_model, []) @@ -221,20 +151,7 @@ def mock_messages(): ] -def test_state_modifier(mock_messages): - """Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens.""" - state = AgentState(messages=mock_messages) - - with patch( - "ra_aid.agent_backends.ciayn_agent.CiaynAgent._estimate_tokens" - ) as mock_estimate: - mock_estimate.side_effect = lambda msg: 100 if msg else 0 - - result = state_modifier(state, max_input_tokens=250) - - assert len(result) < len(mock_messages) - assert isinstance(result[0], SystemMessage) - assert result[-1] == mock_messages[-1] +# This test has been moved to test_anthropic_token_limiter.py def test_create_agent_with_checkpointer(mock_model, mock_config_repository): @@ -265,7 +182,7 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, - patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, + patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit, ): mock_react.return_value = "react_agent" mock_limit.return_value = 100000 @@ -288,7 +205,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_ with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, - patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, + patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit, ): mock_react.return_value = "react_agent" mock_limit.return_value = 100000 @@ -299,36 +216,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_ mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2") -def test_get_model_token_limit_research(mock_config_repository): - """Test get_model_token_limit with research provider and model.""" - config = { - "provider": "openai", - "model": "gpt-4", - "research_provider": "anthropic", - "research_model": "claude-2", - } - mock_config_repository.update(config) - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.return_value = {"max_input_tokens": 150000} - token_limit = get_model_token_limit(config, "research") - assert token_limit == 150000 - - -def test_get_model_token_limit_planner(mock_config_repository): - """Test get_model_token_limit with planner provider and model.""" - config = { - "provider": "openai", - "model": "gpt-4", - "planner_provider": "deepseek", - "planner_model": "dsm-1", - } - mock_config_repository.update(config) - - with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: - mock_get_info.return_value = {"max_input_tokens": 120000} - token_limit = get_model_token_limit(config, "planner") - assert token_limit == 120000 +# These tests have been moved to test_anthropic_token_limiter.py # New tests for private helper methods in agent_utils.py diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py index 3f7e35e..3d0d9c3 100644 --- a/tests/ra_aid/test_anthropic_token_limiter.py +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -1,5 +1,6 @@ import unittest from unittest.mock import MagicMock, patch +import litellm from langchain_anthropic import ChatAnthropic from langchain_core.messages import ( @@ -19,6 +20,7 @@ from ra_aid.anthropic_token_limiter import ( convert_message_to_litellm_format ) from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair +from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT class TestAnthropicTokenLimiter(unittest.TestCase): @@ -140,6 +142,35 @@ class TestAnthropicTokenLimiter(unittest.TestCase): # Verify print_messages_compact was called at least once self.assertTrue(mock_print.call_count >= 1) + def test_state_modifier_with_messages(self): + """Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens.""" + # Create a state with messages + messages = [ + SystemMessage(content="System prompt"), + HumanMessage(content="Human message 1"), + AIMessage(content="AI response 1"), + HumanMessage(content="Human message 2"), + AIMessage(content="AI response 2"), + ] + state = AgentState(messages=messages) + model = MagicMock(spec=ChatAnthropic) + model.model = "claude-3-opus-20240229" + + with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \ + patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \ + patch("ra_aid.anthropic_token_limiter.print_messages_compact"): + # Setup mock to return a fixed token count per message + mock_wrapper.return_value = lambda msgs: len(msgs) * 100 + # Setup mock to return a subset of messages + mock_trim.return_value = [messages[0], messages[-2], messages[-1]] + + result = state_modifier(state, model, max_input_tokens=250) + + # Should return what anthropic_trim_messages returned + self.assertEqual(len(result), 3) + self.assertEqual(result[0], messages[0]) # First message preserved + self.assertEqual(result[-1], messages[-1]) # Last message preserved + @patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens") @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate): @@ -191,6 +222,42 @@ class TestAnthropicTokenLimiter(unittest.TestCase): # Verify get_model_info was called with the right model mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}") + + def test_get_model_token_limit_research(self): + """Test get_model_token_limit with research provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "research_provider": "anthropic", + "research_model": "claude-2", + } + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.return_value = {"max_input_tokens": 150000} + token_limit = get_model_token_limit(config, "research") + self.assertEqual(token_limit, 150000) + # Verify get_model_info was called with the research model + mock_get_info.assert_called_with("anthropic/claude-2") + + def test_get_model_token_limit_planner(self): + """Test get_model_token_limit with planner provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "planner_provider": "deepseek", + "planner_model": "dsm-1", + } + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.return_value = {"max_input_tokens": 120000} + token_limit = get_model_token_limit(config, "planner") + self.assertEqual(token_limit, 120000) + # Verify get_model_info was called with the planner model + mock_get_info.assert_called_with("deepseek/dsm-1") @patch("ra_aid.anthropic_token_limiter.get_config_repository") @patch("litellm.get_model_info") @@ -252,6 +319,85 @@ class TestAnthropicTokenLimiter(unittest.TestCase): result = get_model_token_limit(mock_config, "planner") self.assertEqual(result, 100000) + def test_get_model_token_limit_anthropic(self): + """Test get_model_token_limit with Anthropic model.""" + config = {"provider": "anthropic", "model": "claude2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + + def test_get_model_token_limit_openai(self): + """Test get_model_token_limit with OpenAI model.""" + config = {"provider": "openai", "model": "gpt-4"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["openai"]["gpt-4"]["token_limit"]) + + def test_get_model_token_limit_unknown(self): + """Test get_model_token_limit with unknown provider/model.""" + config = {"provider": "unknown", "model": "unknown-model"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertIsNone(token_limit) + + def test_get_model_token_limit_missing_config(self): + """Test get_model_token_limit with missing configuration.""" + config = {} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo: + mock_get_config_repo.return_value.get_all.return_value = config + token_limit = get_model_token_limit(config, "default") + self.assertIsNone(token_limit) + + def test_get_model_token_limit_litellm_success(self): + """Test get_model_token_limit successfully getting limit from litellm.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.return_value = {"max_input_tokens": 100000} + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, 100000) + mock_get_info.assert_called_with("anthropic/claude-2") + + def test_get_model_token_limit_litellm_not_found(self): + """Test fallback to models_tokens when litellm raises NotFoundError.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.side_effect = litellm.exceptions.NotFoundError( + message="Model not found", model="claude-2", llm_provider="anthropic" + ) + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + + def test_get_model_token_limit_litellm_error(self): + """Test fallback to models_tokens when litellm raises other exceptions.""" + config = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \ + patch("litellm.get_model_info") as mock_get_info: + mock_get_config_repo.return_value.get_all.return_value = config + mock_get_info.side_effect = Exception("Unknown error") + token_limit = get_model_token_limit(config, "default") + self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"]) + + def test_get_model_token_limit_unexpected_error(self): + """Test returning None when unexpected errors occur.""" + config = None # This will cause an attribute error when accessed + + token_limit = get_model_token_limit(config, "default") + self.assertIsNone(token_limit) + def test_has_tool_use(self): """Test the has_tool_use function.""" # Test with regular AI message From d15d2499299e3f29c0cd1d53a1f6d8749482ada3 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Mar 2025 23:55:43 -0700 Subject: [PATCH 08/11] fix(test_agent_utils.py): add name parameter to mock_react calls to ensure consistency in agent creation tests --- tests/ra_aid/test_agent_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index f9a894c..d97225d 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -84,6 +84,7 @@ def test_create_agent_anthropic(mock_model, mock_config_repository): interrupt_after=['tools'], version="v2", state_modifier=mock_react.call_args[1]["state_modifier"], + name="React", ) @@ -213,7 +214,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_ agent = create_agent(mock_model, []) assert agent == "react_agent" - mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2") + mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2", name="React") # These tests have been moved to test_anthropic_token_limiter.py From 7cfbcb5a2e4b10e78c8662bcc9d5d8b1081e37d0 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Mar 2025 00:12:39 -0700 Subject: [PATCH 09/11] chore(anthropic_token_limiter.py): comment out max_input_tokens and related debug prints to clean up code and reduce clutter during execution --- ra_aid/anthropic_token_limiter.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index d9e7355..6aac3be 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -115,8 +115,7 @@ def state_modifier( wrapped_token_counter = create_token_counter_wrapper(model.model) - # Keep max_input_tokens at 21000 as requested - max_input_tokens = 21000 + # max_input_tokens = 33440 print("\nDEBUG - Starting token trimming with max_tokens:", max_input_tokens) print(f"Current token total: {wrapped_token_counter(messages)}") @@ -143,12 +142,12 @@ def state_modifier( if len(result) < len(messages): print(f"TRIMMED: {len(messages)} messages → {len(result)} messages") - total_tokens_after = wrapped_token_counter(result) - print(f"New token total: {total_tokens_after}") - print("BEFORE TRIMMING") - print_messages_compact(messages) - print("AFTER TRIMMING") - print_messages_compact(result) + # total_tokens_after = wrapped_token_counter(result) + # print(f"New token total: {total_tokens_after}") + # print("BEFORE TRIMMING") + # print_messages_compact(messages) + # print("AFTER TRIMMING") + # print_messages_compact(result) return result From fdd73f149c618f88140bfeb9f8e17e0148ebb1ed Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Mar 2025 11:16:54 -0700 Subject: [PATCH 10/11] feat(agent_utils.py): add support for sonnet_35_state_modifier for Claude 3.5 models to enhance token management chore(anthropic_message_utils.py): remove debug print statements to clean up code and improve readability chore(anthropic_token_limiter.py): remove debug print statements and replace with logging for better monitoring test(test_anthropic_token_limiter.py): update tests to verify correct behavior of sonnet_35_state_modifier without patching internal logic --- ra_aid/agent_utils.py | 5 +- ra_aid/anthropic_message_utils.py | 74 -------------------- ra_aid/anthropic_token_limiter.py | 40 +---------- tests/ra_aid/test_anthropic_token_limiter.py | 47 +++++++------ 4 files changed, 31 insertions(+), 135 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index d0f26cb..f87acab 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -47,7 +47,7 @@ from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command from ra_aid.database.repositories.config_repository import get_config_repository -from ra_aid.anthropic_token_limiter import state_modifier, get_model_token_limit +from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit console = Console() @@ -95,6 +95,9 @@ def build_agent_kwargs( ): def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: + if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]): + return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens) + return state_modifier(state, model, max_input_tokens=max_input_tokens) agent_kwargs["state_modifier"] = wrapped_state_modifier diff --git a/ra_aid/anthropic_message_utils.py b/ra_aid/anthropic_message_utils.py index e71d0ed..79f271e 100644 --- a/ra_aid/anthropic_message_utils.py +++ b/ra_aid/anthropic_message_utils.py @@ -124,23 +124,6 @@ def anthropic_trim_messages( kept_messages = messages[:num_messages_to_keep] remaining_msgs = messages[num_messages_to_keep:] - # Debug: Print message types for all messages - print("\nDEBUG - All messages:") - for i, msg in enumerate(messages): - msg_type = type(msg).__name__ - tool_use = ( - "tool_use" - if isinstance(msg, AIMessage) - and hasattr(msg, "additional_kwargs") - and msg.additional_kwargs.get("tool_calls") - else "" - ) - tool_result = ( - f"tool_call_id: {msg.tool_call_id}" - if isinstance(msg, ToolMessage) and hasattr(msg, "tool_call_id") - else "" - ) - print(f" [{i}] {msg_type} {tool_use} {tool_result}") # For Anthropic, we need to maintain the conversation structure where: # 1. Every AIMessage with tool_use must be followed by a ToolMessage @@ -148,21 +131,11 @@ def anthropic_trim_messages( # First, check if we have any tool_use in the messages has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages) - print(f"DEBUG - Has tool_use anywhere in messages: {has_tool_use_anywhere}") - - # Print debug info for AIMessages - for i, msg in enumerate(messages): - if isinstance(msg, AIMessage): - print(f"DEBUG - AIMessage[{i}] details:") - print(f" has_tool_use: {has_tool_use(msg)}") - if hasattr(msg, "additional_kwargs"): - print(f" additional_kwargs keys: {list(msg.additional_kwargs.keys())}") # If we have tool_use anywhere, we need to be very careful about trimming if has_tool_use_anywhere: # For safety, just keep all messages if we're under the token limit if token_counter(messages) <= max_tokens: - print("DEBUG - All messages fit within token limit, keeping all") return messages # We need to identify all tool_use/tool_result relationships @@ -172,13 +145,10 @@ def anthropic_trim_messages( while i < len(messages) - 1: if is_tool_pair(messages[i], messages[i + 1]): pairs.append((i, i + 1)) - print(f"DEBUG - Found tool_use pair: ({i}, {i+1})") i += 2 else: i += 1 - print(f"DEBUG - Found {len(pairs)} AIMessage+ToolMessage pairs") - # For Anthropic, we need to ensure that: # 1. If we include an AIMessage with tool_use, we must include the following ToolMessage # 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use @@ -189,10 +159,6 @@ def anthropic_trim_messages( for start, end in pairs: complete_pairs.append((start, end)) - print( - f"DEBUG - Found {len(complete_pairs)} complete AIMessage+ToolMessage pairs" - ) - # Now we'll build our result, starting with the kept_messages # But we need to be careful about the first message if it has tool_use result = [] @@ -240,12 +206,8 @@ def anthropic_trim_messages( if token_counter(test_msgs) <= max_tokens: # This pair fits, add it to our list pairs_to_include.append((ai_idx, tool_idx)) - print(f"DEBUG - Added complete pair ({ai_idx}, {tool_idx})") else: # This pair would exceed the token limit - print( - f"DEBUG - Pair ({ai_idx}, {tool_idx}) would exceed token limit, stopping" - ) break # Now add the pairs in the correct order @@ -256,7 +218,6 @@ def anthropic_trim_messages( # No need to sort - we've already added messages in the correct order - print(f"DEBUG - Final result has {len(result)} messages") return result # If no tool_use, proceed with normal segmentation @@ -266,14 +227,8 @@ def anthropic_trim_messages( # Group messages into segments while i < len(remaining_msgs): segments.append([remaining_msgs[i]]) - print(f"DEBUG - Added message as segment: [{i}]") i += 1 - print(f"\nDEBUG - Created {len(segments)} segments") - for i, segment in enumerate(segments): - segment_types = [type(msg).__name__ for msg in segment] - print(f" Segment {i}: {segment_types}") - # Now we have segments that maintain the required structure # We'll add segments from the end (for "last" strategy) or beginning (for "first") # until we hit the token limit @@ -292,22 +247,14 @@ def anthropic_trim_messages( if token_counter(kept_messages + test_msgs) <= max_tokens: result = segment + result - print(f"DEBUG - Added segment {len(segments)-i-1} to result") else: # This segment would exceed the token limit - print( - f"DEBUG - Segment {len(segments)-i-1} would exceed token limit, stopping" - ) break final_result = kept_messages + result # For Anthropic, we need to ensure the conversation follows a valid structure # We'll do a final check of the entire conversation - print("\nDEBUG - Final result before validation:") - for i, msg in enumerate(final_result): - msg_type = type(msg).__name__ - print(f" [{i}] {msg_type}") # Validate the conversation structure valid_result = [] @@ -327,21 +274,14 @@ def anthropic_trim_messages( # This is a valid tool_use + tool_result pair valid_result.append(current_msg) valid_result.append(final_result[i + 1]) - print( - f"DEBUG - Added valid tool_use + tool_result pair at positions {i}, {i+1}" - ) i += 2 else: # Invalid: AIMessage with tool_use not followed by ToolMessage - print( - f"WARNING: AIMessage at position {i} has tool_use but is not followed by a ToolMessage" - ) # Skip this message to maintain valid structure i += 1 else: # Regular message, just add it valid_result.append(current_msg) - print(f"DEBUG - Added regular message at position {i}") i += 1 # Final check: don't end with an AIMessage that has tool_use @@ -350,16 +290,8 @@ def anthropic_trim_messages( and isinstance(valid_result[-1], AIMessage) and has_tool_use(valid_result[-1]) ): - print( - "WARNING: Last message is AIMessage with tool_use but no following ToolMessage" - ) valid_result.pop() # Remove the last message - print("\nDEBUG - Final validated result:") - for i, msg in enumerate(valid_result): - msg_type = type(msg).__name__ - print(f" [{i}] {msg_type}") - return valid_result elif strategy == "first": @@ -371,16 +303,10 @@ def anthropic_trim_messages( test_msgs = result + segment if token_counter(kept_messages + test_msgs) <= max_tokens: result = result + segment - print(f"DEBUG - Added segment {i} to result") else: # This segment would exceed the token limit - print(f"DEBUG - Segment {i} would exceed token limit, stopping") break final_result = kept_messages + result - print("\nDEBUG - Final result:") - for i, msg in enumerate(final_result): - msg_type = type(msg).__name__ - print(f" [{i}] {msg_type}") return final_result diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py index 6aac3be..45a79d4 100644 --- a/ra_aid/anthropic_token_limiter.py +++ b/ra_aid/anthropic_token_limiter.py @@ -109,27 +109,13 @@ def state_modifier( Returns: list[BaseMessage]: Trimmed list of messages that fits within token limit """ + messages = state["messages"] if not messages: return [] wrapped_token_counter = create_token_counter_wrapper(model.model) - # max_input_tokens = 33440 - - print("\nDEBUG - Starting token trimming with max_tokens:", max_input_tokens) - print(f"Current token total: {wrapped_token_counter(messages)}") - - # Print more details about the messages to help debug - for i, msg in enumerate(messages): - if isinstance(msg, AIMessage): - print(f"DEBUG - AIMessage[{i}] content type: {type(msg.content)}") - print(f"DEBUG - AIMessage[{i}] has_tool_use: {has_tool_use(msg)}") - if has_tool_use(msg) and i < len(messages) - 1: - print( - f"DEBUG - Next message is ToolMessage: {isinstance(messages[i+1], ToolMessage)}" - ) - result = anthropic_trim_messages( messages, token_counter=wrapped_token_counter, @@ -141,13 +127,7 @@ def state_modifier( ) if len(result) < len(messages): - print(f"TRIMMED: {len(messages)} messages → {len(result)} messages") - # total_tokens_after = wrapped_token_counter(result) - # print(f"New token total: {total_tokens_after}") - # print("BEFORE TRIMMING") - # print_messages_compact(messages) - # print("AFTER TRIMMING") - # print_messages_compact(result) + logger.info(f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages") return result @@ -174,12 +154,7 @@ def sonnet_35_state_modifier( first_tokens = estimate_messages_tokens([first_message]) new_max_tokens = max_input_tokens - first_tokens - # Calculate total tokens before trimming - total_tokens_before = estimate_messages_tokens(messages) - print(f"Current token total: {total_tokens_before}") - - # Trim remaining messages - trimmed_remaining = anthropic_trim_messages( + trimmed_remaining = trim_messages( remaining_messages, token_counter=estimate_messages_tokens, max_tokens=new_max_tokens, @@ -190,15 +165,6 @@ def sonnet_35_state_modifier( result = [first_message] + trimmed_remaining - # Only show message if some messages were trimmed - if len(result) < len(messages): - print(f"TRIMMED: {len(messages)} messages → {len(result)} messages") - # Calculate total tokens after trimming - total_tokens_after = estimate_messages_tokens(result) - print(f"New token total: {total_tokens_after}") - - # No need to fix message content as anthropic_trim_messages already handles this - return result diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py index 3d0d9c3..36f9528 100644 --- a/tests/ra_aid/test_anthropic_token_limiter.py +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -139,9 +139,7 @@ class TestAnthropicTokenLimiter(unittest.TestCase): # Verify anthropic_trim_messages was called with the right parameters mock_trim_messages.assert_called_once() - # Verify print_messages_compact was called at least once - self.assertTrue(mock_print.call_count >= 1) - + def test_state_modifier_with_messages(self): """Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens.""" # Create a state with messages @@ -171,38 +169,41 @@ class TestAnthropicTokenLimiter(unittest.TestCase): self.assertEqual(result[0], messages[0]) # First message preserved self.assertEqual(result[-1], messages[-1]) # Last message preserved - @patch("ra_aid.anthropic_token_limiter.estimate_messages_tokens") - @patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") - def test_sonnet_35_state_modifier(self, mock_trim, mock_estimate): + def test_sonnet_35_state_modifier(self): """Test the sonnet 35 state modifier function.""" - # Setup mocks - mock_estimate.side_effect = lambda msgs: len(msgs) * 1000 - mock_trim.return_value = [self.human_message, self.ai_message] - # Create a state with messages state = {"messages": [self.system_message, self.human_message, self.ai_message]} # Test with empty messages empty_state = {"messages": []} - self.assertEqual(sonnet_35_state_modifier(empty_state), []) - # Test with messages under the limit - result = sonnet_35_state_modifier(state, max_input_tokens=10000) + # Instead of patching trim_messages which has complex internal logic, + # we'll directly patch the sonnet_35_state_modifier's call to trim_messages + with patch("ra_aid.anthropic_token_limiter.trim_messages") as mock_trim: + # Setup mock to return our desired messages + mock_trim.return_value = [self.human_message, self.ai_message] + + # Test with empty messages + self.assertEqual(sonnet_35_state_modifier(empty_state), []) + + # Test with messages under the limit + result = sonnet_35_state_modifier(state, max_input_tokens=10000) - # Should keep the first message and call anthropic_trim_messages for the rest + # Should keep the first message and call trim_messages for the rest self.assertEqual(len(result), 3) self.assertEqual(result[0], self.system_message) self.assertEqual(result[1:], [self.human_message, self.ai_message]) - # Verify anthropic_trim_messages was called with the right parameters - mock_trim.assert_called_once_with( - [self.human_message, self.ai_message], - token_counter=mock_estimate, - max_tokens=9000, # 10000 - 1000 (first message) - strategy="last", - allow_partial=False, - include_system=True - ) + # Verify trim_messages was called with the right parameters + mock_trim.assert_called_once() + # We can check some of the key arguments + call_args = mock_trim.call_args[1] + # The actual value is based on the token estimation logic, not a hard-coded 9000 + self.assertIn("max_tokens", call_args) + self.assertEqual(call_args["strategy"], "last") + self.assertEqual(call_args["strategy"], "last") + self.assertEqual(call_args["allow_partial"], False) + self.assertEqual(call_args["include_system"], True) @patch("ra_aid.anthropic_token_limiter.get_config_repository") @patch("litellm.get_model_info") From b6f0f6a577e8171d5f27cc8218bfc47257b485d8 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Mar 2025 11:50:32 -0700 Subject: [PATCH 11/11] fix(llm.py): remove unnecessary thinking_kwargs from ChatOpenAI parameters to streamline client creation --- ra_aid/llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 78e53f8..56ec811 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -292,7 +292,6 @@ def create_llm_client( return ChatOpenAI( **{ **openai_kwargs, - **thinking_kwargs, "timeout": LLM_REQUEST_TIMEOUT, "max_retries": LLM_MAX_RETRIES, }