Merge pull request #124 from ariel-frischer/fix-token-limiter

Fix Sonnet 3.7 Token Limiter API Errors
This commit is contained in:
Andrew I. Christianson 2025-03-12 14:59:33 -04:00 committed by GitHub
commit a9656552a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1315 additions and 422 deletions

View File

@ -39,35 +39,41 @@ from ra_aid.agents.research_agent import run_research_agent
from ra_aid.agents import run_planning_agent
from ra_aid.config import (
DEFAULT_MAX_TEST_CMD_RETRIES,
DEFAULT_MODEL,
DEFAULT_RECURSION_LIMIT,
DEFAULT_TEST_CMD_TIMEOUT,
VALID_PROVIDERS,
)
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
from ra_aid.database.repositories.key_fact_repository import (
KeyFactRepositoryManager,
get_key_fact_repository,
)
from ra_aid.database.repositories.key_snippet_repository import (
KeySnippetRepositoryManager, get_key_snippet_repository
KeySnippetRepositoryManager,
get_key_snippet_repository,
)
from ra_aid.database.repositories.human_input_repository import (
HumanInputRepositoryManager, get_human_input_repository
HumanInputRepositoryManager,
get_human_input_repository,
)
from ra_aid.database.repositories.research_note_repository import (
ResearchNoteRepositoryManager, get_research_note_repository
ResearchNoteRepositoryManager,
get_research_note_repository,
)
from ra_aid.database.repositories.trajectory_repository import (
TrajectoryRepositoryManager, get_trajectory_repository
TrajectoryRepositoryManager,
get_trajectory_repository,
)
from ra_aid.database.repositories.session_repository import (
SessionRepositoryManager, get_session_repository
)
from ra_aid.database.repositories.related_files_repository import (
RelatedFilesRepositoryManager
)
from ra_aid.database.repositories.work_log_repository import (
WorkLogRepositoryManager
RelatedFilesRepositoryManager,
)
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
from ra_aid.database.repositories.config_repository import (
ConfigRepositoryManager,
get_config_repository
get_config_repository,
)
from ra_aid.env_inv import EnvDiscovery
from ra_aid.env_inv_context import EnvInvManager, get_env_inv
@ -103,7 +109,7 @@ def launch_webui(host: str, port: int):
def parse_arguments(args=None):
ANTHROPIC_DEFAULT_MODEL = "claude-3-7-sonnet-20250219"
ANTHROPIC_DEFAULT_MODEL = DEFAULT_MODEL
OPENAI_DEFAULT_MODEL = "gpt-4o"
# Case-insensitive log level argument type
@ -202,8 +208,10 @@ Examples:
help="Enable chat mode with direct human interaction (implies --hil)",
)
parser.add_argument(
"--log-mode", choices=["console", "file"], default="file",
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console"
"--log-mode",
choices=["console", "file"],
default="file",
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console",
)
parser.add_argument(
"--pretty-logger", action="store_true", help="Enable pretty logging output"
@ -423,7 +431,9 @@ def build_status():
temperature = config_repo.get("temperature")
expert_provider = config_repo.get("expert_provider", "")
expert_model = config_repo.get("expert_model", "")
experimental_fallback_handler = config_repo.get("experimental_fallback_handler", False)
experimental_fallback_handler = config_repo.get(
"experimental_fallback_handler", False
)
web_research_enabled = config_repo.get("web_research_enabled", False)
# Get the expert enabled status
@ -483,7 +493,9 @@ def build_status():
logger.debug(f"Failed to get research notes count: {e}")
# Add memory statistics line with reset option note
status.append(f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes")
status.append(
f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes"
)
if fact_count > 0 or snippet_count > 0 or note_count > 0:
status.append(" (use --wipe-project-memory to reset)")
@ -568,7 +580,9 @@ def main():
expert_missing,
web_research_enabled,
web_research_missing,
) = validate_environment(args) # Will exit if main env vars missing
) = validate_environment(
args
) # Will exit if main env vars missing
logger.debug("Environment validation successful")
# Validate model configuration early
@ -604,12 +618,16 @@ def main():
config_repo.set("expert_provider", args.expert_provider)
config_repo.set("expert_model", args.expert_model)
config_repo.set("temperature", args.temperature)
config_repo.set("experimental_fallback_handler", args.experimental_fallback_handler)
config_repo.set(
"experimental_fallback_handler", args.experimental_fallback_handler
)
config_repo.set("web_research_enabled", web_research_enabled)
config_repo.set("show_thoughts", args.show_thoughts)
config_repo.set("show_cost", args.show_cost)
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
config_repo.set(
"disable_reasoning_assistance", args.no_reasoning_assistance
)
# Build status panel with memory statistics
status = build_status()
@ -684,7 +702,9 @@ def main():
try:
# Using get_human_input_repository() to access the repository from context
human_input_repository = get_human_input_repository()
human_input_repository.create(content=initial_request, source='chat')
human_input_repository.create(
content=initial_request, source="chat"
)
human_input_repository.garbage_collect()
except Exception as e:
logger.error(f"Failed to record initial chat input: {str(e)}")
@ -742,8 +762,12 @@ def main():
),
working_directory=working_directory,
current_date=current_date,
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
key_facts=format_key_facts_dict(
get_key_fact_repository().get_facts_dict()
),
key_snippets=format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
),
project_info=formatted_project_info,
env_inv=get_env_inv(),
),
@ -780,7 +804,7 @@ def main():
try:
# Using get_human_input_repository() to access the repository from context
human_input_repository = get_human_input_repository()
human_input_repository.create(content=base_task, source='cli')
human_input_repository.create(content=base_task, source="cli")
# Run garbage collection to ensure we don't exceed 100 inputs
human_input_repository.garbage_collect()
logger.debug(f"Recorded CLI input: {base_task}")
@ -814,11 +838,15 @@ def main():
config_repo.set("expert_model", args.expert_model)
# Store planner config with fallback to base values
config_repo.set("planner_provider", args.planner_provider or args.provider)
config_repo.set(
"planner_provider", args.planner_provider or args.provider
)
config_repo.set("planner_model", args.planner_model or args.model)
# Store research config with fallback to base values
config_repo.set("research_provider", args.research_provider or args.provider)
config_repo.set(
"research_provider", args.research_provider or args.provider
)
config_repo.set("research_model", args.research_model or args.model)
# Store temperature in config
@ -826,7 +854,9 @@ def main():
# Store reasoning assistance flags
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
config_repo.set(
"disable_reasoning_assistance", args.no_reasoning_assistance
)
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)
@ -870,5 +900,6 @@ def main():
print()
sys.exit(0)
if __name__ == "__main__":
main()

