Merge pull request #124 from ariel-frischer/fix-token-limiter
Fix Sonnet 3.7 Token Limiter API Errors
This commit is contained in:
commit
a9656552a9
|
|
@ -39,35 +39,41 @@ from ra_aid.agents.research_agent import run_research_agent
|
|||
from ra_aid.agents import run_planning_agent
|
||||
from ra_aid.config import (
|
||||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_RECURSION_LIMIT,
|
||||
DEFAULT_TEST_CMD_TIMEOUT,
|
||||
VALID_PROVIDERS,
|
||||
)
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_fact_repository import (
|
||||
KeyFactRepositoryManager,
|
||||
get_key_fact_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.key_snippet_repository import (
|
||||
KeySnippetRepositoryManager, get_key_snippet_repository
|
||||
KeySnippetRepositoryManager,
|
||||
get_key_snippet_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
HumanInputRepositoryManager, get_human_input_repository
|
||||
HumanInputRepositoryManager,
|
||||
get_human_input_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.research_note_repository import (
|
||||
ResearchNoteRepositoryManager, get_research_note_repository
|
||||
ResearchNoteRepositoryManager,
|
||||
get_research_note_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.trajectory_repository import (
|
||||
TrajectoryRepositoryManager, get_trajectory_repository
|
||||
TrajectoryRepositoryManager,
|
||||
get_trajectory_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.session_repository import (
|
||||
SessionRepositoryManager, get_session_repository
|
||||
)
|
||||
from ra_aid.database.repositories.related_files_repository import (
|
||||
RelatedFilesRepositoryManager
|
||||
)
|
||||
from ra_aid.database.repositories.work_log_repository import (
|
||||
WorkLogRepositoryManager
|
||||
RelatedFilesRepositoryManager,
|
||||
)
|
||||
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
|
||||
from ra_aid.database.repositories.config_repository import (
|
||||
ConfigRepositoryManager,
|
||||
get_config_repository
|
||||
get_config_repository,
|
||||
)
|
||||
from ra_aid.env_inv import EnvDiscovery
|
||||
from ra_aid.env_inv_context import EnvInvManager, get_env_inv
|
||||
|
|
@ -103,7 +109,7 @@ def launch_webui(host: str, port: int):
|
|||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
ANTHROPIC_DEFAULT_MODEL = "claude-3-7-sonnet-20250219"
|
||||
ANTHROPIC_DEFAULT_MODEL = DEFAULT_MODEL
|
||||
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
# Case-insensitive log level argument type
|
||||
|
|
@ -202,8 +208,10 @@ Examples:
|
|||
help="Enable chat mode with direct human interaction (implies --hil)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-mode", choices=["console", "file"], default="file",
|
||||
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console"
|
||||
"--log-mode",
|
||||
choices=["console", "file"],
|
||||
default="file",
|
||||
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretty-logger", action="store_true", help="Enable pretty logging output"
|
||||
|
|
@ -423,7 +431,9 @@ def build_status():
|
|||
temperature = config_repo.get("temperature")
|
||||
expert_provider = config_repo.get("expert_provider", "")
|
||||
expert_model = config_repo.get("expert_model", "")
|
||||
experimental_fallback_handler = config_repo.get("experimental_fallback_handler", False)
|
||||
experimental_fallback_handler = config_repo.get(
|
||||
"experimental_fallback_handler", False
|
||||
)
|
||||
web_research_enabled = config_repo.get("web_research_enabled", False)
|
||||
|
||||
# Get the expert enabled status
|
||||
|
|
@ -483,7 +493,9 @@ def build_status():
|
|||
logger.debug(f"Failed to get research notes count: {e}")
|
||||
|
||||
# Add memory statistics line with reset option note
|
||||
status.append(f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes")
|
||||
status.append(
|
||||
f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes"
|
||||
)
|
||||
if fact_count > 0 or snippet_count > 0 or note_count > 0:
|
||||
status.append(" (use --wipe-project-memory to reset)")
|
||||
|
||||
|
|
@ -568,7 +580,9 @@ def main():
|
|||
expert_missing,
|
||||
web_research_enabled,
|
||||
web_research_missing,
|
||||
) = validate_environment(args) # Will exit if main env vars missing
|
||||
) = validate_environment(
|
||||
args
|
||||
) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
|
||||
# Validate model configuration early
|
||||
|
|
@ -604,12 +618,16 @@ def main():
|
|||
config_repo.set("expert_provider", args.expert_provider)
|
||||
config_repo.set("expert_model", args.expert_model)
|
||||
config_repo.set("temperature", args.temperature)
|
||||
config_repo.set("experimental_fallback_handler", args.experimental_fallback_handler)
|
||||
config_repo.set(
|
||||
"experimental_fallback_handler", args.experimental_fallback_handler
|
||||
)
|
||||
config_repo.set("web_research_enabled", web_research_enabled)
|
||||
config_repo.set("show_thoughts", args.show_thoughts)
|
||||
config_repo.set("show_cost", args.show_cost)
|
||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
||||
config_repo.set(
|
||||
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||
)
|
||||
|
||||
# Build status panel with memory statistics
|
||||
status = build_status()
|
||||
|
|
@ -684,7 +702,9 @@ def main():
|
|||
try:
|
||||
# Using get_human_input_repository() to access the repository from context
|
||||
human_input_repository = get_human_input_repository()
|
||||
human_input_repository.create(content=initial_request, source='chat')
|
||||
human_input_repository.create(
|
||||
content=initial_request, source="chat"
|
||||
)
|
||||
human_input_repository.garbage_collect()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record initial chat input: {str(e)}")
|
||||
|
|
@ -742,8 +762,12 @@ def main():
|
|||
),
|
||||
working_directory=working_directory,
|
||||
current_date=current_date,
|
||||
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
|
||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
||||
key_facts=format_key_facts_dict(
|
||||
get_key_fact_repository().get_facts_dict()
|
||||
),
|
||||
key_snippets=format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
),
|
||||
project_info=formatted_project_info,
|
||||
env_inv=get_env_inv(),
|
||||
),
|
||||
|
|
@ -780,7 +804,7 @@ def main():
|
|||
try:
|
||||
# Using get_human_input_repository() to access the repository from context
|
||||
human_input_repository = get_human_input_repository()
|
||||
human_input_repository.create(content=base_task, source='cli')
|
||||
human_input_repository.create(content=base_task, source="cli")
|
||||
# Run garbage collection to ensure we don't exceed 100 inputs
|
||||
human_input_repository.garbage_collect()
|
||||
logger.debug(f"Recorded CLI input: {base_task}")
|
||||
|
|
@ -814,11 +838,15 @@ def main():
|
|||
config_repo.set("expert_model", args.expert_model)
|
||||
|
||||
# Store planner config with fallback to base values
|
||||
config_repo.set("planner_provider", args.planner_provider or args.provider)
|
||||
config_repo.set(
|
||||
"planner_provider", args.planner_provider or args.provider
|
||||
)
|
||||
config_repo.set("planner_model", args.planner_model or args.model)
|
||||
|
||||
# Store research config with fallback to base values
|
||||
config_repo.set("research_provider", args.research_provider or args.provider)
|
||||
config_repo.set(
|
||||
"research_provider", args.research_provider or args.provider
|
||||
)
|
||||
config_repo.set("research_model", args.research_model or args.model)
|
||||
|
||||
# Store temperature in config
|
||||
|
|
@ -826,7 +854,9 @@ def main():
|
|||
|
||||
# Store reasoning assistance flags
|
||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
||||
config_repo.set(
|
||||
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||
)
|
||||
|
||||
# Set modification tools based on use_aider flag
|
||||
set_modification_tools(args.use_aider)
|
||||
|
|
@ -870,5 +900,6 @@ def main():
|
|||
print()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,19 +1,14 @@
|
|||
"""Utility functions for working with agents."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||
|
||||
|
||||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
from openai import RateLimitError as OpenAIRateLimitError
|
||||
from litellm.exceptions import RateLimitError as LiteLLMRateLimitError
|
||||
|
|
@ -23,28 +18,24 @@ from langchain_core.messages import (
|
|||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import get_model_info
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
agent_context,
|
||||
get_depth,
|
||||
is_completed,
|
||||
reset_completion_flags,
|
||||
should_exit,
|
||||
)
|
||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||
from ra_aid.agents_alias import RAgents
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||
from ra_aid.console.formatting import print_error, print_stage_header
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.exceptions import (
|
||||
AgentInterrupt,
|
||||
|
|
@ -53,77 +44,20 @@ from ra_aid.exceptions import (
|
|||
)
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.llm import initialize_expert_llm
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.text.processing import process_thinking_content
|
||||
from ra_aid.project_info import (
|
||||
display_project_status,
|
||||
format_project_info,
|
||||
get_project_info,
|
||||
)
|
||||
from ra_aid.prompts.expert_prompts import (
|
||||
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
|
||||
EXPERT_PROMPT_SECTION_PLANNING,
|
||||
EXPERT_PROMPT_SECTION_RESEARCH,
|
||||
)
|
||||
from ra_aid.prompts.human_prompts import (
|
||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
|
||||
HUMAN_PROMPT_SECTION_PLANNING,
|
||||
HUMAN_PROMPT_SECTION_RESEARCH,
|
||||
)
|
||||
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
||||
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
||||
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
||||
from ra_aid.prompts.reasoning_assist_prompt import (
|
||||
REASONING_ASSIST_PROMPT_PLANNING,
|
||||
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
|
||||
REASONING_ASSIST_PROMPT_RESEARCH,
|
||||
)
|
||||
from ra_aid.prompts.research_prompts import (
|
||||
RESEARCH_ONLY_PROMPT,
|
||||
RESEARCH_PROMPT,
|
||||
)
|
||||
from ra_aid.prompts.web_research_prompts import (
|
||||
WEB_RESEARCH_PROMPT,
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT,
|
||||
WEB_RESEARCH_PROMPT_SECTION_PLANNING,
|
||||
WEB_RESEARCH_PROMPT_SECTION_RESEARCH,
|
||||
)
|
||||
from ra_aid.tool_configs import (
|
||||
get_implementation_tools,
|
||||
get_planning_tools,
|
||||
get_research_tools,
|
||||
get_web_research_tools,
|
||||
)
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import (
|
||||
get_key_snippet_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
get_human_input_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.research_note_repository import (
|
||||
get_research_note_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||
from ra_aid.tools.memory import (
|
||||
get_related_files,
|
||||
log_work_event,
|
||||
)
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.env_inv_context import get_env_inv
|
||||
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
|
||||
|
||||
console = Console()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Import repositories using get_* functions
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -133,131 +67,19 @@ def output_markdown_message(message: str) -> str:
|
|||
return "Message output."
|
||||
|
||||
|
||||
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
"""Helper function to estimate total tokens in a sequence of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to count tokens for
|
||||
|
||||
Returns:
|
||||
Total estimated token count
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
estimate_tokens = CiaynAgent._estimate_tokens
|
||||
return sum(estimate_tokens(msg) for msg in messages)
|
||||
|
||||
|
||||
def state_modifier(
|
||||
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing messages
|
||||
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||
|
||||
Returns:
|
||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||
"""
|
||||
messages = state["messages"]
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
first_message = messages[0]
|
||||
remaining_messages = messages[1:]
|
||||
first_tokens = estimate_messages_tokens([first_message])
|
||||
new_max_tokens = max_input_tokens - first_tokens
|
||||
|
||||
trimmed_remaining = trim_messages(
|
||||
remaining_messages,
|
||||
token_counter=estimate_messages_tokens,
|
||||
max_tokens=new_max_tokens,
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
)
|
||||
|
||||
return [first_message] + trimmed_remaining
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
|
||||
) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration based on agent type.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The token limit if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to get config from repository for production use
|
||||
try:
|
||||
config_from_repo = get_config_repository().get_all()
|
||||
# If we succeeded, use the repository config instead of passed config
|
||||
config = config_from_repo
|
||||
except RuntimeError:
|
||||
# In tests, this may fail because the repository isn't set up
|
||||
# So we'll use the passed config directly
|
||||
pass
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
try:
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
model_info = get_model_info(provider_model)
|
||||
max_input_tokens = model_info.get("max_input_tokens")
|
||||
if max_input_tokens:
|
||||
logger.debug(
|
||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||
)
|
||||
return max_input_tokens
|
||||
except litellm.exceptions.NotFoundError:
|
||||
logger.debug(
|
||||
f"Model {model_name} not found in litellm, falling back to models_params"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||
)
|
||||
|
||||
# Fallback to models_params dict
|
||||
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
||||
normalized_name = model_name.replace("-", "")
|
||||
provider_tokens = models_params.get(provider, {})
|
||||
if normalized_name in provider_tokens:
|
||||
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
|
||||
logger.debug(
|
||||
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
||||
)
|
||||
else:
|
||||
max_input_tokens = None
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model token limit: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def build_agent_kwargs(
|
||||
checkpointer: Optional[Any] = None,
|
||||
model: ChatAnthropic = None,
|
||||
max_input_tokens: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs dictionary for agent creation.
|
||||
|
||||
Args:
|
||||
checkpointer: Optional memory checkpointer
|
||||
config: Optional configuration dictionary
|
||||
token_limit: Optional token limit for the model
|
||||
model: The language model to use for token counting
|
||||
max_input_tokens: Optional token limit for the model
|
||||
|
||||
Returns:
|
||||
Dictionary of kwargs for agent creation
|
||||
|
|
@ -270,12 +92,20 @@ def build_agent_kwargs(
|
|||
agent_kwargs["checkpointer"] = checkpointer
|
||||
|
||||
config = get_config_repository().get_all()
|
||||
if config.get("limit_tokens", True) and is_anthropic_claude(config):
|
||||
if (
|
||||
config.get("limit_tokens", True)
|
||||
and is_anthropic_claude(config)
|
||||
and model is not None
|
||||
):
|
||||
|
||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||
return state_modifier(state, max_input_tokens=max_input_tokens)
|
||||
if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]):
|
||||
return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens)
|
||||
|
||||
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
||||
|
||||
agent_kwargs["state_modifier"] = wrapped_state_modifier
|
||||
agent_kwargs["name"] = "React"
|
||||
|
||||
return agent_kwargs
|
||||
|
||||
|
|
@ -345,7 +175,8 @@ def create_agent(
|
|||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||
if is_anthropic_claude(config):
|
||||
logger.debug("Using create_react_agent to instantiate agent.")
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
)
|
||||
|
|
@ -358,16 +189,12 @@ def create_agent(
|
|||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
config = get_config_repository().get_all()
|
||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
)
|
||||
|
||||
|
||||
from ra_aid.agents.research_agent import run_research_agent, run_web_research_agent
|
||||
from ra_aid.agents.implementation_agent import run_task_implementation_agent
|
||||
|
||||
|
||||
_CONTEXT_STACK = []
|
||||
_INTERRUPT_CONTEXT = None
|
||||
_FEEDBACK_MODE = False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,312 @@
|
|||
"""Utilities for handling Anthropic-specific message formats and trimming."""
|
||||
|
||||
from typing import Callable, List, Literal, Optional, Sequence, Union, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
|
||||
def _is_message_type(
|
||||
message: BaseMessage, message_types: Union[str, type, List[Union[str, type]]]
|
||||
) -> bool:
|
||||
"""Check if a message is of a specific type or types.
|
||||
|
||||
Args:
|
||||
message: The message to check
|
||||
message_types: Type(s) to check against (string name or class)
|
||||
|
||||
Returns:
|
||||
bool: True if message matches any of the specified types
|
||||
"""
|
||||
if not isinstance(message_types, list):
|
||||
message_types = [message_types]
|
||||
|
||||
types_str = [t for t in message_types if isinstance(t, str)]
|
||||
types_classes = tuple(t for t in message_types if isinstance(t, type))
|
||||
|
||||
return message.type in types_str or isinstance(message, types_classes)
|
||||
|
||||
|
||||
def has_tool_use(message: BaseMessage) -> bool:
|
||||
"""Check if a message contains tool use.
|
||||
|
||||
Args:
|
||||
message: The message to check
|
||||
|
||||
Returns:
|
||||
bool: True if the message contains tool use
|
||||
"""
|
||||
if not isinstance(message, AIMessage):
|
||||
return False
|
||||
|
||||
# Check content for tool_use
|
||||
if isinstance(message.content, str) and "tool_use" in message.content:
|
||||
return True
|
||||
|
||||
# Check content list for tool_use blocks
|
||||
if isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if isinstance(item, dict) and item.get("type") == "tool_use":
|
||||
return True
|
||||
|
||||
# Check additional_kwargs for tool_calls
|
||||
if hasattr(message, "additional_kwargs") and message.additional_kwargs.get(
|
||||
"tool_calls"
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
|
||||
"""Check if two messages form a tool use/result pair.
|
||||
|
||||
Args:
|
||||
message1: First message
|
||||
message2: Second message
|
||||
|
||||
Returns:
|
||||
bool: True if the messages form a tool use/result pair
|
||||
"""
|
||||
return (
|
||||
isinstance(message1, AIMessage)
|
||||
and isinstance(message2, ToolMessage)
|
||||
and has_tool_use(message1)
|
||||
)
|
||||
|
||||
|
||||
|
||||
def anthropic_trim_messages(
|
||||
messages: Sequence[BaseMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
token_counter: Callable[[List[BaseMessage]], int],
|
||||
strategy: Literal["first", "last"] = "last",
|
||||
num_messages_to_keep: int = 2,
|
||||
allow_partial: bool = False,
|
||||
include_system: bool = True,
|
||||
start_on: Optional[Union[str, type, List[Union[str, type]]]] = None,
|
||||
) -> List[BaseMessage]:
|
||||
"""Trim messages to fit within a token limit, with Anthropic-specific handling.
|
||||
|
||||
Warning - not fully implemented - last strategy is supported and test, not
|
||||
allow partial, not 'first' strategy either.
|
||||
This function is similar to langchain_core's trim_messages but with special
|
||||
handling for Anthropic message formats to avoid API errors.
|
||||
|
||||
It always keeps the first num_messages_to_keep messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to trim
|
||||
max_tokens: Maximum number of tokens allowed
|
||||
token_counter: Function to count tokens in messages
|
||||
strategy: Whether to keep the "first" or "last" messages
|
||||
allow_partial: Whether to allow partial messages
|
||||
include_system: Whether to always include the system message
|
||||
start_on: Message type to start on (only for "last" strategy)
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: Trimmed messages that fit within token limit
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
messages = list(messages)
|
||||
|
||||
# Always keep the first num_messages_to_keep messages
|
||||
kept_messages = messages[:num_messages_to_keep]
|
||||
remaining_msgs = messages[num_messages_to_keep:]
|
||||
|
||||
|
||||
# For Anthropic, we need to maintain the conversation structure where:
|
||||
# 1. Every AIMessage with tool_use must be followed by a ToolMessage
|
||||
# 2. Every AIMessage that follows a ToolMessage must start with a tool_result
|
||||
|
||||
# First, check if we have any tool_use in the messages
|
||||
has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages)
|
||||
|
||||
# If we have tool_use anywhere, we need to be very careful about trimming
|
||||
if has_tool_use_anywhere:
|
||||
# For safety, just keep all messages if we're under the token limit
|
||||
if token_counter(messages) <= max_tokens:
|
||||
return messages
|
||||
|
||||
# We need to identify all tool_use/tool_result relationships
|
||||
# First, find all AIMessage+ToolMessage pairs
|
||||
pairs = []
|
||||
i = 0
|
||||
while i < len(messages) - 1:
|
||||
if is_tool_pair(messages[i], messages[i + 1]):
|
||||
pairs.append((i, i + 1))
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# For Anthropic, we need to ensure that:
|
||||
# 1. If we include an AIMessage with tool_use, we must include the following ToolMessage
|
||||
# 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use
|
||||
|
||||
# The safest approach is to always keep complete AIMessage+ToolMessage pairs together
|
||||
# First, identify all complete pairs
|
||||
complete_pairs = []
|
||||
for start, end in pairs:
|
||||
complete_pairs.append((start, end))
|
||||
|
||||
# Now we'll build our result, starting with the kept_messages
|
||||
# But we need to be careful about the first message if it has tool_use
|
||||
result = []
|
||||
|
||||
# Check if the last message in kept_messages has tool_use
|
||||
if (
|
||||
kept_messages
|
||||
and isinstance(kept_messages[-1], AIMessage)
|
||||
and has_tool_use(kept_messages[-1])
|
||||
):
|
||||
# We need to find the corresponding ToolMessage
|
||||
for i, (ai_idx, tool_idx) in enumerate(pairs):
|
||||
if messages[ai_idx] is kept_messages[-1]:
|
||||
# Found the pair, add all kept_messages except the last one
|
||||
result.extend(kept_messages[:-1])
|
||||
# Add the AIMessage and ToolMessage as a pair
|
||||
result.extend([messages[ai_idx], messages[tool_idx]])
|
||||
# Remove this pair from the list of pairs to process later
|
||||
pairs = pairs[:i] + pairs[i + 1 :]
|
||||
break
|
||||
else:
|
||||
# If we didn't find a matching pair, just add all kept_messages
|
||||
result.extend(kept_messages)
|
||||
else:
|
||||
# No tool_use in the last kept message, just add all kept_messages
|
||||
result.extend(kept_messages)
|
||||
|
||||
# If we're using the "last" strategy, we'll try to include pairs from the end
|
||||
if strategy == "last":
|
||||
# First collect all pairs we can include within the token limit
|
||||
pairs_to_include = []
|
||||
|
||||
# Process pairs from the end (newest first)
|
||||
for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)):
|
||||
# Try adding this pair
|
||||
test_msgs = result.copy()
|
||||
|
||||
# Add all previously selected pairs
|
||||
for prev_ai_idx, prev_tool_idx in pairs_to_include:
|
||||
test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]])
|
||||
|
||||
# Add this pair
|
||||
test_msgs.extend([messages[ai_idx], messages[tool_idx]])
|
||||
|
||||
if token_counter(test_msgs) <= max_tokens:
|
||||
# This pair fits, add it to our list
|
||||
pairs_to_include.append((ai_idx, tool_idx))
|
||||
else:
|
||||
# This pair would exceed the token limit
|
||||
break
|
||||
|
||||
# Now add the pairs in the correct order
|
||||
# Sort by index to maintain the original conversation flow
|
||||
pairs_to_include.sort(key=lambda x: x[0])
|
||||
for ai_idx, tool_idx in pairs_to_include:
|
||||
result.extend([messages[ai_idx], messages[tool_idx]])
|
||||
|
||||
# No need to sort - we've already added messages in the correct order
|
||||
|
||||
return result
|
||||
|
||||
# If no tool_use, proceed with normal segmentation
|
||||
segments = []
|
||||
i = 0
|
||||
|
||||
# Group messages into segments
|
||||
while i < len(remaining_msgs):
|
||||
segments.append([remaining_msgs[i]])
|
||||
i += 1
|
||||
|
||||
# Now we have segments that maintain the required structure
|
||||
# We'll add segments from the end (for "last" strategy) or beginning (for "first")
|
||||
# until we hit the token limit
|
||||
|
||||
if strategy == "last":
|
||||
# If we have no segments, just return kept_messages
|
||||
if not segments:
|
||||
return kept_messages
|
||||
|
||||
result = []
|
||||
|
||||
# Process segments from the end
|
||||
for i, segment in enumerate(reversed(segments)):
|
||||
# Try adding this segment
|
||||
test_msgs = segment + result
|
||||
|
||||
if token_counter(kept_messages + test_msgs) <= max_tokens:
|
||||
result = segment + result
|
||||
else:
|
||||
# This segment would exceed the token limit
|
||||
break
|
||||
|
||||
final_result = kept_messages + result
|
||||
|
||||
# For Anthropic, we need to ensure the conversation follows a valid structure
|
||||
# We'll do a final check of the entire conversation
|
||||
|
||||
# Validate the conversation structure
|
||||
valid_result = []
|
||||
i = 0
|
||||
|
||||
# Process messages in order
|
||||
while i < len(final_result):
|
||||
current_msg = final_result[i]
|
||||
|
||||
# If this is an AIMessage with tool_use, it must be followed by a ToolMessage
|
||||
if (
|
||||
i < len(final_result) - 1
|
||||
and isinstance(current_msg, AIMessage)
|
||||
and has_tool_use(current_msg)
|
||||
):
|
||||
if isinstance(final_result[i + 1], ToolMessage):
|
||||
# This is a valid tool_use + tool_result pair
|
||||
valid_result.append(current_msg)
|
||||
valid_result.append(final_result[i + 1])
|
||||
i += 2
|
||||
else:
|
||||
# Invalid: AIMessage with tool_use not followed by ToolMessage
|
||||
# Skip this message to maintain valid structure
|
||||
i += 1
|
||||
else:
|
||||
# Regular message, just add it
|
||||
valid_result.append(current_msg)
|
||||
i += 1
|
||||
|
||||
# Final check: don't end with an AIMessage that has tool_use
|
||||
if (
|
||||
valid_result
|
||||
and isinstance(valid_result[-1], AIMessage)
|
||||
and has_tool_use(valid_result[-1])
|
||||
):
|
||||
valid_result.pop() # Remove the last message
|
||||
|
||||
return valid_result
|
||||
|
||||
elif strategy == "first":
|
||||
result = []
|
||||
|
||||
# Process segments from the beginning
|
||||
for i, segment in enumerate(segments):
|
||||
# Try adding this segment
|
||||
test_msgs = result + segment
|
||||
if token_counter(kept_messages + test_msgs) <= max_tokens:
|
||||
result = result + segment
|
||||
else:
|
||||
# This segment would exceed the token limit
|
||||
break
|
||||
|
||||
final_result = kept_messages + result
|
||||
|
||||
return final_result
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
"""Utilities for handling token limits with Anthropic models."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
RemoveMessage,
|
||||
ToolMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.messages.base import message_to_dict
|
||||
|
||||
from ra_aid.anthropic_message_utils import (
|
||||
anthropic_trim_messages,
|
||||
has_tool_use,
|
||||
)
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import token_counter
|
||||
|
||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.console.output import cpm, print_messages_compact
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
"""Helper function to estimate total tokens in a sequence of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to count tokens for
|
||||
|
||||
Returns:
|
||||
Total estimated token count
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
estimate_tokens = CiaynAgent._estimate_tokens
|
||||
return sum(estimate_tokens(msg) for msg in messages)
|
||||
|
||||
|
||||
def convert_message_to_litellm_format(message: BaseMessage) -> Dict:
|
||||
"""Convert a BaseMessage to the format expected by litellm.
|
||||
|
||||
Args:
|
||||
message: The BaseMessage to convert
|
||||
|
||||
Returns:
|
||||
Dict in litellm format
|
||||
"""
|
||||
message_dict = message_to_dict(message)
|
||||
return {
|
||||
"role": message_dict["type"],
|
||||
"content": message_dict["data"]["content"],
|
||||
}
|
||||
|
||||
|
||||
def create_token_counter_wrapper(model: str):
|
||||
"""Create a wrapper for token counter that handles BaseMessage conversion.
|
||||
|
||||
Args:
|
||||
model: The model name to use for token counting
|
||||
|
||||
Returns:
|
||||
A function that accepts BaseMessage objects and returns token count
|
||||
"""
|
||||
|
||||
# Create a partial function that already has the model parameter set
|
||||
base_token_counter = partial(token_counter, model=model)
|
||||
|
||||
def wrapped_token_counter(messages: List[BaseMessage]) -> int:
|
||||
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
|
||||
|
||||
Args:
|
||||
messages: List of BaseMessage objects
|
||||
|
||||
Returns:
|
||||
Token count for the messages
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages]
|
||||
result = base_token_counter(messages=litellm_messages)
|
||||
return result
|
||||
|
||||
return wrapped_token_counter
|
||||
|
||||
|
||||
def state_modifier(
|
||||
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||
|
||||
This uses anthropic_trim_messages which always keeps the first 2 messages.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing messages
|
||||
model: The language model to use for token counting
|
||||
max_input_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||
|
||||
Returns:
|
||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||
"""
|
||||
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
||||
|
||||
result = anthropic_trim_messages(
|
||||
messages,
|
||||
token_counter=wrapped_token_counter,
|
||||
max_tokens=max_input_tokens,
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
include_system=True,
|
||||
num_messages_to_keep=2,
|
||||
)
|
||||
|
||||
if len(result) < len(messages):
|
||||
logger.info(f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def sonnet_35_state_modifier(
|
||||
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing messages
|
||||
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||
|
||||
Returns:
|
||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||
"""
|
||||
messages = state["messages"]
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
first_message = messages[0]
|
||||
remaining_messages = messages[1:]
|
||||
first_tokens = estimate_messages_tokens([first_message])
|
||||
new_max_tokens = max_input_tokens - first_tokens
|
||||
|
||||
trimmed_remaining = trim_messages(
|
||||
remaining_messages,
|
||||
token_counter=estimate_messages_tokens,
|
||||
max_tokens=new_max_tokens,
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
include_system=True,
|
||||
)
|
||||
|
||||
result = [first_message] + trimmed_remaining
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
config: Dict[str, Any], agent_type: str = "default"
|
||||
) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration based on agent type.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
agent_type: Type of agent ("default", "research", or "planner")
|
||||
|
||||
Returns:
|
||||
Optional[int]: The token limit if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to get config from repository for production use
|
||||
try:
|
||||
config_from_repo = get_config_repository().get_all()
|
||||
# If we succeeded, use the repository config instead of passed config
|
||||
config = config_from_repo
|
||||
except RuntimeError:
|
||||
# In tests, this may fail because the repository isn't set up
|
||||
# So we'll use the passed config directly
|
||||
pass
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
model_info = get_model_info(provider_model)
|
||||
max_input_tokens = model_info.get("max_input_tokens")
|
||||
if max_input_tokens:
|
||||
logger.debug(
|
||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||
)
|
||||
return max_input_tokens
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||
)
|
||||
|
||||
# Fallback to models_params dict
|
||||
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
||||
normalized_name = model_name.replace("-", "")
|
||||
provider_tokens = models_params.get(provider, {})
|
||||
if normalized_name in provider_tokens:
|
||||
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
|
||||
logger.debug(
|
||||
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
||||
)
|
||||
else:
|
||||
max_input_tokens = None
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model token limit: {e}")
|
||||
return None
|
||||
|
|
@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3
|
|||
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||
RETRY_FALLBACK_COUNT = 3
|
||||
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
||||
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
|
||||
DEFAULT_SHOW_COST = False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
|
|
@ -98,3 +98,57 @@ def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -
|
|||
"""
|
||||
|
||||
console.print(Panel(Markdown(message), title=title, border_style=border_style))
|
||||
|
||||
|
||||
def print_messages_compact(messages: Sequence[BaseMessage]) -> None:
|
||||
"""Print a compact representation of a list of messages.
|
||||
|
||||
Warning: Used mainly for debugging purposes so do not delete if not referenced anywhere!
|
||||
For all message types, only the first 30 characters of content are shown.
|
||||
|
||||
Args:
|
||||
messages: A sequence of BaseMessage objects to print
|
||||
"""
|
||||
if not messages:
|
||||
console.print("[italic]No messages[/italic]")
|
||||
return
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
msg_type = msg.__class__.__name__
|
||||
content = msg.content
|
||||
|
||||
# Process content based on its type
|
||||
if isinstance(content, str):
|
||||
display_content = f"{content[:30]}..." if len(content) > 30 else content
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (list of content blocks)
|
||||
content_preview = []
|
||||
for item in content[:2]: # Show first 2 items at most
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
content_preview.append(f"text: {text[:20]}..." if len(text) > 20 else f"text: {text}")
|
||||
elif item.get("type") == "tool_call":
|
||||
tool_name = item.get("tool_call", {}).get("name", "unknown")
|
||||
content_preview.append(f"tool_call: {tool_name}")
|
||||
else:
|
||||
content_preview.append(f"{item.get('type', 'unknown')}")
|
||||
|
||||
if len(content) > 2:
|
||||
content_preview.append(f"...({len(content)-2} more)")
|
||||
|
||||
display_content = ", ".join(content_preview)
|
||||
else:
|
||||
display_content = str(content)[:30] + "..." if len(str(content)) > 30 else str(content)
|
||||
|
||||
# Add additional tool message info if available
|
||||
additional_info = []
|
||||
if hasattr(msg, "tool_call_id") and msg.tool_call_id:
|
||||
additional_info.append(f"tool_call_id: {msg.tool_call_id}")
|
||||
if hasattr(msg, "name") and msg.name:
|
||||
additional_info.append(f"name: {msg.name}")
|
||||
if hasattr(msg, "status") and msg.status:
|
||||
additional_info.append(f"status: {msg.status}")
|
||||
|
||||
info_str = f" ({', '.join(additional_info)})" if additional_info else ""
|
||||
console.print(f"[{i}] [bold]{msg_type}{info_str}[/bold]: {display_content}")
|
||||
|
|
|
|||
|
|
@ -259,8 +259,9 @@ def create_llm_client(
|
|||
else:
|
||||
temp_kwargs = {}
|
||||
|
||||
thinking_kwargs = {}
|
||||
if supports_thinking:
|
||||
temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
||||
thinking_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
||||
|
||||
if provider == "deepseek":
|
||||
return create_deepseek_client(
|
||||
|
|
@ -268,6 +269,7 @@ def create_llm_client(
|
|||
api_key=config["api_key"],
|
||||
base_url=config["base_url"],
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
is_expert=is_expert,
|
||||
)
|
||||
elif provider == "openrouter":
|
||||
|
|
@ -275,6 +277,7 @@ def create_llm_client(
|
|||
model_name=model_name,
|
||||
api_key=config["api_key"],
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
is_expert=is_expert,
|
||||
)
|
||||
elif provider == "openai":
|
||||
|
|
@ -301,6 +304,7 @@ def create_llm_client(
|
|||
max_retries=LLM_MAX_RETRIES,
|
||||
max_tokens=model_config.get("max_tokens", 64000),
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
return ChatOpenAI(
|
||||
|
|
@ -310,6 +314,7 @@ def create_llm_client(
|
|||
timeout=LLM_REQUEST_TIMEOUT,
|
||||
max_retries=LLM_MAX_RETRIES,
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif provider == "gemini":
|
||||
return ChatGoogleGenerativeAI(
|
||||
|
|
@ -318,6 +323,7 @@ def create_llm_client(
|
|||
timeout=LLM_REQUEST_TIMEOUT,
|
||||
max_retries=LLM_MAX_RETRIES,
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ from ra_aid.agent_context import (
|
|||
is_crashed,
|
||||
reset_completion_flags,
|
||||
)
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
from ra_aid.console.formatting import print_error, print_task_header
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository, get_human_input_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
|
|
@ -385,7 +386,7 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
config = get_config_repository().get_all()
|
||||
model = initialize_llm(
|
||||
config.get("provider", "anthropic"),
|
||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
||||
config.get("model",DEFAULT_MODEL),
|
||||
temperature=config.get("temperature"),
|
||||
)
|
||||
|
||||
|
|
@ -552,7 +553,7 @@ def request_implementation(task_spec: str) -> str:
|
|||
config = get_config_repository().get_all()
|
||||
model = initialize_llm(
|
||||
config.get("provider", "anthropic"),
|
||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
||||
config.get("model", DEFAULT_MODEL),
|
||||
temperature=config.get("temperature"),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,12 +14,18 @@ from ra_aid.agent_context import (
|
|||
from ra_aid.agent_utils import (
|
||||
AgentState,
|
||||
create_agent,
|
||||
get_model_token_limit,
|
||||
is_anthropic_claude,
|
||||
)
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
get_model_token_limit,
|
||||
state_modifier,
|
||||
)
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
|
||||
from ra_aid.database.repositories.config_repository import (
|
||||
ConfigRepositoryManager,
|
||||
get_config_repository,
|
||||
config_repo_var,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -32,7 +38,9 @@ def mock_model():
|
|||
@pytest.fixture
|
||||
def mock_config_repository():
|
||||
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||
with patch(
|
||||
"ra_aid.database.repositories.config_repository.config_repo_var"
|
||||
) as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
|
|
@ -42,6 +50,7 @@ def mock_config_repository():
|
|||
# Setup get method to return config values
|
||||
def get_config(key, default=None):
|
||||
return config.get(key, default)
|
||||
|
||||
mock_repo.get.side_effect = get_config
|
||||
|
||||
# Setup get_all method to return all config values
|
||||
|
|
@ -50,11 +59,13 @@ def mock_config_repository():
|
|||
# Setup set method to update config values
|
||||
def set_config(key, value):
|
||||
config[key] = value
|
||||
|
||||
mock_repo.set.side_effect = set_config
|
||||
|
||||
# Setup update method to update multiple config values
|
||||
def update_config(update_dict):
|
||||
config.update(update_dict)
|
||||
|
||||
mock_repo.update.side_effect = update_config
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
|
|
@ -66,7 +77,9 @@ def mock_config_repository():
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_trajectory_repository():
|
||||
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
|
||||
with patch(
|
||||
"ra_aid.database.repositories.trajectory_repository.trajectory_repo_var"
|
||||
) as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
|
|
@ -75,6 +88,7 @@ def mock_trajectory_repository():
|
|||
mock_trajectory = MagicMock()
|
||||
mock_trajectory.id = 1
|
||||
return mock_trajectory
|
||||
|
||||
mock_repo.create.side_effect = mock_create
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
|
|
@ -86,7 +100,9 @@ def mock_trajectory_repository():
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_human_input_repository():
|
||||
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
|
||||
with patch(
|
||||
"ra_aid.database.repositories.human_input_repository.human_input_repo_var"
|
||||
) as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
|
|
@ -99,87 +115,14 @@ def mock_human_input_repository():
|
|||
yield mock_repo
|
||||
|
||||
|
||||
def test_get_model_token_limit_anthropic(mock_config_repository):
|
||||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_openai(mock_config_repository):
|
||||
"""Test get_model_token_limit with OpenAI model."""
|
||||
config = {"provider": "openai", "model": "gpt-4"}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_unknown(mock_config_repository):
|
||||
"""Test get_model_token_limit with unknown provider/model."""
|
||||
config = {"provider": "unknown", "model": "unknown-model"}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
def test_get_model_token_limit_missing_config(mock_config_repository):
|
||||
"""Test get_model_token_limit with missing configuration."""
|
||||
config = {}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_success():
|
||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == 100000
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_not_found():
|
||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||
)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_error():
|
||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.side_effect = Exception("Unknown error")
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_unexpected_error():
|
||||
"""Test returning None when unexpected errors occur."""
|
||||
config = None # This will cause an attribute error when accessed
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
def test_create_agent_anthropic(mock_model, mock_config_repository):
|
||||
"""Test create_agent with Anthropic Claude model."""
|
||||
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
|
||||
|
||||
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier,
|
||||
):
|
||||
mock_react.return_value = "react_agent"
|
||||
agent = create_agent(mock_model, [])
|
||||
|
||||
|
|
@ -187,9 +130,10 @@ def test_create_agent_anthropic(mock_model, mock_config_repository):
|
|||
mock_react.assert_called_once_with(
|
||||
mock_model,
|
||||
[],
|
||||
interrupt_after=['tools'],
|
||||
interrupt_after=["tools"],
|
||||
version="v2",
|
||||
state_modifier=mock_react.call_args[1]["state_modifier"],
|
||||
name="React",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -257,20 +201,7 @@ def mock_messages():
|
|||
]
|
||||
|
||||
|
||||
def test_state_modifier(mock_messages):
|
||||
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
|
||||
state = AgentState(messages=mock_messages)
|
||||
|
||||
with patch(
|
||||
"ra_aid.agent_backends.ciayn_agent.CiaynAgent._estimate_tokens"
|
||||
) as mock_estimate:
|
||||
mock_estimate.side_effect = lambda msg: 100 if msg else 0
|
||||
|
||||
result = state_modifier(state, max_input_tokens=250)
|
||||
|
||||
assert len(result) < len(mock_messages)
|
||||
assert isinstance(result[0], SystemMessage)
|
||||
assert result[-1] == mock_messages[-1]
|
||||
# This test has been moved to test_anthropic_token_limiter.py
|
||||
|
||||
|
||||
def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
||||
|
|
@ -291,17 +222,21 @@ def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
|||
)
|
||||
|
||||
|
||||
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
|
||||
def test_create_agent_anthropic_token_limiting_enabled(
|
||||
mock_model, mock_config_repository
|
||||
):
|
||||
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
||||
mock_config_repository.update({
|
||||
mock_config_repository.update(
|
||||
{
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": True,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
|
||||
):
|
||||
mock_react.return_value = "react_agent"
|
||||
mock_limit.return_value = 100000
|
||||
|
|
@ -314,17 +249,21 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r
|
|||
assert callable(args[1]["state_modifier"])
|
||||
|
||||
|
||||
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
|
||||
def test_create_agent_anthropic_token_limiting_disabled(
|
||||
mock_model, mock_config_repository
|
||||
):
|
||||
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
||||
mock_config_repository.update({
|
||||
mock_config_repository.update(
|
||||
{
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": False,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
||||
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
|
||||
):
|
||||
mock_react.return_value = "react_agent"
|
||||
mock_limit.return_value = 100000
|
||||
|
|
@ -332,39 +271,12 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
|
|||
agent = create_agent(mock_model, [])
|
||||
|
||||
assert agent == "react_agent"
|
||||
mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2")
|
||||
mock_react.assert_called_once_with(
|
||||
mock_model, [], interrupt_after=["tools"], version="v2", name="React"
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_token_limit_research(mock_config_repository):
|
||||
"""Test get_model_token_limit with research provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"research_provider": "anthropic",
|
||||
"research_model": "claude-2",
|
||||
}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||
token_limit = get_model_token_limit(config, "research")
|
||||
assert token_limit == 150000
|
||||
|
||||
|
||||
def test_get_model_token_limit_planner(mock_config_repository):
|
||||
"""Test get_model_token_limit with planner provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"planner_provider": "deepseek",
|
||||
"planner_model": "dsm-1",
|
||||
}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
assert token_limit == 120000
|
||||
# These tests have been moved to test_anthropic_token_limiter.py
|
||||
|
||||
|
||||
# New tests for private helper methods in agent_utils.py
|
||||
|
|
@ -629,7 +541,9 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repos
|
|||
assert "Agent has crashed: Test crash message" in result
|
||||
|
||||
|
||||
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_repository):
|
||||
def test_run_agent_with_retry_handles_badrequest_error(
|
||||
monkeypatch, mock_config_repository
|
||||
):
|
||||
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
||||
from ra_aid.agent_context import agent_context, is_crashed
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
|
|
@ -687,7 +601,9 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_
|
|||
assert is_crashed()
|
||||
|
||||
|
||||
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch, mock_config_repository):
|
||||
def test_run_agent_with_retry_handles_api_badrequest_error(
|
||||
monkeypatch, mock_config_repository
|
||||
):
|
||||
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
||||
# Import APIError from anthropic module and patch it on the agent_utils module
|
||||
|
||||
|
|
@ -760,5 +676,7 @@ def test_handle_api_error_resource_exhausted():
|
|||
from ra_aid.agent_utils import _handle_api_error
|
||||
|
||||
# ResourceExhausted exception should be handled without raising
|
||||
resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).")
|
||||
resource_exhausted_error = ResourceExhausted(
|
||||
"429 Resource has been exhausted (e.g. check quota)."
|
||||
)
|
||||
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
||||
|
|
@ -0,0 +1,507 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import litellm
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage
|
||||
)
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
create_token_counter_wrapper,
|
||||
estimate_messages_tokens,
|
||||
get_model_token_limit,
|
||||
state_modifier,
|
||||
sonnet_35_state_modifier,
|
||||
convert_message_to_litellm_format
|
||||
)
|
||||
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
|
||||
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT
|
||||
|
||||
|
||||
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
|
||||
self.mock_model = MagicMock(spec=ChatAnthropic)
|
||||
self.mock_model.model = DEFAULT_MODEL
|
||||
|
||||
# Sample messages for testing
|
||||
self.system_message = SystemMessage(content="You are a helpful assistant.")
|
||||
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
|
||||
self.ai_message = AIMessage(content="I'd be happy to help! What do you need?")
|
||||
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
|
||||
|
||||
# Create more messages for testing
|
||||
self.extra_messages = [
|
||||
HumanMessage(content=f"Extra message {i}") for i in range(5)
|
||||
]
|
||||
|
||||
# Mock state for testing state_modifier with many messages
|
||||
self.state = AgentState(
|
||||
messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages,
|
||||
next=None,
|
||||
)
|
||||
|
||||
# Create tool-related messages for testing
|
||||
self.ai_with_tool_use = AIMessage(
|
||||
content="I'll use a tool to help you",
|
||||
additional_kwargs={"tool_calls": [{"name": "calculator", "input": {"expression": "2+2"}}]}
|
||||
)
|
||||
self.tool_message = ToolMessage(
|
||||
content="4",
|
||||
tool_call_id="tool_call_1",
|
||||
name="calculator"
|
||||
)
|
||||
|
||||
def test_convert_message_to_litellm_format(self):
|
||||
"""Test conversion of BaseMessage to litellm format."""
|
||||
# Test human message
|
||||
human_result = convert_message_to_litellm_format(self.human_message)
|
||||
self.assertEqual(human_result["role"], "human")
|
||||
self.assertEqual(human_result["content"], "Hello, can you help me with a task?")
|
||||
|
||||
# Test system message
|
||||
system_result = convert_message_to_litellm_format(self.system_message)
|
||||
self.assertEqual(system_result["role"], "system")
|
||||
self.assertEqual(system_result["content"], "You are a helpful assistant.")
|
||||
|
||||
# Test AI message
|
||||
ai_result = convert_message_to_litellm_format(self.ai_message)
|
||||
self.assertEqual(ai_result["role"], "ai")
|
||||
self.assertEqual(ai_result["content"], "I'd be happy to help! What do you need?")
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.token_counter")
|
||||
def test_create_token_counter_wrapper(self, mock_token_counter):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
|
||||
# Setup mock return values
|
||||
mock_token_counter.return_value = 50
|
||||
|
||||
# Create the wrapper
|
||||
wrapper = create_token_counter_wrapper(DEFAULT_MODEL)
|
||||
|
||||
# Test with BaseMessage objects
|
||||
result = wrapper([self.human_message])
|
||||
self.assertEqual(result, 50)
|
||||
|
||||
# Test with empty list
|
||||
result = wrapper([])
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
# Verify the mock was called with the right parameters
|
||||
mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens")
|
||||
def test_estimate_messages_tokens(self, mock_estimate_tokens):
|
||||
# Setup mock to return different values for different messages
|
||||
mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20
|
||||
|
||||
# Test with multiple messages
|
||||
messages = [self.system_message, self.human_message]
|
||||
result = estimate_messages_tokens(messages)
|
||||
|
||||
# Should be sum of individual token counts (10 + 20)
|
||||
self.assertEqual(result, 30)
|
||||
|
||||
# Test with empty list
|
||||
result = estimate_messages_tokens([])
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
||||
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
||||
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
|
||||
# Setup a proper token counter function that returns integers
|
||||
def token_counter(msgs):
|
||||
# Return token count based on number of messages
|
||||
return len(msgs) * 10
|
||||
|
||||
# Configure the mock to return our token counter
|
||||
mock_create_wrapper.return_value = token_counter
|
||||
|
||||
# Configure anthropic_trim_messages to return a subset of messages
|
||||
mock_trim_messages.return_value = [self.system_message, self.human_message]
|
||||
|
||||
# Call state_modifier with a max token limit of 50
|
||||
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
|
||||
|
||||
# Should return what anthropic_trim_messages returned
|
||||
self.assertEqual(result, [self.system_message, self.human_message])
|
||||
|
||||
# Verify the wrapper was created with the right model
|
||||
mock_create_wrapper.assert_called_with(self.mock_model.model)
|
||||
|
||||
# Verify anthropic_trim_messages was called with the right parameters
|
||||
mock_trim_messages.assert_called_once()
|
||||
|
||||
|
||||
def test_state_modifier_with_messages(self):
|
||||
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
|
||||
# Create a state with messages
|
||||
messages = [
|
||||
SystemMessage(content="System prompt"),
|
||||
HumanMessage(content="Human message 1"),
|
||||
AIMessage(content="AI response 1"),
|
||||
HumanMessage(content="Human message 2"),
|
||||
AIMessage(content="AI response 2"),
|
||||
]
|
||||
state = AgentState(messages=messages)
|
||||
model = MagicMock(spec=ChatAnthropic)
|
||||
model.model = "claude-3-opus-20240229"
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
|
||||
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \
|
||||
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
|
||||
# Setup mock to return a fixed token count per message
|
||||
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
|
||||
# Setup mock to return a subset of messages
|
||||
mock_trim.return_value = [messages[0], messages[-2], messages[-1]]
|
||||
|
||||
result = state_modifier(state, model, max_input_tokens=250)
|
||||
|
||||
# Should return what anthropic_trim_messages returned
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0], messages[0]) # First message preserved
|
||||
self.assertEqual(result[-1], messages[-1]) # Last message preserved
|
||||
|
||||
def test_sonnet_35_state_modifier(self):
|
||||
"""Test the sonnet 35 state modifier function."""
|
||||
# Create a state with messages
|
||||
state = {"messages": [self.system_message, self.human_message, self.ai_message]}
|
||||
|
||||
# Test with empty messages
|
||||
empty_state = {"messages": []}
|
||||
|
||||
# Instead of patching trim_messages which has complex internal logic,
|
||||
# we'll directly patch the sonnet_35_state_modifier's call to trim_messages
|
||||
with patch("ra_aid.anthropic_token_limiter.trim_messages") as mock_trim:
|
||||
# Setup mock to return our desired messages
|
||||
mock_trim.return_value = [self.human_message, self.ai_message]
|
||||
|
||||
# Test with empty messages
|
||||
self.assertEqual(sonnet_35_state_modifier(empty_state), [])
|
||||
|
||||
# Test with messages under the limit
|
||||
result = sonnet_35_state_modifier(state, max_input_tokens=10000)
|
||||
|
||||
# Should keep the first message and call trim_messages for the rest
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(result[0], self.system_message)
|
||||
self.assertEqual(result[1:], [self.human_message, self.ai_message])
|
||||
|
||||
# Verify trim_messages was called with the right parameters
|
||||
mock_trim.assert_called_once()
|
||||
# We can check some of the key arguments
|
||||
call_args = mock_trim.call_args[1]
|
||||
# The actual value is based on the token estimation logic, not a hard-coded 9000
|
||||
self.assertIn("max_tokens", call_args)
|
||||
self.assertEqual(call_args["strategy"], "last")
|
||||
self.assertEqual(call_args["strategy"], "last")
|
||||
self.assertEqual(call_args["allow_partial"], False)
|
||||
self.assertEqual(call_args["include_system"], True)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
|
||||
# Setup mocks
|
||||
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Mock litellm's get_model_info to return a token limit
|
||||
mock_get_model_info.return_value = {"max_input_tokens": 100000}
|
||||
|
||||
# Test getting token limit
|
||||
result = get_model_token_limit(mock_config)
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
# Verify get_model_info was called with the right model
|
||||
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
|
||||
|
||||
def test_get_model_token_limit_research(self):
|
||||
"""Test get_model_token_limit with research provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"research_provider": "anthropic",
|
||||
"research_model": "claude-2",
|
||||
}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||
token_limit = get_model_token_limit(config, "research")
|
||||
self.assertEqual(token_limit, 150000)
|
||||
# Verify get_model_info was called with the research model
|
||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||
|
||||
def test_get_model_token_limit_planner(self):
|
||||
"""Test get_model_token_limit with planner provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"planner_provider": "deepseek",
|
||||
"planner_model": "dsm-1",
|
||||
}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
self.assertEqual(token_limit, 120000)
|
||||
# Verify get_model_info was called with the planner model
|
||||
mock_get_info.assert_called_with("deepseek/dsm-1")
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
|
||||
# Setup mocks
|
||||
mock_config = {"provider": "anthropic", "model": "claude-2"}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Make litellm's get_model_info raise an exception to test fallback
|
||||
mock_get_model_info.side_effect = Exception("Model not found")
|
||||
|
||||
# Test getting token limit from models_params fallback
|
||||
with patch("ra_aid.anthropic_token_limiter.models_params", {
|
||||
"anthropic": {
|
||||
"claude2": {"token_limit": 100000}
|
||||
}
|
||||
}):
|
||||
result = get_model_token_limit(mock_config)
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
|
||||
# Setup mocks for different agent types
|
||||
mock_config = {
|
||||
"provider": "anthropic",
|
||||
"model": DEFAULT_MODEL,
|
||||
"research_provider": "openai",
|
||||
"research_model": "gpt-4",
|
||||
"planner_provider": "anthropic",
|
||||
"planner_model": "claude-3-sonnet-20240229"
|
||||
}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Mock different returns for different models
|
||||
def model_info_side_effect(model_name):
|
||||
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
|
||||
return {"max_input_tokens": 200000}
|
||||
elif "gpt-4" in model_name:
|
||||
return {"max_input_tokens": 8192}
|
||||
elif "claude-3-sonnet" in model_name:
|
||||
return {"max_input_tokens": 100000}
|
||||
else:
|
||||
raise Exception(f"Unknown model: {model_name}")
|
||||
|
||||
mock_get_model_info.side_effect = model_info_side_effect
|
||||
|
||||
# Test default agent type
|
||||
result = get_model_token_limit(mock_config, "default")
|
||||
self.assertEqual(result, 200000)
|
||||
|
||||
# Test research agent type
|
||||
result = get_model_token_limit(mock_config, "research")
|
||||
self.assertEqual(result, 8192)
|
||||
|
||||
# Test planner agent type
|
||||
result = get_model_token_limit(mock_config, "planner")
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
def test_get_model_token_limit_anthropic(self):
|
||||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
|
||||
def test_get_model_token_limit_openai(self):
|
||||
"""Test get_model_token_limit with OpenAI model."""
|
||||
config = {"provider": "openai", "model": "gpt-4"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["openai"]["gpt-4"]["token_limit"])
|
||||
|
||||
def test_get_model_token_limit_unknown(self):
|
||||
"""Test get_model_token_limit with unknown provider/model."""
|
||||
config = {"provider": "unknown", "model": "unknown-model"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertIsNone(token_limit)
|
||||
|
||||
def test_get_model_token_limit_missing_config(self):
|
||||
"""Test get_model_token_limit with missing configuration."""
|
||||
config = {}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertIsNone(token_limit)
|
||||
|
||||
def test_get_model_token_limit_litellm_success(self):
|
||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, 100000)
|
||||
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||
|
||||
def test_get_model_token_limit_litellm_not_found(self):
|
||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||
)
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
|
||||
def test_get_model_token_limit_litellm_error(self):
|
||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||
patch("litellm.get_model_info") as mock_get_info:
|
||||
mock_get_config_repo.return_value.get_all.return_value = config
|
||||
mock_get_info.side_effect = Exception("Unknown error")
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||
|
||||
def test_get_model_token_limit_unexpected_error(self):
|
||||
"""Test returning None when unexpected errors occur."""
|
||||
config = None # This will cause an attribute error when accessed
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
self.assertIsNone(token_limit)
|
||||
|
||||
def test_has_tool_use(self):
|
||||
"""Test the has_tool_use function."""
|
||||
# Test with regular AI message
|
||||
self.assertFalse(has_tool_use(self.ai_message))
|
||||
|
||||
# Test with AI message containing tool_use in string content
|
||||
ai_with_tool_str = AIMessage(content="I'll use a tool_use to help you")
|
||||
self.assertTrue(has_tool_use(ai_with_tool_str))
|
||||
|
||||
# Test with AI message containing tool_use in structured content
|
||||
ai_with_tool_dict = AIMessage(content=[
|
||||
{"type": "text", "text": "I'll use a tool to help you"},
|
||||
{"type": "tool_use", "tool_use": {"name": "calculator", "input": {"expression": "2+2"}}}
|
||||
])
|
||||
self.assertTrue(has_tool_use(ai_with_tool_dict))
|
||||
|
||||
# Test with AI message containing tool_calls in additional_kwargs
|
||||
self.assertTrue(has_tool_use(self.ai_with_tool_use))
|
||||
|
||||
# Test with non-AI message
|
||||
self.assertFalse(has_tool_use(self.human_message))
|
||||
|
||||
def test_is_tool_pair(self):
|
||||
"""Test the is_tool_pair function."""
|
||||
# Test with valid tool pair
|
||||
self.assertTrue(is_tool_pair(self.ai_with_tool_use, self.tool_message))
|
||||
|
||||
# Test with non-tool pair (wrong order)
|
||||
self.assertFalse(is_tool_pair(self.tool_message, self.ai_with_tool_use))
|
||||
|
||||
# Test with non-tool pair (wrong types)
|
||||
self.assertFalse(is_tool_pair(self.ai_message, self.human_message))
|
||||
|
||||
# Test with non-tool pair (AI message without tool use)
|
||||
self.assertFalse(is_tool_pair(self.ai_message, self.tool_message))
|
||||
|
||||
@patch("ra_aid.anthropic_message_utils.has_tool_use")
|
||||
def test_anthropic_trim_messages_with_tool_use(self, mock_has_tool_use):
|
||||
"""Test anthropic_trim_messages with a sequence of messages including tool use."""
|
||||
from ra_aid.anthropic_message_utils import anthropic_trim_messages
|
||||
|
||||
# Setup mock for has_tool_use to return True for AI messages at even indices
|
||||
def side_effect(msg):
|
||||
if isinstance(msg, AIMessage) and hasattr(msg, 'test_index'):
|
||||
return msg.test_index % 2 == 0 # Even indices have tool use
|
||||
return False
|
||||
|
||||
mock_has_tool_use.side_effect = side_effect
|
||||
|
||||
# Create a sequence of alternating human and AI messages with tool use
|
||||
messages = []
|
||||
|
||||
# Start with system message
|
||||
system_msg = SystemMessage(content="You are a helpful assistant.")
|
||||
messages.append(system_msg)
|
||||
|
||||
# Add alternating human and AI messages with tool use
|
||||
for i in range(8):
|
||||
if i % 2 == 0:
|
||||
# Human message
|
||||
msg = HumanMessage(content=f"Human message {i}")
|
||||
messages.append(msg)
|
||||
else:
|
||||
# AI message, every other one has tool use
|
||||
ai_msg = AIMessage(content=f"AI message {i}")
|
||||
# Add a test_index attribute to track position
|
||||
ai_msg.test_index = i
|
||||
messages.append(ai_msg)
|
||||
|
||||
# If this AI message has tool use (even index), add a tool message after it
|
||||
if i % 4 == 1: # 1, 5, etc.
|
||||
tool_msg = ToolMessage(
|
||||
content=f"Tool result {i}",
|
||||
tool_call_id=f"tool_call_{i}",
|
||||
name="test_tool"
|
||||
)
|
||||
messages.append(tool_msg)
|
||||
|
||||
# Define a token counter that returns a fixed value per message
|
||||
def token_counter(msgs):
|
||||
return len(msgs) * 1000
|
||||
|
||||
# Test with a token limit that will require trimming
|
||||
result = anthropic_trim_messages(
|
||||
messages,
|
||||
token_counter=token_counter,
|
||||
max_tokens=5000, # This will allow 5 messages
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
include_system=True,
|
||||
num_messages_to_keep=2 # Keep system and first human message
|
||||
)
|
||||
|
||||
# We should have kept the first 2 messages (system + human)
|
||||
self.assertEqual(len(result), 5) # 2 kept + 3 more that fit in token limit
|
||||
self.assertEqual(result[0], system_msg)
|
||||
|
||||
# Verify that we don't have any AI messages with tool use that aren't followed by a tool message
|
||||
for i in range(len(result) - 1):
|
||||
if isinstance(result[i], AIMessage) and mock_has_tool_use(result[i]):
|
||||
self.assertTrue(isinstance(result[i+1], ToolMessage),
|
||||
f"AI message with tool use at index {i} not followed by ToolMessage")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue