feat(main.py): refactor imports for better organization and readability

feat(main.py): add DEFAULT_MODEL constant to centralize model configuration
feat(main.py): enhance logging and error handling for better debugging
feat(main.py): implement state_modifier for managing token limits in agent state
feat(anthropic_token_limiter.py): create utilities for handling token limits with Anthropic models
feat(output.py): add print_messages_compact function for debugging message output
test(anthropic_token_limiter.py): add unit tests for token limit utilities and state management
This commit is contained in:
Ariel Frischer 2025-03-11 14:03:18 -07:00
parent b4b0fdd686
commit 5c9a1e81d2
8 changed files with 592 additions and 262 deletions

View File

@ -39,32 +39,38 @@ from ra_aid.agents.research_agent import run_research_agent
from ra_aid.agents import run_planning_agent from ra_aid.agents import run_planning_agent
from ra_aid.config import ( from ra_aid.config import (
DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_MAX_TEST_CMD_RETRIES,
DEFAULT_MODEL,
DEFAULT_RECURSION_LIMIT, DEFAULT_RECURSION_LIMIT,
DEFAULT_TEST_CMD_TIMEOUT, DEFAULT_TEST_CMD_TIMEOUT,
VALID_PROVIDERS, VALID_PROVIDERS,
) )
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository from ra_aid.database.repositories.key_fact_repository import (
KeyFactRepositoryManager,
get_key_fact_repository,
)
from ra_aid.database.repositories.key_snippet_repository import ( from ra_aid.database.repositories.key_snippet_repository import (
KeySnippetRepositoryManager, get_key_snippet_repository KeySnippetRepositoryManager,
get_key_snippet_repository,
) )
from ra_aid.database.repositories.human_input_repository import ( from ra_aid.database.repositories.human_input_repository import (
HumanInputRepositoryManager, get_human_input_repository HumanInputRepositoryManager,
get_human_input_repository,
) )
from ra_aid.database.repositories.research_note_repository import ( from ra_aid.database.repositories.research_note_repository import (
ResearchNoteRepositoryManager, get_research_note_repository ResearchNoteRepositoryManager,
get_research_note_repository,
) )
from ra_aid.database.repositories.trajectory_repository import ( from ra_aid.database.repositories.trajectory_repository import (
TrajectoryRepositoryManager, get_trajectory_repository TrajectoryRepositoryManager,
get_trajectory_repository,
) )
from ra_aid.database.repositories.related_files_repository import ( from ra_aid.database.repositories.related_files_repository import (
RelatedFilesRepositoryManager RelatedFilesRepositoryManager,
)
from ra_aid.database.repositories.work_log_repository import (
WorkLogRepositoryManager
) )
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
from ra_aid.database.repositories.config_repository import ( from ra_aid.database.repositories.config_repository import (
ConfigRepositoryManager, ConfigRepositoryManager,
get_config_repository get_config_repository,
) )
from ra_aid.env_inv import EnvDiscovery from ra_aid.env_inv import EnvDiscovery
from ra_aid.env_inv_context import EnvInvManager, get_env_inv from ra_aid.env_inv_context import EnvInvManager, get_env_inv
@ -100,9 +106,9 @@ def launch_webui(host: str, port: int):
def parse_arguments(args=None): def parse_arguments(args=None):
ANTHROPIC_DEFAULT_MODEL = "claude-3-7-sonnet-20250219" ANTHROPIC_DEFAULT_MODEL = DEFAULT_MODEL
OPENAI_DEFAULT_MODEL = "gpt-4o" OPENAI_DEFAULT_MODEL = "gpt-4o"
# Case-insensitive log level argument type # Case-insensitive log level argument type
def log_level_type(value): def log_level_type(value):
value = value.lower() value = value.lower()
@ -199,8 +205,10 @@ Examples:
help="Enable chat mode with direct human interaction (implies --hil)", help="Enable chat mode with direct human interaction (implies --hil)",
) )
parser.add_argument( parser.add_argument(
"--log-mode", choices=["console", "file"], default="file", "--log-mode",
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console" choices=["console", "file"],
default="file",
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console",
) )
parser.add_argument( parser.add_argument(
"--pretty-logger", action="store_true", help="Enable pretty logging output" "--pretty-logger", action="store_true", help="Enable pretty logging output"
@ -378,20 +386,20 @@ def is_stage_requested(stage: str) -> bool:
def wipe_project_memory(): def wipe_project_memory():
"""Delete the project database file to wipe all stored memory. """Delete the project database file to wipe all stored memory.
Returns: Returns:
str: A message indicating the result of the operation str: A message indicating the result of the operation
""" """
import os import os
from pathlib import Path from pathlib import Path
cwd = os.getcwd() cwd = os.getcwd()
ra_aid_dir = Path(os.path.join(cwd, ".ra-aid")) ra_aid_dir = Path(os.path.join(cwd, ".ra-aid"))
db_path = os.path.join(ra_aid_dir, "pk.db") db_path = os.path.join(ra_aid_dir, "pk.db")
if not os.path.exists(db_path): if not os.path.exists(db_path):
return "No project memory found to wipe." return "No project memory found to wipe."
try: try:
os.remove(db_path) os.remove(db_path)
return "Project memory wiped successfully." return "Project memory wiped successfully."
@ -403,11 +411,11 @@ def wipe_project_memory():
def build_status(): def build_status():
"""Build status panel with model and feature information. """Build status panel with model and feature information.
Includes memory statistics at the bottom with counts of key facts, snippets, and research notes. Includes memory statistics at the bottom with counts of key facts, snippets, and research notes.
""" """
status = Text() status = Text()
# Get the config repository to get model/provider information # Get the config repository to get model/provider information
config_repo = get_config_repository() config_repo = get_config_repository()
provider = config_repo.get("provider", "") provider = config_repo.get("provider", "")
@ -415,12 +423,14 @@ def build_status():
temperature = config_repo.get("temperature") temperature = config_repo.get("temperature")
expert_provider = config_repo.get("expert_provider", "") expert_provider = config_repo.get("expert_provider", "")
expert_model = config_repo.get("expert_model", "") expert_model = config_repo.get("expert_model", "")
experimental_fallback_handler = config_repo.get("experimental_fallback_handler", False) experimental_fallback_handler = config_repo.get(
"experimental_fallback_handler", False
)
web_research_enabled = config_repo.get("web_research_enabled", False) web_research_enabled = config_repo.get("web_research_enabled", False)
# Get the expert enabled status # Get the expert enabled status
expert_enabled = bool(expert_provider and expert_model) expert_enabled = bool(expert_provider and expert_model)
# Basic model information # Basic model information
status.append("🤖 ") status.append("🤖 ")
status.append(f"{provider}/{model}") status.append(f"{provider}/{model}")
@ -452,39 +462,41 @@ def build_status():
[fb_handler._format_model(m) for m in fb_handler.fallback_tool_models] [fb_handler._format_model(m) for m in fb_handler.fallback_tool_models]
) )
status.append(msg) status.append(msg)
# Add memory statistics # Add memory statistics
# Get counts of key facts, snippets, and research notes with error handling # Get counts of key facts, snippets, and research notes with error handling
fact_count = 0 fact_count = 0
snippet_count = 0 snippet_count = 0
note_count = 0 note_count = 0
try: try:
fact_count = len(get_key_fact_repository().get_all()) fact_count = len(get_key_fact_repository().get_all())
except RuntimeError as e: except RuntimeError as e:
logger.debug(f"Failed to get key facts count: {e}") logger.debug(f"Failed to get key facts count: {e}")
try: try:
snippet_count = len(get_key_snippet_repository().get_all()) snippet_count = len(get_key_snippet_repository().get_all())
except RuntimeError as e: except RuntimeError as e:
logger.debug(f"Failed to get key snippets count: {e}") logger.debug(f"Failed to get key snippets count: {e}")
try: try:
note_count = len(get_research_note_repository().get_all()) note_count = len(get_research_note_repository().get_all())
except RuntimeError as e: except RuntimeError as e:
logger.debug(f"Failed to get research notes count: {e}") logger.debug(f"Failed to get research notes count: {e}")
# Add memory statistics line with reset option note # Add memory statistics line with reset option note
status.append(f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes") status.append(
f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes"
)
if fact_count > 0 or snippet_count > 0 or note_count > 0: if fact_count > 0 or snippet_count > 0 or note_count > 0:
status.append(" (use --wipe-project-memory to reset)") status.append(" (use --wipe-project-memory to reset)")
# Check for newer version # Check for newer version
version_message = check_for_newer_version() version_message = check_for_newer_version()
if version_message: if version_message:
status.append("\n\n") status.append("\n\n")
status.append(version_message, style="yellow") status.append(version_message, style="yellow")
return status return status
@ -493,7 +505,7 @@ def main():
args = parse_arguments() args = parse_arguments()
setup_logging(args.log_mode, args.pretty_logger, args.log_level) setup_logging(args.log_mode, args.pretty_logger, args.log_level)
logger.debug("Starting RA.Aid with arguments: %s", args) logger.debug("Starting RA.Aid with arguments: %s", args)
# Check if we need to wipe project memory before starting # Check if we need to wipe project memory before starting
if args.wipe_project_memory: if args.wipe_project_memory:
result = wipe_project_memory() result = wipe_project_memory()
@ -519,22 +531,24 @@ def main():
# Initialize empty config dictionary to be populated later # Initialize empty config dictionary to be populated later
config = {} config = {}
# Initialize repositories with database connection # Initialize repositories with database connection
# Create environment inventory data # Create environment inventory data
env_discovery = EnvDiscovery() env_discovery = EnvDiscovery()
env_discovery.discover() env_discovery.discover()
env_data = env_discovery.format_markdown() env_data = env_discovery.format_markdown()
with KeyFactRepositoryManager(db) as key_fact_repo, \ with (
KeySnippetRepositoryManager(db) as key_snippet_repo, \ KeyFactRepositoryManager(db) as key_fact_repo,
HumanInputRepositoryManager(db) as human_input_repo, \ KeySnippetRepositoryManager(db) as key_snippet_repo,
ResearchNoteRepositoryManager(db) as research_note_repo, \ HumanInputRepositoryManager(db) as human_input_repo,
RelatedFilesRepositoryManager() as related_files_repo, \ ResearchNoteRepositoryManager(db) as research_note_repo,
TrajectoryRepositoryManager(db) as trajectory_repo, \ RelatedFilesRepositoryManager() as related_files_repo,
WorkLogRepositoryManager() as work_log_repo, \ TrajectoryRepositoryManager(db) as trajectory_repo,
ConfigRepositoryManager(config) as config_repo, \ WorkLogRepositoryManager() as work_log_repo,
EnvInvManager(env_data) as env_inv: ConfigRepositoryManager(config) as config_repo,
EnvInvManager(env_data) as env_inv,
):
# This initializes all repositories and makes them available via their respective get methods # This initializes all repositories and makes them available via their respective get methods
logger.debug("Initialized KeyFactRepository") logger.debug("Initialized KeyFactRepository")
logger.debug("Initialized KeySnippetRepository") logger.debug("Initialized KeySnippetRepository")
@ -554,7 +568,9 @@ def main():
expert_missing, expert_missing,
web_research_enabled, web_research_enabled,
web_research_missing, web_research_missing,
) = validate_environment(args) # Will exit if main env vars missing ) = validate_environment(
args
) # Will exit if main env vars missing
logger.debug("Environment validation successful") logger.debug("Environment validation successful")
# Validate model configuration early # Validate model configuration early
@ -590,11 +606,15 @@ def main():
config_repo.set("expert_provider", args.expert_provider) config_repo.set("expert_provider", args.expert_provider)
config_repo.set("expert_model", args.expert_model) config_repo.set("expert_model", args.expert_model)
config_repo.set("temperature", args.temperature) config_repo.set("temperature", args.temperature)
config_repo.set("experimental_fallback_handler", args.experimental_fallback_handler) config_repo.set(
"experimental_fallback_handler", args.experimental_fallback_handler
)
config_repo.set("web_research_enabled", web_research_enabled) config_repo.set("web_research_enabled", web_research_enabled)
config_repo.set("show_thoughts", args.show_thoughts) config_repo.set("show_thoughts", args.show_thoughts)
config_repo.set("force_reasoning_assistance", args.reasoning_assistance) config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) config_repo.set(
"disable_reasoning_assistance", args.no_reasoning_assistance
)
# Build status panel with memory statistics # Build status panel with memory statistics
status = build_status() status = build_status()
@ -633,13 +653,15 @@ def main():
initial_request = ask_human.invoke( initial_request = ask_human.invoke(
{"question": "What would you like help with?"} {"question": "What would you like help with?"}
) )
# Record chat input in database (redundant as ask_human already records it, # Record chat input in database (redundant as ask_human already records it,
# but needed in case the ask_human implementation changes) # but needed in case the ask_human implementation changes)
try: try:
# Using get_human_input_repository() to access the repository from context # Using get_human_input_repository() to access the repository from context
human_input_repository = get_human_input_repository() human_input_repository = get_human_input_repository()
human_input_repository.create(content=initial_request, source='chat') human_input_repository.create(
content=initial_request, source="chat"
)
human_input_repository.garbage_collect() human_input_repository.garbage_collect()
except Exception as e: except Exception as e:
logger.error(f"Failed to record initial chat input: {str(e)}") logger.error(f"Failed to record initial chat input: {str(e)}")
@ -668,8 +690,12 @@ def main():
config_repo.set("expert_model", args.expert_model) config_repo.set("expert_model", args.expert_model)
config_repo.set("temperature", args.temperature) config_repo.set("temperature", args.temperature)
config_repo.set("show_thoughts", args.show_thoughts) config_repo.set("show_thoughts", args.show_thoughts)
config_repo.set("force_reasoning_assistance", args.reasoning_assistance) config_repo.set(
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) "force_reasoning_assistance", args.reasoning_assistance
)
config_repo.set(
"disable_reasoning_assistance", args.no_reasoning_assistance
)
# Set modification tools based on use_aider flag # Set modification tools based on use_aider flag
set_modification_tools(args.use_aider) set_modification_tools(args.use_aider)
@ -696,8 +722,12 @@ def main():
), ),
working_directory=working_directory, working_directory=working_directory,
current_date=current_date, current_date=current_date,
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()), key_facts=format_key_facts_dict(
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), get_key_fact_repository().get_facts_dict()
),
key_snippets=format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
),
project_info=formatted_project_info, project_info=formatted_project_info,
env_inv=get_env_inv(), env_inv=get_env_inv(),
), ),
@ -711,12 +741,12 @@ def main():
sys.exit(1) sys.exit(1)
base_task = args.message base_task = args.message
# Record CLI input in database # Record CLI input in database
try: try:
# Using get_human_input_repository() to access the repository from context # Using get_human_input_repository() to access the repository from context
human_input_repository = get_human_input_repository() human_input_repository = get_human_input_repository()
human_input_repository.create(content=base_task, source='cli') human_input_repository.create(content=base_task, source="cli")
# Run garbage collection to ensure we don't exceed 100 inputs # Run garbage collection to ensure we don't exceed 100 inputs
human_input_repository.garbage_collect() human_input_repository.garbage_collect()
logger.debug(f"Recorded CLI input: {base_task}") logger.debug(f"Recorded CLI input: {base_task}")
@ -750,19 +780,25 @@ def main():
config_repo.set("expert_model", args.expert_model) config_repo.set("expert_model", args.expert_model)
# Store planner config with fallback to base values # Store planner config with fallback to base values
config_repo.set("planner_provider", args.planner_provider or args.provider) config_repo.set(
"planner_provider", args.planner_provider or args.provider
)
config_repo.set("planner_model", args.planner_model or args.model) config_repo.set("planner_model", args.planner_model or args.model)
# Store research config with fallback to base values # Store research config with fallback to base values
config_repo.set("research_provider", args.research_provider or args.provider) config_repo.set(
"research_provider", args.research_provider or args.provider
)
config_repo.set("research_model", args.research_model or args.model) config_repo.set("research_model", args.research_model or args.model)
# Store temperature in config # Store temperature in config
config_repo.set("temperature", args.temperature) config_repo.set("temperature", args.temperature)
# Store reasoning assistance flags # Store reasoning assistance flags
config_repo.set("force_reasoning_assistance", args.reasoning_assistance) config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) config_repo.set(
"disable_reasoning_assistance", args.no_reasoning_assistance
)
# Set modification tools based on use_aider flag # Set modification tools based on use_aider flag
set_modification_tools(args.use_aider) set_modification_tools(args.use_aider)
@ -794,5 +830,6 @@ def main():
print() print()
sys.exit(0) sys.exit(0)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,19 +1,14 @@
"""Utility functions for working with agents.""" """Utility functions for working with agents."""
import inspect
import os
import signal import signal
import sys import sys
import threading import threading
import time import time
import uuid from typing import Any, Dict, List, Literal, Optional
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence
from langchain_anthropic import ChatAnthropic
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from openai import RateLimitError as OpenAIRateLimitError from openai import RateLimitError as OpenAIRateLimitError
from litellm.exceptions import RateLimitError as LiteLLMRateLimitError from litellm.exceptions import RateLimitError as LiteLLMRateLimitError
@ -23,28 +18,24 @@ from langchain_core.messages import (
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
trim_messages,
) )
from langchain_core.tools import tool from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.chat_agent_executor import AgentState from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import get_model_info
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
from ra_aid.agent_context import ( from ra_aid.agent_context import (
agent_context, agent_context,
get_depth,
is_completed, is_completed,
reset_completion_flags, reset_completion_flags,
should_exit, should_exit,
) )
from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.agents_alias import RAgents from ra_aid.agents_alias import RAgents
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES
from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.formatting import print_error
from ra_aid.console.output import print_agent_output from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import ( from ra_aid.exceptions import (
AgentInterrupt, AgentInterrupt,
@ -53,76 +44,16 @@ from ra_aid.exceptions import (
) )
from ra_aid.fallback_handler import FallbackHandler from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from ra_aid.llm import initialize_expert_llm from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.text.processing import process_thinking_content
from ra_aid.project_info import (
display_project_status,
format_project_info,
get_project_info,
)
from ra_aid.prompts.expert_prompts import (
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
EXPERT_PROMPT_SECTION_PLANNING,
EXPERT_PROMPT_SECTION_RESEARCH,
)
from ra_aid.prompts.human_prompts import (
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
HUMAN_PROMPT_SECTION_PLANNING,
HUMAN_PROMPT_SECTION_RESEARCH,
)
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
from ra_aid.prompts.reasoning_assist_prompt import (
REASONING_ASSIST_PROMPT_PLANNING,
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
REASONING_ASSIST_PROMPT_RESEARCH,
)
from ra_aid.prompts.research_prompts import (
RESEARCH_ONLY_PROMPT,
RESEARCH_PROMPT,
)
from ra_aid.prompts.web_research_prompts import (
WEB_RESEARCH_PROMPT,
WEB_RESEARCH_PROMPT_SECTION_CHAT,
WEB_RESEARCH_PROMPT_SECTION_PLANNING,
WEB_RESEARCH_PROMPT_SECTION_RESEARCH,
)
from ra_aid.tool_configs import (
get_implementation_tools,
get_planning_tools,
get_research_tools,
get_web_research_tools,
)
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import (
get_key_snippet_repository,
)
from ra_aid.database.repositories.human_input_repository import (
get_human_input_repository,
)
from ra_aid.database.repositories.research_note_repository import (
get_research_note_repository,
)
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
from ra_aid.tools.memory import (
get_related_files,
log_work_event,
)
from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.env_inv_context import get_env_inv from ra_aid.anthropic_token_limiter import state_modifier, get_model_token_limit
console = Console() console = Console()
logger = get_logger(__name__) logger = get_logger(__name__)
# Import repositories using get_* functions # Import repositories using get_* functions
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
@tool @tool
@ -132,131 +63,19 @@ def output_markdown_message(message: str) -> str:
return "Message output." return "Message output."
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
"""Helper function to estimate total tokens in a sequence of messages.
Args:
messages: Sequence of messages to count tokens for
Returns:
Total estimated token count
"""
if not messages:
return 0
estimate_tokens = CiaynAgent._estimate_tokens
return sum(estimate_tokens(msg) for msg in messages)
def state_modifier(
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
Args:
state: The current agent state containing messages
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
first_message = messages[0]
remaining_messages = messages[1:]
first_tokens = estimate_messages_tokens([first_message])
new_max_tokens = max_input_tokens - first_tokens
trimmed_remaining = trim_messages(
remaining_messages,
token_counter=estimate_messages_tokens,
max_tokens=new_max_tokens,
strategy="last",
allow_partial=False,
)
return [first_message] + trimmed_remaining
def get_model_token_limit(
config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
) -> Optional[int]:
"""Get the token limit for the current model configuration based on agent type.
Returns:
Optional[int]: The token limit if found, None otherwise
"""
try:
# Try to get config from repository for production use
try:
config_from_repo = get_config_repository().get_all()
# If we succeeded, use the repository config instead of passed config
config = config_from_repo
except RuntimeError:
# In tests, this may fail because the repository isn't set up
# So we'll use the passed config directly
pass
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "")
model_name = config.get("model", "")
try:
provider_model = model_name if not provider else f"{provider}/{model_name}"
model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens:
logger.debug(
f"Using litellm token limit for {model_name}: {max_input_tokens}"
)
return max_input_tokens
except litellm.exceptions.NotFoundError:
logger.debug(
f"Model {model_name} not found in litellm, falling back to models_params"
)
except Exception as e:
logger.debug(
f"Error getting model info from litellm: {e}, falling back to models_params"
)
# Fallback to models_params dict
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
normalized_name = model_name.replace("-", "")
provider_tokens = models_params.get(provider, {})
if normalized_name in provider_tokens:
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
logger.debug(
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
)
else:
max_input_tokens = None
logger.debug(f"Could not find token limit for {provider}/{model_name}")
return max_input_tokens
except Exception as e:
logger.warning(f"Failed to get model token limit: {e}")
return None
def build_agent_kwargs( def build_agent_kwargs(
checkpointer: Optional[Any] = None, checkpointer: Optional[Any] = None,
model: ChatAnthropic = None,
max_input_tokens: Optional[int] = None, max_input_tokens: Optional[int] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Build kwargs dictionary for agent creation. """Build kwargs dictionary for agent creation.
Args: Args:
checkpointer: Optional memory checkpointer checkpointer: Optional memory checkpointer
config: Optional configuration dictionary model: The language model to use for token counting
token_limit: Optional token limit for the model max_input_tokens: Optional token limit for the model
Returns: Returns:
Dictionary of kwargs for agent creation Dictionary of kwargs for agent creation
@ -269,12 +88,17 @@ def build_agent_kwargs(
agent_kwargs["checkpointer"] = checkpointer agent_kwargs["checkpointer"] = checkpointer
config = get_config_repository().get_all() config = get_config_repository().get_all()
if config.get("limit_tokens", True) and is_anthropic_claude(config): if (
config.get("limit_tokens", True)
and is_anthropic_claude(config)
and model is not None
):
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
return state_modifier(state, max_input_tokens=max_input_tokens) return state_modifier(state, model, max_input_tokens=max_input_tokens)
agent_kwargs["state_modifier"] = wrapped_state_modifier agent_kwargs["state_modifier"] = wrapped_state_modifier
agent_kwargs["name"] = "React"
return agent_kwargs return agent_kwargs
@ -340,11 +164,13 @@ def create_agent(
max_input_tokens = ( max_input_tokens = (
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
) )
print(f"max_input_tokens={max_input_tokens}")
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN # Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if is_anthropic_claude(config): if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.") logger.debug("Using create_react_agent to instantiate agent.")
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
return create_react_agent( return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs model, tools, interrupt_after=["tools"], **agent_kwargs
) )
@ -357,16 +183,12 @@ def create_agent(
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = get_config_repository().get_all() config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type) max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
return create_react_agent( return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs model, tools, interrupt_after=["tools"], **agent_kwargs
) )
from ra_aid.agents.research_agent import run_research_agent, run_web_research_agent
from ra_aid.agents.implementation_agent import run_task_implementation_agent
_CONTEXT_STACK = [] _CONTEXT_STACK = []
_INTERRUPT_CONTEXT = None _INTERRUPT_CONTEXT = None
_FEEDBACK_MODE = False _FEEDBACK_MODE = False

View File

@ -0,0 +1,210 @@
"""Utilities for handling token limits with Anthropic models."""
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage, trim_messages
from langchain_core.messages.base import messages_to_dict
from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import token_counter
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.console.output import print_messages_compact
logger = get_logger(__name__)
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
"""Helper function to estimate total tokens in a sequence of messages.
Args:
messages: Sequence of messages to count tokens for
Returns:
Total estimated token count
"""
if not messages:
return 0
estimate_tokens = CiaynAgent._estimate_tokens
return sum(estimate_tokens(msg) for msg in messages)
def create_token_counter_wrapper(model: str):
"""Create a wrapper for token counter that handles BaseMessage conversion.
Args:
model: The model name to use for token counting
Returns:
A function that accepts BaseMessage objects and returns token count
"""
# Create a partial function that already has the model parameter set
base_token_counter = partial(token_counter, model=model)
def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int:
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
Args:
messages: List of messages (either BaseMessage objects or dicts)
Returns:
Token count for the messages
"""
if not messages:
return 0
if isinstance(messages[0], BaseMessage):
messages_dicts = [msg["data"] for msg in messages_to_dict(messages)]
return base_token_counter(messages=messages_dicts)
else:
# Already in dict format
return base_token_counter(messages=messages)
return wrapped_token_counter
def state_modifier(
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message.
Args:
state: The current agent state containing messages
model: The language model to use for token counting
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
first_message = messages[0]
remaining_messages = messages[1:]
wrapped_token_counter = create_token_counter_wrapper(model.model)
first_tokens = wrapped_token_counter([first_message])
new_max_tokens = max_input_tokens - first_tokens
print_messages_compact(messages)
trimmed_remaining = trim_messages(
remaining_messages,
token_counter=wrapped_token_counter,
max_tokens=new_max_tokens,
strategy="last",
allow_partial=False,
)
return [first_message] + trimmed_remaining
def sonnet_3_5_state_modifier(
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
Args:
state: The current agent state containing messages
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
first_message = messages[0]
remaining_messages = messages[1:]
first_tokens = estimate_messages_tokens([first_message])
new_max_tokens = max_input_tokens - first_tokens
trimmed_remaining = trim_messages(
remaining_messages,
token_counter=estimate_messages_tokens,
max_tokens=new_max_tokens,
strategy="last",
allow_partial=False,
)
return [first_message] + trimmed_remaining
def get_model_token_limit(
config: Dict[str, Any], agent_type: str = "default"
) -> Optional[int]:
"""Get the token limit for the current model configuration based on agent type.
Args:
config: Configuration dictionary containing provider and model information
agent_type: Type of agent ("default", "research", or "planner")
Returns:
Optional[int]: The token limit if found, None otherwise
"""
try:
# Try to get config from repository for production use
try:
config_from_repo = get_config_repository().get_all()
# If we succeeded, use the repository config instead of passed config
config = config_from_repo
except RuntimeError:
# In tests, this may fail because the repository isn't set up
# So we'll use the passed config directly
pass
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "")
model_name = config.get("model", "")
try:
from litellm import get_model_info
provider_model = model_name if not provider else f"{provider}/{model_name}"
model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens:
logger.debug(
f"Using litellm token limit for {model_name}: {max_input_tokens}"
)
return max_input_tokens
except Exception as e:
logger.debug(
f"Error getting model info from litellm: {e}, falling back to models_params"
)
# Fallback to models_params dict
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
normalized_name = model_name.replace("-", "")
provider_tokens = models_params.get(provider, {})
if normalized_name in provider_tokens:
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
logger.debug(
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
)
else:
max_input_tokens = None
logger.debug(f"Could not find token limit for {provider}/{model_name}")
return max_input_tokens
except Exception as e:
logger.warning(f"Failed to get model token limit: {e}")
return None

View File

@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3
FALLBACK_TOOL_MODEL_LIMIT = 5 FALLBACK_TOOL_MODEL_LIMIT = 5
RETRY_FALLBACK_COUNT = 3 RETRY_FALLBACK_COUNT = 3
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
VALID_PROVIDERS = [ VALID_PROVIDERS = [

View File

@ -1,6 +1,6 @@
from typing import Any, Dict, Literal, Optional from typing import Any, Dict, List, Literal, Optional, Sequence
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
@ -94,3 +94,57 @@ def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -
""" """
console.print(Panel(Markdown(message), title=title, border_style=border_style)) console.print(Panel(Markdown(message), title=title, border_style=border_style))
def print_messages_compact(messages: Sequence[BaseMessage]) -> None:
"""Print a compact representation of a list of messages.
Warning: Used mainly for debugging purposes so do not delete if not referenced anywhere!
For all message types, only the first 30 characters of content are shown.
Args:
messages: A sequence of BaseMessage objects to print
"""
if not messages:
console.print("[italic]No messages[/italic]")
return
for i, msg in enumerate(messages):
msg_type = msg.__class__.__name__
content = msg.content
# Process content based on its type
if isinstance(content, str):
display_content = f"{content[:30]}..." if len(content) > 30 else content
elif isinstance(content, list):
# Handle structured content (list of content blocks)
content_preview = []
for item in content[:2]: # Show first 2 items at most
if isinstance(item, dict):
if item.get("type") == "text":
text = item.get("text", "")
content_preview.append(f"text: {text[:20]}..." if len(text) > 20 else f"text: {text}")
elif item.get("type") == "tool_call":
tool_name = item.get("tool_call", {}).get("name", "unknown")
content_preview.append(f"tool_call: {tool_name}")
else:
content_preview.append(f"{item.get('type', 'unknown')}")
if len(content) > 2:
content_preview.append(f"...({len(content)-2} more)")
display_content = ", ".join(content_preview)
else:
display_content = str(content)[:30] + "..." if len(str(content)) > 30 else str(content)
# Add additional tool message info if available
additional_info = []
if hasattr(msg, "tool_call_id") and msg.tool_call_id:
additional_info.append(f"tool_call_id: {msg.tool_call_id}")
if hasattr(msg, "name") and msg.name:
additional_info.append(f"name: {msg.name}")
if hasattr(msg, "status") and msg.status:
additional_info.append(f"status: {msg.status}")
info_str = f" ({', '.join(additional_info)})" if additional_info else ""
console.print(f"[{i}] [bold]{msg_type}{info_str}[/bold]: {display_content}")

View File

@ -241,8 +241,9 @@ def create_llm_client(
else: else:
temp_kwargs = {} temp_kwargs = {}
thinking_kwargs = {}
if supports_thinking: if supports_thinking:
temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}} thinking_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
if provider == "deepseek": if provider == "deepseek":
return create_deepseek_client( return create_deepseek_client(
@ -250,6 +251,7 @@ def create_llm_client(
api_key=config["api_key"], api_key=config["api_key"],
base_url=config["base_url"], base_url=config["base_url"],
**temp_kwargs, **temp_kwargs,
**thinking_kwargs,
is_expert=is_expert, is_expert=is_expert,
) )
elif provider == "openrouter": elif provider == "openrouter":
@ -257,6 +259,7 @@ def create_llm_client(
model_name=model_name, model_name=model_name,
api_key=config["api_key"], api_key=config["api_key"],
**temp_kwargs, **temp_kwargs,
**thinking_kwargs,
is_expert=is_expert, is_expert=is_expert,
) )
elif provider == "openai": elif provider == "openai":
@ -271,6 +274,7 @@ def create_llm_client(
return ChatOpenAI( return ChatOpenAI(
**{ **{
**openai_kwargs, **openai_kwargs,
**thinking_kwargs,
"timeout": LLM_REQUEST_TIMEOUT, "timeout": LLM_REQUEST_TIMEOUT,
"max_retries": LLM_MAX_RETRIES, "max_retries": LLM_MAX_RETRIES,
} }
@ -283,6 +287,7 @@ def create_llm_client(
max_retries=LLM_MAX_RETRIES, max_retries=LLM_MAX_RETRIES,
max_tokens=model_config.get("max_tokens", 64000), max_tokens=model_config.get("max_tokens", 64000),
**temp_kwargs, **temp_kwargs,
**thinking_kwargs,
) )
elif provider == "openai-compatible": elif provider == "openai-compatible":
return ChatOpenAI( return ChatOpenAI(
@ -292,6 +297,7 @@ def create_llm_client(
timeout=LLM_REQUEST_TIMEOUT, timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES, max_retries=LLM_MAX_RETRIES,
**temp_kwargs, **temp_kwargs,
**thinking_kwargs,
) )
elif provider == "gemini": elif provider == "gemini":
return ChatGoogleGenerativeAI( return ChatGoogleGenerativeAI(
@ -300,6 +306,7 @@ def create_llm_client(
timeout=LLM_REQUEST_TIMEOUT, timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES, max_retries=LLM_MAX_RETRIES,
**temp_kwargs, **temp_kwargs,
**thinking_kwargs,
) )
else: else:
raise ValueError(f"Unsupported provider: {provider}") raise ValueError(f"Unsupported provider: {provider}")

View File

@ -14,6 +14,7 @@ from ra_aid.agent_context import (
is_crashed, is_crashed,
reset_completion_flags, reset_completion_flags,
) )
from ra_aid.config import DEFAULT_MODEL
from ra_aid.console.formatting import print_error from ra_aid.console.formatting import print_error
from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
@ -337,7 +338,7 @@ def request_task_implementation(task_spec: str) -> str:
config = get_config_repository().get_all() config = get_config_repository().get_all()
model = initialize_llm( model = initialize_llm(
config.get("provider", "anthropic"), config.get("provider", "anthropic"),
config.get("model", "claude-3-5-sonnet-20241022"), config.get("model",DEFAULT_MODEL),
temperature=config.get("temperature"), temperature=config.get("temperature"),
) )
@ -475,7 +476,7 @@ def request_implementation(task_spec: str) -> str:
config = get_config_repository().get_all() config = get_config_repository().get_all()
model = initialize_llm( model = initialize_llm(
config.get("provider", "anthropic"), config.get("provider", "anthropic"),
config.get("model", "claude-3-5-sonnet-20241022"), config.get("model", DEFAULT_MODEL),
temperature=config.get("temperature"), temperature=config.get("temperature"),
) )
@ -592,4 +593,4 @@ def request_implementation(task_spec: str) -> str:
# Join all parts into a single markdown string # Join all parts into a single markdown string
markdown_output = "".join(markdown_parts) markdown_output = "".join(markdown_parts)
return markdown_output return markdown_output

View File

@ -0,0 +1,198 @@
import unittest
from unittest.mock import MagicMock, patch
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.prebuilt.chat_agent_executor import AgentState
from ra_aid.anthropic_token_limiter import (
create_token_counter_wrapper,
estimate_messages_tokens,
get_model_token_limit,
state_modifier,
)
class TestAnthropicTokenLimiter(unittest.TestCase):
def setUp(self):
from ra_aid.config import DEFAULT_MODEL
self.mock_model = MagicMock(spec=ChatAnthropic)
self.mock_model.model = DEFAULT_MODEL
# Sample messages for testing
self.system_message = SystemMessage(content="You are a helpful assistant.")
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
# Create more messages for testing
self.extra_messages = [
HumanMessage(content=f"Extra message {i}") for i in range(5)
]
# Mock state for testing state_modifier with many messages
self.state = AgentState(
messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages,
next=None,
)
@patch("ra_aid.anthropic_token_limiter.token_counter")
def test_create_token_counter_wrapper(self, mock_token_counter):
from ra_aid.config import DEFAULT_MODEL
# Setup mock return values
mock_token_counter.return_value = 50
# Create the wrapper
wrapper = create_token_counter_wrapper(DEFAULT_MODEL)
# Test with BaseMessage objects
result = wrapper([self.human_message])
self.assertEqual(result, 50)
# Test with empty list
result = wrapper([])
self.assertEqual(result, 0)
# Verify the mock was called with the right parameters
mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL)
@patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens")
def test_estimate_messages_tokens(self, mock_estimate_tokens):
# Setup mock to return different values for different messages
mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20
# Test with multiple messages
messages = [self.system_message, self.human_message]
result = estimate_messages_tokens(messages)
# Should be sum of individual token counts (10 + 20)
self.assertEqual(result, 30)
# Test with empty list
result = estimate_messages_tokens([])
self.assertEqual(result, 0)
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
def test_state_modifier(self, mock_print, mock_create_wrapper):
# Setup a proper token counter function that returns integers
# This function needs to return values that will cause trim_messages to keep only the first message
def token_counter(msgs):
# For a single message, return a small token count
if len(msgs) == 1:
return 10
# For two messages (first + one more), return a value under our limit
elif len(msgs) == 2:
return 30 # This is under our 40 token remaining budget (50-10)
# For three messages, return a value just under our limit
elif len(msgs) == 3:
return 40 # This is exactly at our 40 token remaining budget (50-10)
# For four messages, return a value just at our limit
elif len(msgs) == 4:
return 40 # This is exactly at our 40 token remaining budget (50-10)
# For five messages, return a value that exceeds our 40 token budget
elif len(msgs) == 5:
return 60 # This exceeds our 40 token budget, forcing only 4 more messages
# For more messages, return a value over our limit
else:
return 100 # This exceeds our limit
# Don't use side_effect here, directly return the function
mock_create_wrapper.return_value = token_counter
# Call state_modifier with a max token limit of 50
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
# Should keep first message and some of the others (up to 5 total)
self.assertEqual(len(result), 5) # First message plus four more
self.assertEqual(result[0], self.system_message) # First message is preserved
# Verify the wrapper was created with the right model
mock_create_wrapper.assert_called_with(self.mock_model.model)
# Verify print_messages_compact was called
mock_print.assert_called_once()
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
from ra_aid.config import DEFAULT_MODEL
# Setup mocks
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
mock_get_config_repo.return_value.get_all.return_value = mock_config
# Mock litellm's get_model_info to return a token limit
mock_get_model_info.return_value = {"max_input_tokens": 100000}
# Test getting token limit
result = get_model_token_limit(mock_config)
self.assertEqual(result, 100000)
# Verify get_model_info was called with the right model
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
# Setup mocks
mock_config = {"provider": "anthropic", "model": "claude-2"}
mock_get_config_repo.return_value.get_all.return_value = mock_config
# Make litellm's get_model_info raise an exception to test fallback
mock_get_model_info.side_effect = Exception("Model not found")
# Test getting token limit from models_params fallback
with patch("ra_aid.anthropic_token_limiter.models_params", {
"anthropic": {
"claude2": {"token_limit": 100000}
}
}):
result = get_model_token_limit(mock_config)
self.assertEqual(result, 100000)
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
@patch("litellm.get_model_info")
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
from ra_aid.config import DEFAULT_MODEL
# Setup mocks for different agent types
mock_config = {
"provider": "anthropic",
"model": DEFAULT_MODEL,
"research_provider": "openai",
"research_model": "gpt-4",
"planner_provider": "anthropic",
"planner_model": "claude-3-sonnet-20240229"
}
mock_get_config_repo.return_value.get_all.return_value = mock_config
# Mock different returns for different models
def model_info_side_effect(model_name):
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
return {"max_input_tokens": 200000}
elif "gpt-4" in model_name:
return {"max_input_tokens": 8192}
elif "claude-3-sonnet" in model_name:
return {"max_input_tokens": 100000}
else:
raise Exception(f"Unknown model: {model_name}")
mock_get_model_info.side_effect = model_info_side_effect
# Test default agent type
result = get_model_token_limit(mock_config, "default")
self.assertEqual(result, 200000)
# Test research agent type
result = get_model_token_limit(mock_config, "research")
self.assertEqual(result, 8192)
# Test planner agent type
result = get_model_token_limit(mock_config, "planner")
self.assertEqual(result, 100000)
if __name__ == "__main__":
unittest.main()