View File

@ -1,19 +1,14 @@
"""Utility functions for working with agents."""
import inspect
import os
import signal
import sys
import threading
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence
from typing import Any, Dict, List, Literal, Optional
from langchain_anthropic import ChatAnthropic
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from openai import RateLimitError as OpenAIRateLimitError
from litellm.exceptions import RateLimitError as LiteLLMRateLimitError
@ -23,28 +18,24 @@ from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
trim_messages,
)
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import get_model_info
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.agent_context import (
agent_context,
get_depth,
is_completed,
reset_completion_flags,
should_exit,
)
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.agents_alias import RAgents
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
from ra_aid.console.formatting import print_error, print_stage_header
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES
from ra_aid.console.formatting import print_error
from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import (
AgentInterrupt,
@ -53,77 +44,20 @@ from ra_aid.exceptions import (
)
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
from ra_aid.llm import initialize_expert_llm
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.text.processing import process_thinking_content
from ra_aid.project_info import (
display_project_status,
format_project_info,
get_project_info,
)
from ra_aid.prompts.expert_prompts import (
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
EXPERT_PROMPT_SECTION_PLANNING,
EXPERT_PROMPT_SECTION_RESEARCH,
)
from ra_aid.prompts.human_prompts import (
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
HUMAN_PROMPT_SECTION_PLANNING,
HUMAN_PROMPT_SECTION_RESEARCH,
)
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
from ra_aid.prompts.reasoning_assist_prompt import (
REASONING_ASSIST_PROMPT_PLANNING,
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
REASONING_ASSIST_PROMPT_RESEARCH,
)
from ra_aid.prompts.research_prompts import (
RESEARCH_ONLY_PROMPT,
RESEARCH_PROMPT,
)
from ra_aid.prompts.web_research_prompts import (
WEB_RESEARCH_PROMPT,
WEB_RESEARCH_PROMPT_SECTION_CHAT,
WEB_RESEARCH_PROMPT_SECTION_PLANNING,
WEB_RESEARCH_PROMPT_SECTION_RESEARCH,
)
from ra_aid.tool_configs import (
get_implementation_tools,
get_planning_tools,
get_research_tools,
get_web_research_tools,
)
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import (
get_key_snippet_repository,
)
from ra_aid.database.repositories.human_input_repository import (
get_human_input_repository,
)
from ra_aid.database.repositories.research_note_repository import (
get_research_note_repository,
)
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
from ra_aid.tools.memory import (
get_related_files,
log_work_event,
)
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.env_inv_context import get_env_inv
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
console = Console()
logger = get_logger(__name__)
# Import repositories using get_* functions
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
@tool
@ -133,131 +67,19 @@ def output_markdown_message(message: str) -> str:
return "Message output."
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
"""Helper function to estimate total tokens in a sequence of messages.
Args:
messages: Sequence of messages to count tokens for
Returns:
Total estimated token count
"""
if not messages:
return 0
estimate_tokens = CiaynAgent._estimate_tokens
return sum(estimate_tokens(msg) for msg in messages)
def state_modifier(
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
Args:
state: The current agent state containing messages
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
first_message = messages[0]
remaining_messages = messages[1:]
first_tokens = estimate_messages_tokens([first_message])
new_max_tokens = max_input_tokens - first_tokens
trimmed_remaining = trim_messages(
remaining_messages,
token_counter=estimate_messages_tokens,
max_tokens=new_max_tokens,
strategy="last",
allow_partial=False,
)
return [first_message] + trimmed_remaining
def get_model_token_limit(
config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
) -> Optional[int]:
"""Get the token limit for the current model configuration based on agent type.
Returns:
Optional[int]: The token limit if found, None otherwise
"""
try:
# Try to get config from repository for production use
try:
config_from_repo = get_config_repository().get_all()
# If we succeeded, use the repository config instead of passed config
config = config_from_repo
except RuntimeError:
# In tests, this may fail because the repository isn't set up
# So we'll use the passed config directly
pass
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "")
model_name = config.get("model", "")
try:
provider_model = model_name if not provider else f"{provider}/{model_name}"
model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens:
logger.debug(
f"Using litellm token limit for {model_name}: {max_input_tokens}"
)
return max_input_tokens
except litellm.exceptions.NotFoundError:
logger.debug(
f"Model {model_name} not found in litellm, falling back to models_params"
)
except Exception as e:
logger.debug(
f"Error getting model info from litellm: {e}, falling back to models_params"
)
# Fallback to models_params dict
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
normalized_name = model_name.replace("-", "")
provider_tokens = models_params.get(provider, {})
if normalized_name in provider_tokens:
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
logger.debug(
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
)
else:
max_input_tokens = None
logger.debug(f"Could not find token limit for {provider}/{model_name}")
return max_input_tokens
except Exception as e:
logger.warning(f"Failed to get model token limit: {e}")
return None
def build_agent_kwargs(
checkpointer: Optional[Any] = None,
model: ChatAnthropic = None,
max_input_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""Build kwargs dictionary for agent creation.
Args:
checkpointer: Optional memory checkpointer
config: Optional configuration dictionary
token_limit: Optional token limit for the model
model: The language model to use for token counting
max_input_tokens: Optional token limit for the model
Returns:
Dictionary of kwargs for agent creation
@ -270,12 +92,20 @@ def build_agent_kwargs(
agent_kwargs["checkpointer"] = checkpointer
config = get_config_repository().get_all()
if config.get("limit_tokens", True) and is_anthropic_claude(config):
if (
config.get("limit_tokens", True)
and is_anthropic_claude(config)
and model is not None
):
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
return state_modifier(state, max_input_tokens=max_input_tokens)
if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]):
return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens)
return state_modifier(state, model, max_input_tokens=max_input_tokens)
agent_kwargs["state_modifier"] = wrapped_state_modifier
agent_kwargs["name"] = "React"
return agent_kwargs
@ -345,7 +175,8 @@ def create_agent(
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.")
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
@ -358,16 +189,12 @@ def create_agent(
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
from ra_aid.agents.research_agent import run_research_agent, run_web_research_agent
from ra_aid.agents.implementation_agent import run_task_implementation_agent
_CONTEXT_STACK = []
_INTERRUPT_CONTEXT = None
_FEEDBACK_MODE = False

View File

@ -0,0 +1,312 @@
"""Utilities for handling Anthropic-specific message formats and trimming."""
from typing import Callable, List, Literal, Optional, Sequence, Union, cast
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
def _is_message_type(
message: BaseMessage, message_types: Union[str, type, List[Union[str, type]]]
) -> bool:
"""Check if a message is of a specific type or types.
Args:
message: The message to check
message_types: Type(s) to check against (string name or class)
Returns:
bool: True if message matches any of the specified types
"""
if not isinstance(message_types, list):
message_types = [message_types]
types_str = [t for t in message_types if isinstance(t, str)]
types_classes = tuple(t for t in message_types if isinstance(t, type))
return message.type in types_str or isinstance(message, types_classes)
def has_tool_use(message: BaseMessage) -> bool:
"""Check if a message contains tool use.
Args:
message: The message to check
Returns:
bool: True if the message contains tool use
"""
if not isinstance(message, AIMessage):
return False
# Check content for tool_use
if isinstance(message.content, str) and "tool_use" in message.content:
return True
# Check content list for tool_use blocks
if isinstance(message.content, list):
for item in message.content:
if isinstance(item, dict) and item.get("type") == "tool_use":
return True
# Check additional_kwargs for tool_calls
if hasattr(message, "additional_kwargs") and message.additional_kwargs.get(
"tool_calls"
):
return True
return False
def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
"""Check if two messages form a tool use/result pair.
Args:
message1: First message
message2: Second message
Returns:
bool: True if the messages form a tool use/result pair
"""
return (
isinstance(message1, AIMessage)
and isinstance(message2, ToolMessage)
and has_tool_use(message1)
)
def anthropic_trim_messages(
messages: Sequence[BaseMessage],
*,
max_tokens: int,
token_counter: Callable[[List[BaseMessage]], int],
strategy: Literal["first", "last"] = "last",
num_messages_to_keep: int = 2,
allow_partial: bool = False,
include_system: bool = True,
start_on: Optional[Union[str, type, List[Union[str, type]]]] = None,
) -> List[BaseMessage]:
"""Trim messages to fit within a token limit, with Anthropic-specific handling.
Warning - not fully implemented - last strategy is supported and test, not
allow partial, not 'first' strategy either.
This function is similar to langchain_core's trim_messages but with special
handling for Anthropic message formats to avoid API errors.
It always keeps the first num_messages_to_keep messages.
Args:
messages: Sequence of messages to trim
max_tokens: Maximum number of tokens allowed
token_counter: Function to count tokens in messages
strategy: Whether to keep the "first" or "last" messages
allow_partial: Whether to allow partial messages
include_system: Whether to always include the system message
start_on: Message type to start on (only for "last" strategy)
Returns:
List[BaseMessage]: Trimmed messages that fit within token limit
"""
if not messages:
return []
messages = list(messages)
# Always keep the first num_messages_to_keep messages
kept_messages = messages[:num_messages_to_keep]
remaining_msgs = messages[num_messages_to_keep:]
# For Anthropic, we need to maintain the conversation structure where:
# 1. Every AIMessage with tool_use must be followed by a ToolMessage
# 2. Every AIMessage that follows a ToolMessage must start with a tool_result
# First, check if we have any tool_use in the messages
has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages)
# If we have tool_use anywhere, we need to be very careful about trimming
if has_tool_use_anywhere:
# For safety, just keep all messages if we're under the token limit
if token_counter(messages) <= max_tokens:
return messages
# We need to identify all tool_use/tool_result relationships
# First, find all AIMessage+ToolMessage pairs
pairs = []
i = 0
while i < len(messages) - 1:
if is_tool_pair(messages[i], messages[i + 1]):
pairs.append((i, i + 1))
i += 2
else:
i += 1
# For Anthropic, we need to ensure that:
# 1. If we include an AIMessage with tool_use, we must include the following ToolMessage
# 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use
# The safest approach is to always keep complete AIMessage+ToolMessage pairs together
# First, identify all complete pairs
complete_pairs = []
for start, end in pairs:
complete_pairs.append((start, end))
# Now we'll build our result, starting with the kept_messages
# But we need to be careful about the first message if it has tool_use
result = []
# Check if the last message in kept_messages has tool_use
if (
kept_messages
and isinstance(kept_messages[-1], AIMessage)
and has_tool_use(kept_messages[-1])
):
# We need to find the corresponding ToolMessage
for i, (ai_idx, tool_idx) in enumerate(pairs):
if messages[ai_idx] is kept_messages[-1]:
# Found the pair, add all kept_messages except the last one
result.extend(kept_messages[:-1])
# Add the AIMessage and ToolMessage as a pair
result.extend([messages[ai_idx], messages[tool_idx]])
# Remove this pair from the list of pairs to process later
pairs = pairs[:i] + pairs[i + 1 :]
break
else:
# If we didn't find a matching pair, just add all kept_messages
result.extend(kept_messages)
else:
# No tool_use in the last kept message, just add all kept_messages
result.extend(kept_messages)
# If we're using the "last" strategy, we'll try to include pairs from the end
if strategy == "last":
# First collect all pairs we can include within the token limit
pairs_to_include = []
# Process pairs from the end (newest first)
for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)):
# Try adding this pair
test_msgs = result.copy()
# Add all previously selected pairs
for prev_ai_idx, prev_tool_idx in pairs_to_include:
test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]])
# Add this pair
test_msgs.extend([messages[ai_idx], messages[tool_idx]])
if token_counter(test_msgs) <= max_tokens:
# This pair fits, add it to our list
pairs_to_include.append((ai_idx, tool_idx))
else:
# This pair would exceed the token limit
break
# Now add the pairs in the correct order
# Sort by index to maintain the original conversation flow
pairs_to_include.sort(key=lambda x: x[0])
for ai_idx, tool_idx in pairs_to_include:
result.extend([messages[ai_idx], messages[tool_idx]])
# No need to sort - we've already added messages in the correct order
return result
# If no tool_use, proceed with normal segmentation
segments = []
i = 0
# Group messages into segments
while i < len(remaining_msgs):
segments.append([remaining_msgs[i]])
i += 1
# Now we have segments that maintain the required structure
# We'll add segments from the end (for "last" strategy) or beginning (for "first")
# until we hit the token limit
if strategy == "last":
# If we have no segments, just return kept_messages
if not segments:
return kept_messages
result = []
# Process segments from the end
for i, segment in enumerate(reversed(segments)):
# Try adding this segment
test_msgs = segment + result
if token_counter(kept_messages + test_msgs) <= max_tokens:
result = segment + result
else:
# This segment would exceed the token limit
break
final_result = kept_messages + result
# For Anthropic, we need to ensure the conversation follows a valid structure
# We'll do a final check of the entire conversation
# Validate the conversation structure
valid_result = []
i = 0
# Process messages in order
while i < len(final_result):
current_msg = final_result[i]
# If this is an AIMessage with tool_use, it must be followed by a ToolMessage
if (
i < len(final_result) - 1
and isinstance(current_msg, AIMessage)
and has_tool_use(current_msg)
):
if isinstance(final_result[i + 1], ToolMessage):
# This is a valid tool_use + tool_result pair
valid_result.append(current_msg)
valid_result.append(final_result[i + 1])
i += 2
else:
# Invalid: AIMessage with tool_use not followed by ToolMessage
# Skip this message to maintain valid structure
i += 1
else:
# Regular message, just add it
valid_result.append(current_msg)
i += 1
# Final check: don't end with an AIMessage that has tool_use
if (
valid_result
and isinstance(valid_result[-1], AIMessage)
and has_tool_use(valid_result[-1])
):
valid_result.pop() # Remove the last message
return valid_result
elif strategy == "first":
result = []
# Process segments from the beginning
for i, segment in enumerate(segments):
# Try adding this segment
test_msgs = result + segment
if token_counter(kept_messages + test_msgs) <= max_tokens:
result = result + segment
else:
# This segment would exceed the token limit
break
final_result = kept_messages + result
return final_result

View File

@ -0,0 +1,236 @@
"""Utilities for handling token limits with Anthropic models."""
from functools import partial
from typing import Any, Dict, List, Optional, Sequence
from dataclasses import dataclass
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import (
AIMessage,
BaseMessage,
RemoveMessage,
ToolMessage,
trim_messages,
)
from langchain_core.messages.base import message_to_dict
from ra_aid.anthropic_message_utils import (
anthropic_trim_messages,
has_tool_use,
)
from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import token_counter
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.console.output import cpm, print_messages_compact
logger = get_logger(__name__)
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
"""Helper function to estimate total tokens in a sequence of messages.
Args:
messages: Sequence of messages to count tokens for
Returns:
Total estimated token count
"""
if not messages:
return 0
estimate_tokens = CiaynAgent._estimate_tokens
return sum(estimate_tokens(msg) for msg in messages)
def convert_message_to_litellm_format(message: BaseMessage) -> Dict:
"""Convert a BaseMessage to the format expected by litellm.
Args:
message: The BaseMessage to convert
Returns:
Dict in litellm format
"""
message_dict = message_to_dict(message)
return {
"role": message_dict["type"],
"content": message_dict["data"]["content"],
}
def create_token_counter_wrapper(model: str):
"""Create a wrapper for token counter that handles BaseMessage conversion.
Args:
model: The model name to use for token counting
Returns:
A function that accepts BaseMessage objects and returns token count
"""
# Create a partial function that already has the model parameter set
base_token_counter = partial(token_counter, model=model)
def wrapped_token_counter(messages: List[BaseMessage]) -> int:
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
Args:
messages: List of BaseMessage objects
Returns:
Token count for the messages
"""
if not messages:
return 0
litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages]
result = base_token_counter(messages=litellm_messages)
return result
return wrapped_token_counter
def state_modifier(
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
This uses anthropic_trim_messages which always keeps the first 2 messages.
Args:
state: The current agent state containing messages
model: The language model to use for token counting
max_input_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
wrapped_token_counter = create_token_counter_wrapper(model.model)
result = anthropic_trim_messages(
messages,
token_counter=wrapped_token_counter,
max_tokens=max_input_tokens,
strategy="last",
allow_partial=False,
include_system=True,
num_messages_to_keep=2,
)
if len(result) < len(messages):
logger.info(f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages")
return result
def sonnet_35_state_modifier(
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
Args:
state: The current agent state containing messages
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
first_message = messages[0]
remaining_messages = messages[1:]
first_tokens = estimate_messages_tokens([first_message])
new_max_tokens = max_input_tokens - first_tokens
trimmed_remaining = trim_messages(
remaining_messages,
token_counter=estimate_messages_tokens,
max_tokens=new_max_tokens,
strategy="last",
allow_partial=False,
include_system=True,
)
result = [first_message] + trimmed_remaining
return result
def get_model_token_limit(
config: Dict[str, Any], agent_type: str = "default"
) -> Optional[int]:
"""Get the token limit for the current model configuration based on agent type.
Args:
config: Configuration dictionary containing provider and model information
agent_type: Type of agent ("default", "research", or "planner")
Returns:
Optional[int]: The token limit if found, None otherwise
"""
try:
# Try to get config from repository for production use
try:
config_from_repo = get_config_repository().get_all()
# If we succeeded, use the repository config instead of passed config
config = config_from_repo
except RuntimeError:
# In tests, this may fail because the repository isn't set up
# So we'll use the passed config directly
pass
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "")
model_name = config.get("model", "")
try:
from litellm import get_model_info
provider_model = model_name if not provider else f"{provider}/{model_name}"
model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens:
logger.debug(
f"Using litellm token limit for {model_name}: {max_input_tokens}"
)
return max_input_tokens
except Exception as e:
logger.debug(
f"Error getting model info from litellm: {e}, falling back to models_params"
)
# Fallback to models_params dict
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
normalized_name = model_name.replace("-", "")
provider_tokens = models_params.get(provider, {})
if normalized_name in provider_tokens:
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
logger.debug(
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
)
else:
max_input_tokens = None
logger.debug(f"Could not find token limit for {provider}/{model_name}")
return max_input_tokens
except Exception as e:
logger.warning(f"Failed to get model token limit: {e}")
return None

View File

@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3
FALLBACK_TOOL_MODEL_LIMIT = 5
RETRY_FALLBACK_COUNT = 3
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
DEFAULT_SHOW_COST = False

View File

@ -1,6 +1,6 @@
from typing import Any, Dict, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, Sequence
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from rich.markdown import Markdown
from rich.panel import Panel
@ -98,3 +98,57 @@ def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -
"""
console.print(Panel(Markdown(message), title=title, border_style=border_style))
def print_messages_compact(messages: Sequence[BaseMessage]) -> None:
"""Print a compact representation of a list of messages.
Warning: Used mainly for debugging purposes so do not delete if not referenced anywhere!
For all message types, only the first 30 characters of content are shown.
Args:
messages: A sequence of BaseMessage objects to print
"""
if not messages:
console.print("[italic]No messages[/italic]")
return
for i, msg in enumerate(messages):
msg_type = msg.__class__.__name__
content = msg.content
# Process content based on its type
if isinstance(content, str):
display_content = f"{content[:30]}..." if len(content) > 30 else content
elif isinstance(content, list):
# Handle structured content (list of content blocks)
content_preview = []
for item in content[:2]: # Show first 2 items at most
if isinstance(item, dict):
if item.get("type") == "text":
text = item.get("text", "")
content_preview.append(f"text: {text[:20]}..." if len(text) > 20 else f"text: {text}")
elif item.get("type") == "tool_call":
tool_name = item.get("tool_call", {}).get("name", "unknown")
content_preview.append(f"tool_call: {tool_name}")
else:
content_preview.append(f"{item.get('type', 'unknown')}")
if len(content) > 2:
content_preview.append(f"...({len(content)-2} more)")
display_content = ", ".join(content_preview)
else:
display_content = str(content)[:30] + "..." if len(str(content)) > 30 else str(content)
# Add additional tool message info if available
additional_info = []
if hasattr(msg, "tool_call_id") and msg.tool_call_id:
additional_info.append(f"tool_call_id: {msg.tool_call_id}")
if hasattr(msg, "name") and msg.name:
additional_info.append(f"name: {msg.name}")
if hasattr(msg, "status") and msg.status:
additional_info.append(f"status: {msg.status}")
info_str = f" ({', '.join(additional_info)})" if additional_info else ""
console.print(f"[{i}] [bold]{msg_type}{info_str}[/bold]: {display_content}")

View File

@ -259,8 +259,9 @@ def create_llm_client(
else:
temp_kwargs = {}
thinking_kwargs = {}
if supports_thinking:
temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
thinking_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
if provider == "deepseek":
return create_deepseek_client(
@ -268,6 +269,7 @@ def create_llm_client(
api_key=config["api_key"],
base_url=config["base_url"],
**temp_kwargs,
**thinking_kwargs,
is_expert=is_expert,
)
elif provider == "openrouter":
@ -275,6 +277,7 @@ def create_llm_client(
model_name=model_name,
api_key=config["api_key"],
**temp_kwargs,
**thinking_kwargs,
is_expert=is_expert,
)
elif provider == "openai":
@ -301,6 +304,7 @@ def create_llm_client(
max_retries=LLM_MAX_RETRIES,
max_tokens=model_config.get("max_tokens", 64000),
**temp_kwargs,
**thinking_kwargs,
)
elif provider == "openai-compatible":
return ChatOpenAI(
@ -310,6 +314,7 @@ def create_llm_client(
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
**temp_kwargs,
**thinking_kwargs,
)
elif provider == "gemini":
return ChatGoogleGenerativeAI(
@ -318,6 +323,7 @@ def create_llm_client(
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
**temp_kwargs,
**thinking_kwargs,
)
else:
raise ValueError(f"Unsupported provider: {provider}")

View File

@ -14,8 +14,9 @@ from ra_aid.agent_context import (
is_crashed,
reset_completion_flags,
)
from ra_aid.config import DEFAULT_MODEL
from ra_aid.console.formatting import print_error, print_task_header
from ra_aid.database.repositories.human_input_repository import HumanInputRepository, get_human_input_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.config_repository import get_config_repository
@ -385,7 +386,7 @@ def request_task_implementation(task_spec: str) -> str:
config = get_config_repository().get_all()
model = initialize_llm(
config.get("provider", "anthropic"),
config.get("model", "claude-3-5-sonnet-20241022"),
config.get("model",DEFAULT_MODEL),
temperature=config.get("temperature"),
)
@ -552,7 +553,7 @@ def request_implementation(task_spec: str) -> str:
config = get_config_repository().get_all()
model = initialize_llm(
config.get("provider", "anthropic"),
config.get("model", "claude-3-5-sonnet-20241022"),
config.get("model", DEFAULT_MODEL),
temperature=config.get("temperature"),
)

View File

@ -14,12 +14,18 @@ from ra_aid.agent_context import (
from ra_aid.agent_utils import (
AgentState,
create_agent,
get_model_token_limit,
is_anthropic_claude,
)
from ra_aid.anthropic_token_limiter import (
get_model_token_limit,
state_modifier,
)
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
from ra_aid.database.repositories.config_repository import (
ConfigRepositoryManager,
get_config_repository,
config_repo_var,
)
@pytest.fixture
@ -32,7 +38,9 @@ def mock_model():
@pytest.fixture
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
with patch(
"ra_aid.database.repositories.config_repository.config_repo_var"
) as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
@ -42,6 +50,7 @@ def mock_config_repository():
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Setup get_all method to return all config values
@ -50,11 +59,13 @@ def mock_config_repository():
# Setup set method to update config values
def set_config(key, value):
config[key] = value
mock_repo.set.side_effect = set_config
# Setup update method to update multiple config values
def update_config(update_dict):
config.update(update_dict)
mock_repo.update.side_effect = update_config
# Make the mock context var return our mock repo
@ -66,7 +77,9 @@ def mock_config_repository():
@pytest.fixture(autouse=True)
def mock_trajectory_repository():
"""Mock the TrajectoryRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
with patch(
"ra_aid.database.repositories.trajectory_repository.trajectory_repo_var"
) as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
@ -75,6 +88,7 @@ def mock_trajectory_repository():
mock_trajectory = MagicMock()
mock_trajectory.id = 1
return mock_trajectory
mock_repo.create.side_effect = mock_create
# Make the mock context var return our mock repo
@ -86,7 +100,9 @@ def mock_trajectory_repository():
@pytest.fixture(autouse=True)
def mock_human_input_repository():
"""Mock the HumanInputRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
with patch(
"ra_aid.database.repositories.human_input_repository.human_input_repo_var"
) as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
@ -99,87 +115,14 @@ def mock_human_input_repository():
yield mock_repo
def test_get_model_token_limit_anthropic(mock_config_repository):
"""Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
def test_get_model_token_limit_openai(mock_config_repository):
"""Test get_model_token_limit with OpenAI model."""
config = {"provider": "openai", "model": "gpt-4"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
def test_get_model_token_limit_unknown(mock_config_repository):
"""Test get_model_token_limit with unknown provider/model."""
config = {"provider": "unknown", "model": "unknown-model"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit is None
def test_get_model_token_limit_missing_config(mock_config_repository):
"""Test get_model_token_limit with missing configuration."""
config = {}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit is None
def test_get_model_token_limit_litellm_success():
"""Test get_model_token_limit successfully getting limit from litellm."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 100000}
token_limit = get_model_token_limit(config, "default")
assert token_limit == 100000
def test_get_model_token_limit_litellm_not_found():
"""Test fallback to models_tokens when litellm raises NotFoundError."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found", model="claude-2", llm_provider="anthropic"
)
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
def test_get_model_token_limit_litellm_error():
"""Test fallback to models_tokens when litellm raises other exceptions."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = Exception("Unknown error")
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
def test_get_model_token_limit_unexpected_error():
"""Test returning None when unexpected errors occur."""
config = None # This will cause an attribute error when accessed
token_limit = get_model_token_limit(config, "default")
assert token_limit is None
def test_create_agent_anthropic(mock_model, mock_config_repository):
"""Test create_agent with Anthropic Claude model."""
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier,
):
mock_react.return_value = "react_agent"
agent = create_agent(mock_model, [])
@ -187,9 +130,10 @@ def test_create_agent_anthropic(mock_model, mock_config_repository):
mock_react.assert_called_once_with(
mock_model,
[],
interrupt_after=['tools'],
interrupt_after=["tools"],
version="v2",
state_modifier=mock_react.call_args[1]["state_modifier"],
name="React",
)
@ -257,20 +201,7 @@ def mock_messages():
]
def test_state_modifier(mock_messages):
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
state = AgentState(messages=mock_messages)
with patch(
"ra_aid.agent_backends.ciayn_agent.CiaynAgent._estimate_tokens"
) as mock_estimate:
mock_estimate.side_effect = lambda msg: 100 if msg else 0
result = state_modifier(state, max_input_tokens=250)
assert len(result) < len(mock_messages)
assert isinstance(result[0], SystemMessage)
assert result[-1] == mock_messages[-1]
# This test has been moved to test_anthropic_token_limiter.py
def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
@ -291,17 +222,21 @@ def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
)
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
def test_create_agent_anthropic_token_limiting_enabled(
mock_model, mock_config_repository
):
"""Test create_agent sets up token limiting for Claude models when enabled."""
mock_config_repository.update({
mock_config_repository.update(
{
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": True,
})
}
)
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
):
mock_react.return_value = "react_agent"
mock_limit.return_value = 100000
@ -314,17 +249,21 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r
assert callable(args[1]["state_modifier"])
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
def test_create_agent_anthropic_token_limiting_disabled(
mock_model, mock_config_repository
):
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
mock_config_repository.update({
mock_config_repository.update(
{
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": False,
})
}
)
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
):
mock_react.return_value = "react_agent"
mock_limit.return_value = 100000
@ -332,39 +271,12 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
agent = create_agent(mock_model, [])
assert agent == "react_agent"
mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2")
mock_react.assert_called_once_with(
mock_model, [], interrupt_after=["tools"], version="v2", name="React"
)
def test_get_model_token_limit_research(mock_config_repository):
"""Test get_model_token_limit with research provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"research_provider": "anthropic",
"research_model": "claude-2",
}
mock_config_repository.update(config)
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 150000}
token_limit = get_model_token_limit(config, "research")
assert token_limit == 150000
def test_get_model_token_limit_planner(mock_config_repository):
"""Test get_model_token_limit with planner provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"planner_provider": "deepseek",
"planner_model": "dsm-1",
}
mock_config_repository.update(config)
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000
# These tests have been moved to test_anthropic_token_limiter.py
# New tests for private helper methods in agent_utils.py
@ -629,7 +541,9 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repos
assert "Agent has crashed: Test crash message" in result
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_repository):
def test_run_agent_with_retry_handles_badrequest_error(
monkeypatch, mock_config_repository
):
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
from ra_aid.agent_context import agent_context, is_crashed
from ra_aid.agent_utils import run_agent_with_retry
@ -687,7 +601,9 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_
assert is_crashed()
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch, mock_config_repository):
def test_run_agent_with_retry_handles_api_badrequest_error(
monkeypatch, mock_config_repository
):
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
# Import APIError from anthropic module and patch it on the agent_utils module
@ -760,5 +676,7 @@ def test_handle_api_error_resource_exhausted():
from ra_aid.agent_utils import _handle_api_error
# ResourceExhausted exception should be handled without raising
resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).")
resource_exhausted_error = ResourceExhausted(
"429 Resource has been exhausted (e.g. check quota)."
)
_handle_api_error(resource_exhausted_error, 0, 5, 1)

View File

@ -0,0 +1,507 @@
import unittest
from unittest.mock import MagicMock, patch
import litellm
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
ToolMessage
)
from langgraph.prebuilt.chat_agent_executor import AgentState
from ra_aid.anthropic_token_limiter import (
create_token_counter_wrapper,
estimate_messages_tokens,
get_model_token_limit,
state_modifier,
sonnet_35_state_modifier,
convert_message_to_litellm_format
)
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT
class TestAnthropicTokenLimiter(unittest.TestCase):
def setUp(self):
from ra_aid.config import DEFAULT_MODEL
self.mock_model = MagicMock(spec=ChatAnthropic)
self.mock_model.model = DEFAULT_MODEL
# Sample messages for testing
self.system_message = SystemMessage(content="You are a helpful assistant.")
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
self.ai_message = AIMessage(content="I'd be happy to help! What do you need?")
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
# Create more messages for testing
self.extra_messages = [
HumanMessage(content=f"Extra message {i}") for i in range(5)
]
# Mock state for testing state_modifier with many messages
self.state = AgentState(
messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages,
next=None,
)
# Create tool-related messages for testing
self.ai_with_tool_use = AIMessage(
content="I'll use a tool to help you",
additional_kwargs={"tool_calls": [{"name": "calculator", "input": {"expression": "2+2"}}]}
)
self.tool_message = ToolMessage(
content="4",
tool_call_id="tool_call_1",
name="calculator"
)
def test_convert_message_to_litellm_format(self):
"""Test conversion of BaseMessage to litellm format."""
# Test human message
human_result = convert_message_to_litellm_format(self.human_message)
self.assertEqual(human_result["role"], "human")
self.assertEqual(human_result["content"], "Hello, can you help me with a task?")
# Test system message
system_result = convert_message_to_litellm_format(self.system_message)
self.assertEqual(system_result["role"], "system")
self.assertEqual(system_result["content"], "You are a helpful assistant.")
# Test AI message
ai_result = convert_message_to_litellm_format(self.ai_message)
self.assertEqual(ai_result["role"], "ai")
self.assertEqual(ai_result["content"], "I'd be happy to help! What do you need?")
@patch("ra_aid.anthropic_token_limiter.token_counter")
def test_create_token_counter_wrapper(self, mock_token_counter):
from ra_aid.config import DEFAULT_MODEL
# Setup mock return values
mock_token_counter.return_value = 50
# Create the wrapper
wrapper = create_token_counter_wrapper(DEFAULT_MODEL)
# Test with BaseMessage objects
result = wrapper([self.human_message])
self.assertEqual(result, 50)
# Test with empty list
result = wrapper([])
self.assertEqual(result, 0)
# Verify the mock was called with the right parameters
mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL)
@patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens")
def test_estimate_messages_tokens(self, mock_estimate_tokens):
# Setup mock to return different values for different messages
mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20
# Test with multiple messages
messages = [self.system_message, self.human_message]
result = estimate_messages_tokens(messages)
# Should be sum of individual token counts (10 + 20)
self.assertEqual(result, 30)
# Test with empty list
result = estimate_messages_tokens([])
self.assertEqual(result, 0)
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
# Setup a proper token counter function that returns integers
def token_counter(msgs):
# Return token count based on number of messages
return len(msgs) * 10
# Configure the mock to return our token counter
mock_create_wrapper.return_value = token_counter
# Configure anthropic_trim_messages to return a subset of messages
mock_trim_messages.return_value = [self.system_message, self.human_message]
# Call state_modifier with a max token limit of 50
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
# Should return what anthropic_trim_messages returned
self.assertEqual(result, [self.system_message, self.human_message])
# Verify the wrapper was created with the right model
mock_create_wrapper.assert_called_with(self.mock_model.model)
# Verify anthropic_trim_messages was called with the right parameters
mock_trim_messages.assert_called_once()
def test_state_modifier_with_messages(self):
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
# Create a state with messages
messages = [
SystemMessage(content="System prompt"),
HumanMessage(content="Human message 1"),
AIMessage(content="AI response 1"),
HumanMessage(content="Human message 2"),
AIMessage(content="AI response 2"),
]
state = AgentState(messages=messages)
model = MagicMock(spec=ChatAnthropic)
model.model = "claude-3-opus-20240229"
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
# Setup mock to return a fixed token count per message
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
# Setup mock to return a subset of messages
mock_trim.return_value = [messages[0], messages[-2], messages[-1]]
result = state_modifier(state, model, max_input_tokens=250)
# Should return what anthropic_trim_messages returned
self.assertEqual(len(result), 3)
self.assertEqual(result[0], messages[0]) # First message preserved
self.assertEqual(result[-1], messages[-1]) # Last message preserved
def test_sonnet_35_state_modifier(self):
"""Test the sonnet 35 state modifier function."""
# Create a state with messages
state = {"messages": [self.system_message, self.human_message, self.ai_message]}
# Test with empty messages
empty_state = {"messages": []}
# Instead of patching trim_messages which has complex internal logic,
# we'll directly patch the sonnet_35_state_modifier's call to trim_messages
with patch("ra_aid.anthropic_token_limiter.trim_messages") as mock_trim:
# Setup mock to return our desired messages
mock_trim.return_value = [self.human_message, self.ai_message]
# Test with empty messages
self.assertEqual(sonnet_35_state_modifier(empty_state), [])
# Test with messages under the limit
result = sonnet_35_state_modifier(state, max_input_tokens=10000)
# Should keep the first message and call trim_messages for the rest
self.assertEqual(len(result), 3)
self.assertEqual(result[0], self.system_message)
self.assertEqual(result[1:], [self.human_message, self.ai_message])
# Verify trim_messages was called with the right parameters
mock_trim.assert_called_once()
# We can check some of the key arguments
call_args = mock_trim.call_args[1]
# The actual value is based on the token estimation logic, not a hard-coded 9000
self.assertIn("max_tokens", call_args)
self.assertEqual(call_args["strategy"], "last")
self.assertEqual(call_args["strategy"], "last")
self.assertEqual(call_args["allow_partial"], False)
self.assertEqual(call_args["include_system"], True)
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
from ra_aid.config import DEFAULT_MODEL
# Setup mocks
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
mock_get_config_repo.return_value.get_all.return_value = mock_config
# Mock litellm's get_model_info to return a token limit
mock_get_model_info.return_value = {"max_input_tokens": 100000}
# Test getting token limit
result = get_model_token_limit(mock_config)
self.assertEqual(result, 100000)
# Verify get_model_info was called with the right model
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
def test_get_model_token_limit_research(self):
"""Test get_model_token_limit with research provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"research_provider": "anthropic",
"research_model": "claude-2",
}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 150000}
token_limit = get_model_token_limit(config, "research")
self.assertEqual(token_limit, 150000)
# Verify get_model_info was called with the research model
mock_get_info.assert_called_with("anthropic/claude-2")
def test_get_model_token_limit_planner(self):
"""Test get_model_token_limit with planner provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"planner_provider": "deepseek",
"planner_model": "dsm-1",
}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")
self.assertEqual(token_limit, 120000)
# Verify get_model_info was called with the planner model
mock_get_info.assert_called_with("deepseek/dsm-1")
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
# Setup mocks
mock_config = {"provider": "anthropic", "model": "claude-2"}
mock_get_config_repo.return_value.get_all.return_value = mock_config
# Make litellm's get_model_info raise an exception to test fallback
mock_get_model_info.side_effect = Exception("Model not found")
# Test getting token limit from models_params fallback
with patch("ra_aid.anthropic_token_limiter.models_params", {
"anthropic": {
"claude2": {"token_limit": 100000}
}
}):
result = get_model_token_limit(mock_config)
self.assertEqual(result, 100000)
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
from ra_aid.config import DEFAULT_MODEL
# Setup mocks for different agent types
mock_config = {
"provider": "anthropic",
"model": DEFAULT_MODEL,
"research_provider": "openai",
"research_model": "gpt-4",
"planner_provider": "anthropic",
"planner_model": "claude-3-sonnet-20240229"
}
mock_get_config_repo.return_value.get_all.return_value = mock_config
# Mock different returns for different models
def model_info_side_effect(model_name):
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
return {"max_input_tokens": 200000}
elif "gpt-4" in model_name:
return {"max_input_tokens": 8192}
elif "claude-3-sonnet" in model_name:
return {"max_input_tokens": 100000}
else:
raise Exception(f"Unknown model: {model_name}")
mock_get_model_info.side_effect = model_info_side_effect
# Test default agent type
result = get_model_token_limit(mock_config, "default")
self.assertEqual(result, 200000)
# Test research agent type
result = get_model_token_limit(mock_config, "research")
self.assertEqual(result, 8192)
# Test planner agent type
result = get_model_token_limit(mock_config, "planner")
self.assertEqual(result, 100000)
def test_get_model_token_limit_anthropic(self):
"""Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
mock_get_config_repo.return_value.get_all.return_value = config
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
def test_get_model_token_limit_openai(self):
"""Test get_model_token_limit with OpenAI model."""
config = {"provider": "openai", "model": "gpt-4"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
mock_get_config_repo.return_value.get_all.return_value = config
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["openai"]["gpt-4"]["token_limit"])
def test_get_model_token_limit_unknown(self):
"""Test get_model_token_limit with unknown provider/model."""
config = {"provider": "unknown", "model": "unknown-model"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
mock_get_config_repo.return_value.get_all.return_value = config
token_limit = get_model_token_limit(config, "default")
self.assertIsNone(token_limit)
def test_get_model_token_limit_missing_config(self):
"""Test get_model_token_limit with missing configuration."""
config = {}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
mock_get_config_repo.return_value.get_all.return_value = config
token_limit = get_model_token_limit(config, "default")
self.assertIsNone(token_limit)
def test_get_model_token_limit_litellm_success(self):
"""Test get_model_token_limit successfully getting limit from litellm."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.return_value = {"max_input_tokens": 100000}
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, 100000)
mock_get_info.assert_called_with("anthropic/claude-2")
def test_get_model_token_limit_litellm_not_found(self):
"""Test fallback to models_tokens when litellm raises NotFoundError."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found", model="claude-2", llm_provider="anthropic"
)
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
def test_get_model_token_limit_litellm_error(self):
"""Test fallback to models_tokens when litellm raises other exceptions."""
config = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
patch("litellm.get_model_info") as mock_get_info:
mock_get_config_repo.return_value.get_all.return_value = config
mock_get_info.side_effect = Exception("Unknown error")
token_limit = get_model_token_limit(config, "default")
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
def test_get_model_token_limit_unexpected_error(self):
"""Test returning None when unexpected errors occur."""
config = None # This will cause an attribute error when accessed
token_limit = get_model_token_limit(config, "default")
self.assertIsNone(token_limit)
def test_has_tool_use(self):
"""Test the has_tool_use function."""
# Test with regular AI message
self.assertFalse(has_tool_use(self.ai_message))
# Test with AI message containing tool_use in string content
ai_with_tool_str = AIMessage(content="I'll use a tool_use to help you")
self.assertTrue(has_tool_use(ai_with_tool_str))
# Test with AI message containing tool_use in structured content
ai_with_tool_dict = AIMessage(content=[
{"type": "text", "text": "I'll use a tool to help you"},
{"type": "tool_use", "tool_use": {"name": "calculator", "input": {"expression": "2+2"}}}
])
self.assertTrue(has_tool_use(ai_with_tool_dict))
# Test with AI message containing tool_calls in additional_kwargs
self.assertTrue(has_tool_use(self.ai_with_tool_use))
# Test with non-AI message
self.assertFalse(has_tool_use(self.human_message))
def test_is_tool_pair(self):
"""Test the is_tool_pair function."""
# Test with valid tool pair
self.assertTrue(is_tool_pair(self.ai_with_tool_use, self.tool_message))
# Test with non-tool pair (wrong order)
self.assertFalse(is_tool_pair(self.tool_message, self.ai_with_tool_use))
# Test with non-tool pair (wrong types)
self.assertFalse(is_tool_pair(self.ai_message, self.human_message))
# Test with non-tool pair (AI message without tool use)
self.assertFalse(is_tool_pair(self.ai_message, self.tool_message))
@patch("ra_aid.anthropic_message_utils.has_tool_use")
def test_anthropic_trim_messages_with_tool_use(self, mock_has_tool_use):
"""Test anthropic_trim_messages with a sequence of messages including tool use."""
from ra_aid.anthropic_message_utils import anthropic_trim_messages
# Setup mock for has_tool_use to return True for AI messages at even indices
def side_effect(msg):
if isinstance(msg, AIMessage) and hasattr(msg, 'test_index'):
return msg.test_index % 2 == 0 # Even indices have tool use
return False
mock_has_tool_use.side_effect = side_effect
# Create a sequence of alternating human and AI messages with tool use
messages = []
# Start with system message
system_msg = SystemMessage(content="You are a helpful assistant.")
messages.append(system_msg)
# Add alternating human and AI messages with tool use
for i in range(8):
if i % 2 == 0:
# Human message
msg = HumanMessage(content=f"Human message {i}")
messages.append(msg)
else:
# AI message, every other one has tool use
ai_msg = AIMessage(content=f"AI message {i}")
# Add a test_index attribute to track position
ai_msg.test_index = i
messages.append(ai_msg)
# If this AI message has tool use (even index), add a tool message after it
if i % 4 == 1: # 1, 5, etc.
tool_msg = ToolMessage(
content=f"Tool result {i}",
tool_call_id=f"tool_call_{i}",
name="test_tool"
)
messages.append(tool_msg)
# Define a token counter that returns a fixed value per message
def token_counter(msgs):
return len(msgs) * 1000
# Test with a token limit that will require trimming
result = anthropic_trim_messages(
messages,
token_counter=token_counter,
max_tokens=5000, # This will allow 5 messages
strategy="last",
allow_partial=False,
include_system=True,
num_messages_to_keep=2 # Keep system and first human message
)
# We should have kept the first 2 messages (system + human)
self.assertEqual(len(result), 5) # 2 kept + 3 more that fit in token limit
self.assertEqual(result[0], system_msg)
# Verify that we don't have any AI messages with tool use that aren't followed by a tool message
for i in range(len(result) - 1):
if isinstance(result[i], AIMessage) and mock_has_tool_use(result[i]):
self.assertTrue(isinstance(result[i+1], ToolMessage),
f"AI message with tool use at index {i} not followed by ToolMessage")
if __name__ == "__main__":
unittest.main()