feat(main.py): refactor imports for better organization and readability
feat(main.py): add DEFAULT_MODEL constant to centralize model configuration feat(main.py): enhance logging and error handling for better debugging feat(main.py): implement state_modifier for managing token limits in agent state feat(anthropic_token_limiter.py): create utilities for handling token limits with Anthropic models feat(output.py): add print_messages_compact function for debugging message output test(anthropic_token_limiter.py): add unit tests for token limit utilities and state management
This commit is contained in:
parent
b4b0fdd686
commit
5c9a1e81d2
|
|
@ -39,32 +39,38 @@ from ra_aid.agents.research_agent import run_research_agent
|
|||
from ra_aid.agents import run_planning_agent
|
||||
from ra_aid.config import (
|
||||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_RECURSION_LIMIT,
|
||||
DEFAULT_TEST_CMD_TIMEOUT,
|
||||
VALID_PROVIDERS,
|
||||
)
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_fact_repository import (
|
||||
KeyFactRepositoryManager,
|
||||
get_key_fact_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.key_snippet_repository import (
|
||||
KeySnippetRepositoryManager, get_key_snippet_repository
|
||||
KeySnippetRepositoryManager,
|
||||
get_key_snippet_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
HumanInputRepositoryManager, get_human_input_repository
|
||||
HumanInputRepositoryManager,
|
||||
get_human_input_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.research_note_repository import (
|
||||
ResearchNoteRepositoryManager, get_research_note_repository
|
||||
ResearchNoteRepositoryManager,
|
||||
get_research_note_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.trajectory_repository import (
|
||||
TrajectoryRepositoryManager, get_trajectory_repository
|
||||
TrajectoryRepositoryManager,
|
||||
get_trajectory_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.related_files_repository import (
|
||||
RelatedFilesRepositoryManager
|
||||
)
|
||||
from ra_aid.database.repositories.work_log_repository import (
|
||||
WorkLogRepositoryManager
|
||||
RelatedFilesRepositoryManager,
|
||||
)
|
||||
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
|
||||
from ra_aid.database.repositories.config_repository import (
|
||||
ConfigRepositoryManager,
|
||||
get_config_repository
|
||||
get_config_repository,
|
||||
)
|
||||
from ra_aid.env_inv import EnvDiscovery
|
||||
from ra_aid.env_inv_context import EnvInvManager, get_env_inv
|
||||
|
|
@ -100,9 +106,9 @@ def launch_webui(host: str, port: int):
|
|||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
ANTHROPIC_DEFAULT_MODEL = "claude-3-7-sonnet-20250219"
|
||||
ANTHROPIC_DEFAULT_MODEL = DEFAULT_MODEL
|
||||
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
|
||||
# Case-insensitive log level argument type
|
||||
def log_level_type(value):
|
||||
value = value.lower()
|
||||
|
|
@ -199,8 +205,10 @@ Examples:
|
|||
help="Enable chat mode with direct human interaction (implies --hil)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-mode", choices=["console", "file"], default="file",
|
||||
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console"
|
||||
"--log-mode",
|
||||
choices=["console", "file"],
|
||||
default="file",
|
||||
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretty-logger", action="store_true", help="Enable pretty logging output"
|
||||
|
|
@ -378,20 +386,20 @@ def is_stage_requested(stage: str) -> bool:
|
|||
|
||||
def wipe_project_memory():
|
||||
"""Delete the project database file to wipe all stored memory.
|
||||
|
||||
|
||||
Returns:
|
||||
str: A message indicating the result of the operation
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
cwd = os.getcwd()
|
||||
ra_aid_dir = Path(os.path.join(cwd, ".ra-aid"))
|
||||
db_path = os.path.join(ra_aid_dir, "pk.db")
|
||||
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return "No project memory found to wipe."
|
||||
|
||||
|
||||
try:
|
||||
os.remove(db_path)
|
||||
return "Project memory wiped successfully."
|
||||
|
|
@ -403,11 +411,11 @@ def wipe_project_memory():
|
|||
|
||||
def build_status():
|
||||
"""Build status panel with model and feature information.
|
||||
|
||||
|
||||
Includes memory statistics at the bottom with counts of key facts, snippets, and research notes.
|
||||
"""
|
||||
status = Text()
|
||||
|
||||
|
||||
# Get the config repository to get model/provider information
|
||||
config_repo = get_config_repository()
|
||||
provider = config_repo.get("provider", "")
|
||||
|
|
@ -415,12 +423,14 @@ def build_status():
|
|||
temperature = config_repo.get("temperature")
|
||||
expert_provider = config_repo.get("expert_provider", "")
|
||||
expert_model = config_repo.get("expert_model", "")
|
||||
experimental_fallback_handler = config_repo.get("experimental_fallback_handler", False)
|
||||
experimental_fallback_handler = config_repo.get(
|
||||
"experimental_fallback_handler", False
|
||||
)
|
||||
web_research_enabled = config_repo.get("web_research_enabled", False)
|
||||
|
||||
|
||||
# Get the expert enabled status
|
||||
expert_enabled = bool(expert_provider and expert_model)
|
||||
|
||||
|
||||
# Basic model information
|
||||
status.append("🤖 ")
|
||||
status.append(f"{provider}/{model}")
|
||||
|
|
@ -452,39 +462,41 @@ def build_status():
|
|||
[fb_handler._format_model(m) for m in fb_handler.fallback_tool_models]
|
||||
)
|
||||
status.append(msg)
|
||||
|
||||
|
||||
# Add memory statistics
|
||||
# Get counts of key facts, snippets, and research notes with error handling
|
||||
fact_count = 0
|
||||
snippet_count = 0
|
||||
note_count = 0
|
||||
|
||||
|
||||
try:
|
||||
fact_count = len(get_key_fact_repository().get_all())
|
||||
except RuntimeError as e:
|
||||
logger.debug(f"Failed to get key facts count: {e}")
|
||||
|
||||
|
||||
try:
|
||||
snippet_count = len(get_key_snippet_repository().get_all())
|
||||
except RuntimeError as e:
|
||||
logger.debug(f"Failed to get key snippets count: {e}")
|
||||
|
||||
|
||||
try:
|
||||
note_count = len(get_research_note_repository().get_all())
|
||||
except RuntimeError as e:
|
||||
logger.debug(f"Failed to get research notes count: {e}")
|
||||
|
||||
|
||||
# Add memory statistics line with reset option note
|
||||
status.append(f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes")
|
||||
status.append(
|
||||
f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes"
|
||||
)
|
||||
if fact_count > 0 or snippet_count > 0 or note_count > 0:
|
||||
status.append(" (use --wipe-project-memory to reset)")
|
||||
|
||||
|
||||
# Check for newer version
|
||||
version_message = check_for_newer_version()
|
||||
if version_message:
|
||||
status.append("\n\n")
|
||||
status.append(version_message, style="yellow")
|
||||
|
||||
|
||||
return status
|
||||
|
||||
|
||||
|
|
@ -493,7 +505,7 @@ def main():
|
|||
args = parse_arguments()
|
||||
setup_logging(args.log_mode, args.pretty_logger, args.log_level)
|
||||
logger.debug("Starting RA.Aid with arguments: %s", args)
|
||||
|
||||
|
||||
# Check if we need to wipe project memory before starting
|
||||
if args.wipe_project_memory:
|
||||
result = wipe_project_memory()
|
||||
|
|
@ -519,22 +531,24 @@ def main():
|
|||
|
||||
# Initialize empty config dictionary to be populated later
|
||||
config = {}
|
||||
|
||||
|
||||
# Initialize repositories with database connection
|
||||
# Create environment inventory data
|
||||
env_discovery = EnvDiscovery()
|
||||
env_discovery.discover()
|
||||
env_data = env_discovery.format_markdown()
|
||||
|
||||
with KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
||||
WorkLogRepositoryManager() as work_log_repo, \
|
||||
ConfigRepositoryManager(config) as config_repo, \
|
||||
EnvInvManager(env_data) as env_inv:
|
||||
|
||||
with (
|
||||
KeyFactRepositoryManager(db) as key_fact_repo,
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo,
|
||||
HumanInputRepositoryManager(db) as human_input_repo,
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo,
|
||||
RelatedFilesRepositoryManager() as related_files_repo,
|
||||
TrajectoryRepositoryManager(db) as trajectory_repo,
|
||||
WorkLogRepositoryManager() as work_log_repo,
|
||||
ConfigRepositoryManager(config) as config_repo,
|
||||
EnvInvManager(env_data) as env_inv,
|
||||
):
|
||||
# This initializes all repositories and makes them available via their respective get methods
|
||||
logger.debug("Initialized KeyFactRepository")
|
||||
logger.debug("Initialized KeySnippetRepository")
|
||||
|
|
@ -554,7 +568,9 @@ def main():
|
|||
expert_missing,
|
||||
web_research_enabled,
|
||||
web_research_missing,
|
||||
) = validate_environment(args) # Will exit if main env vars missing
|
||||
) = validate_environment(
|
||||
args
|
||||
) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
|
||||
# Validate model configuration early
|
||||
|
|
@ -590,11 +606,15 @@ def main():
|
|||
config_repo.set("expert_provider", args.expert_provider)
|
||||
config_repo.set("expert_model", args.expert_model)
|
||||
config_repo.set("temperature", args.temperature)
|
||||
config_repo.set("experimental_fallback_handler", args.experimental_fallback_handler)
|
||||
config_repo.set(
|
||||
"experimental_fallback_handler", args.experimental_fallback_handler
|
||||
)
|
||||
config_repo.set("web_research_enabled", web_research_enabled)
|
||||
config_repo.set("show_thoughts", args.show_thoughts)
|
||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
||||
config_repo.set(
|
||||
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||
)
|
||||
|
||||
# Build status panel with memory statistics
|
||||
status = build_status()
|
||||
|
|
@ -633,13 +653,15 @@ def main():
|
|||
initial_request = ask_human.invoke(
|
||||
{"question": "What would you like help with?"}
|
||||
)
|
||||
|
||||
|
||||
# Record chat input in database (redundant as ask_human already records it,
|
||||
# but needed in case the ask_human implementation changes)
|
||||
try:
|
||||
# Using get_human_input_repository() to access the repository from context
|
||||
human_input_repository = get_human_input_repository()
|
||||
human_input_repository.create(content=initial_request, source='chat')
|
||||
human_input_repository.create(
|
||||
content=initial_request, source="chat"
|
||||
)
|
||||
human_input_repository.garbage_collect()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record initial chat input: {str(e)}")
|
||||
|
|
@ -668,8 +690,12 @@ def main():
|
|||
config_repo.set("expert_model", args.expert_model)
|
||||
config_repo.set("temperature", args.temperature)
|
||||
config_repo.set("show_thoughts", args.show_thoughts)
|
||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
||||
config_repo.set(
|
||||
"force_reasoning_assistance", args.reasoning_assistance
|
||||
)
|
||||
config_repo.set(
|
||||
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||
)
|
||||
|
||||
# Set modification tools based on use_aider flag
|
||||
set_modification_tools(args.use_aider)
|
||||
|
|
@ -696,8 +722,12 @@ def main():
|
|||
),
|
||||
working_directory=working_directory,
|
||||
current_date=current_date,
|
||||
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
|
||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
||||
key_facts=format_key_facts_dict(
|
||||
get_key_fact_repository().get_facts_dict()
|
||||
),
|
||||
key_snippets=format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
),
|
||||
project_info=formatted_project_info,
|
||||
env_inv=get_env_inv(),
|
||||
),
|
||||
|
|
@ -711,12 +741,12 @@ def main():
|
|||
sys.exit(1)
|
||||
|
||||
base_task = args.message
|
||||
|
||||
|
||||
# Record CLI input in database
|
||||
try:
|
||||
# Using get_human_input_repository() to access the repository from context
|
||||
human_input_repository = get_human_input_repository()
|
||||
human_input_repository.create(content=base_task, source='cli')
|
||||
human_input_repository.create(content=base_task, source="cli")
|
||||
# Run garbage collection to ensure we don't exceed 100 inputs
|
||||
human_input_repository.garbage_collect()
|
||||
logger.debug(f"Recorded CLI input: {base_task}")
|
||||
|
|
@ -750,19 +780,25 @@ def main():
|
|||
config_repo.set("expert_model", args.expert_model)
|
||||
|
||||
# Store planner config with fallback to base values
|
||||
config_repo.set("planner_provider", args.planner_provider or args.provider)
|
||||
config_repo.set(
|
||||
"planner_provider", args.planner_provider or args.provider
|
||||
)
|
||||
config_repo.set("planner_model", args.planner_model or args.model)
|
||||
|
||||
# Store research config with fallback to base values
|
||||
config_repo.set("research_provider", args.research_provider or args.provider)
|
||||
config_repo.set(
|
||||
"research_provider", args.research_provider or args.provider
|
||||
)
|
||||
config_repo.set("research_model", args.research_model or args.model)
|
||||
|
||||
# Store temperature in config
|
||||
config_repo.set("temperature", args.temperature)
|
||||
|
||||
|
||||
# Store reasoning assistance flags
|
||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
||||
config_repo.set(
|
||||
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||
)
|
||||
|
||||
# Set modification tools based on use_aider flag
|
||||
set_modification_tools(args.use_aider)
|
||||
|
|
@ -794,5 +830,6 @@ def main():
|
|||
print()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,19 +1,14 @@
|
|||
"""Utility functions for working with agents."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||
|
||||
|
||||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
from openai import RateLimitError as OpenAIRateLimitError
|
||||
from litellm.exceptions import RateLimitError as LiteLLMRateLimitError
|
||||
|
|
@ -23,28 +18,24 @@ from langchain_core.messages import (
|
|||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import get_model_info
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
agent_context,
|
||||
get_depth,
|
||||
is_completed,
|
||||
reset_completion_flags,
|
||||
should_exit,
|
||||
)
|
||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||
from ra_aid.agents_alias import RAgents
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||
from ra_aid.console.formatting import print_error, print_stage_header
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.exceptions import (
|
||||
AgentInterrupt,
|
||||
|
|
@ -53,76 +44,16 @@ from ra_aid.exceptions import (
|
|||
)
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.llm import initialize_expert_llm
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.text.processing import process_thinking_content
|
||||
from ra_aid.project_info import (
|
||||
display_project_status,
|
||||
format_project_info,
|
||||
get_project_info,
|
||||
)
|
||||
from ra_aid.prompts.expert_prompts import (
|
||||
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
|
||||
EXPERT_PROMPT_SECTION_PLANNING,
|
||||
EXPERT_PROMPT_SECTION_RESEARCH,
|
||||
)
|
||||
from ra_aid.prompts.human_prompts import (
|
||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
|
||||
HUMAN_PROMPT_SECTION_PLANNING,
|
||||
HUMAN_PROMPT_SECTION_RESEARCH,
|
||||
)
|
||||
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
||||
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
||||
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
||||
from ra_aid.prompts.reasoning_assist_prompt import (
|
||||
REASONING_ASSIST_PROMPT_PLANNING,
|
||||
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
|
||||
REASONING_ASSIST_PROMPT_RESEARCH,
|
||||
)
|
||||
from ra_aid.prompts.research_prompts import (
|
||||
RESEARCH_ONLY_PROMPT,
|
||||
RESEARCH_PROMPT,
|
||||
)
|
||||
from ra_aid.prompts.web_research_prompts import (
|
||||
WEB_RESEARCH_PROMPT,
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT,
|
||||
WEB_RESEARCH_PROMPT_SECTION_PLANNING,
|
||||
WEB_RESEARCH_PROMPT_SECTION_RESEARCH,
|
||||
)
|
||||
from ra_aid.tool_configs import (
|
||||
get_implementation_tools,
|
||||
get_planning_tools,
|
||||
get_research_tools,
|
||||
get_web_research_tools,
|
||||
)
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import (
|
||||
get_key_snippet_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
get_human_input_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.research_note_repository import (
|
||||
get_research_note_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||
from ra_aid.tools.memory import (
|
||||
get_related_files,
|
||||
log_work_event,
|
||||
)
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.env_inv_context import get_env_inv
|
||||
from ra_aid.anthropic_token_limiter import state_modifier, get_model_token_limit
|
||||
|
||||
console = Console()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Import repositories using get_* functions
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -132,131 +63,19 @@ def output_markdown_message(message: str) -> str:
|
|||
return "Message output."
|
||||
|
||||
|
||||
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
"""Helper function to estimate total tokens in a sequence of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to count tokens for
|
||||
|
||||
Returns:
|
||||
Total estimated token count
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
estimate_tokens = CiaynAgent._estimate_tokens
|
||||
return sum(estimate_tokens(msg) for msg in messages)
|
||||
|
||||
|
||||
def state_modifier(
|
||||
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing messages
|
||||
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||
|
||||
Returns:
|
||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||
"""
|
||||
messages = state["messages"]
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
first_message = messages[0]
|
||||
remaining_messages = messages[1:]
|
||||
first_tokens = estimate_messages_tokens([first_message])
|
||||
new_max_tokens = max_input_tokens - first_tokens
|
||||
|
||||
trimmed_remaining = trim_messages(
|
||||
remaining_messages,
|
||||
token_counter=estimate_messages_tokens,
|
||||
max_tokens=new_max_tokens,
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
)
|
||||
|
||||
return [first_message] + trimmed_remaining
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
|
||||
) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration based on agent type.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The token limit if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to get config from repository for production use
|
||||
try:
|
||||
config_from_repo = get_config_repository().get_all()
|
||||
# If we succeeded, use the repository config instead of passed config
|
||||
config = config_from_repo
|
||||
except RuntimeError:
|
||||
# In tests, this may fail because the repository isn't set up
|
||||
# So we'll use the passed config directly
|
||||
pass
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
try:
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
model_info = get_model_info(provider_model)
|
||||
max_input_tokens = model_info.get("max_input_tokens")
|
||||
if max_input_tokens:
|
||||
logger.debug(
|
||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||
)
|
||||
return max_input_tokens
|
||||
except litellm.exceptions.NotFoundError:
|
||||
logger.debug(
|
||||
f"Model {model_name} not found in litellm, falling back to models_params"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||
)
|
||||
|
||||
# Fallback to models_params dict
|
||||
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
||||
normalized_name = model_name.replace("-", "")
|
||||
provider_tokens = models_params.get(provider, {})
|
||||
if normalized_name in provider_tokens:
|
||||
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
|
||||
logger.debug(
|
||||
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
||||
)
|
||||
else:
|
||||
max_input_tokens = None
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model token limit: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def build_agent_kwargs(
|
||||
checkpointer: Optional[Any] = None,
|
||||
model: ChatAnthropic = None,
|
||||
max_input_tokens: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs dictionary for agent creation.
|
||||
|
||||
Args:
|
||||
checkpointer: Optional memory checkpointer
|
||||
config: Optional configuration dictionary
|
||||
token_limit: Optional token limit for the model
|
||||
model: The language model to use for token counting
|
||||
max_input_tokens: Optional token limit for the model
|
||||
|
||||
Returns:
|
||||
Dictionary of kwargs for agent creation
|
||||
|
|
@ -269,12 +88,17 @@ def build_agent_kwargs(
|
|||
agent_kwargs["checkpointer"] = checkpointer
|
||||
|
||||
config = get_config_repository().get_all()
|
||||
if config.get("limit_tokens", True) and is_anthropic_claude(config):
|
||||
if (
|
||||
config.get("limit_tokens", True)
|
||||
and is_anthropic_claude(config)
|
||||
and model is not None
|
||||
):
|
||||
|
||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||
return state_modifier(state, max_input_tokens=max_input_tokens)
|
||||
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
||||
|
||||
agent_kwargs["state_modifier"] = wrapped_state_modifier
|
||||
agent_kwargs["name"] = "React"
|
||||
|
||||
return agent_kwargs
|
||||
|
||||
|
|
@ -340,11 +164,13 @@ def create_agent(
|
|||
max_input_tokens = (
|
||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||
)
|
||||
print(f"max_input_tokens={max_input_tokens}")
|
||||
|
||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||
if is_anthropic_claude(config):
|
||||
logger.debug("Using create_react_agent to instantiate agent.")
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
)
|
||||
|
|
@ -357,16 +183,12 @@ def create_agent(
|
|||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
config = get_config_repository().get_all()
|
||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
)
|
||||
|
||||
|
||||
from ra_aid.agents.research_agent import run_research_agent, run_web_research_agent
|
||||
from ra_aid.agents.implementation_agent import run_task_implementation_agent
|
||||
|
||||
|
||||
_CONTEXT_STACK = []
|
||||
_INTERRUPT_CONTEXT = None
|
||||
_FEEDBACK_MODE = False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,210 @@
|
|||
"""Utilities for handling token limits with Anthropic models."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import BaseMessage, trim_messages
|
||||
from langchain_core.messages.base import messages_to_dict
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import token_counter
|
||||
|
||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.console.output import print_messages_compact
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
"""Helper function to estimate total tokens in a sequence of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to count tokens for
|
||||
|
||||
Returns:
|
||||
Total estimated token count
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
estimate_tokens = CiaynAgent._estimate_tokens
|
||||
return sum(estimate_tokens(msg) for msg in messages)
|
||||
|
||||
|
||||
def create_token_counter_wrapper(model: str):
|
||||
"""Create a wrapper for token counter that handles BaseMessage conversion.
|
||||
|
||||
Args:
|
||||
model: The model name to use for token counting
|
||||
|
||||
Returns:
|
||||
A function that accepts BaseMessage objects and returns token count
|
||||
"""
|
||||
|
||||
# Create a partial function that already has the model parameter set
|
||||
base_token_counter = partial(token_counter, model=model)
|
||||
|
||||
def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int:
|
||||
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
|
||||
|
||||
Args:
|
||||
messages: List of messages (either BaseMessage objects or dicts)
|
||||
|
||||
Returns:
|
||||
Token count for the messages
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
if isinstance(messages[0], BaseMessage):
|
||||
messages_dicts = [msg["data"] for msg in messages_to_dict(messages)]
|
||||
return base_token_counter(messages=messages_dicts)
|
||||
else:
|
||||
# Already in dict format
|
||||
return base_token_counter(messages=messages)
|
||||
|
||||
return wrapped_token_counter
|
||||
|
||||
|
||||
def state_modifier(
|
||||
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing messages
|
||||
model: The language model to use for token counting
|
||||
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||
|
||||
Returns:
|
||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||
"""
|
||||
messages = state["messages"]
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
first_message = messages[0]
|
||||
remaining_messages = messages[1:]
|
||||
|
||||
|
||||
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
||||
|
||||
first_tokens = wrapped_token_counter([first_message])
|
||||
new_max_tokens = max_input_tokens - first_tokens
|
||||
|
||||
print_messages_compact(messages)
|
||||
|
||||
trimmed_remaining = trim_messages(
|
||||
remaining_messages,
|
||||
token_counter=wrapped_token_counter,
|
||||
max_tokens=new_max_tokens,
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
)
|
||||
|
||||
return [first_message] + trimmed_remaining
|
||||
|
||||
|
||||
def sonnet_3_5_state_modifier(
|
||||
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing messages
|
||||
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||
|
||||
Returns:
|
||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||
"""
|
||||
messages = state["messages"]
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
first_message = messages[0]
|
||||
remaining_messages = messages[1:]
|
||||
first_tokens = estimate_messages_tokens([first_message])
|
||||
new_max_tokens = max_input_tokens - first_tokens
|
||||
|
||||
trimmed_remaining = trim_messages(
|
||||
remaining_messages,
|
||||
token_counter=estimate_messages_tokens,
|
||||
max_tokens=new_max_tokens,
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
)
|
||||
|
||||
return [first_message] + trimmed_remaining
|
||||
|
||||
|
||||
def get_model_token_limit(
|
||||
config: Dict[str, Any], agent_type: str = "default"
|
||||
) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration based on agent type.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing provider and model information
|
||||
agent_type: Type of agent ("default", "research", or "planner")
|
||||
|
||||
Returns:
|
||||
Optional[int]: The token limit if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to get config from repository for production use
|
||||
try:
|
||||
config_from_repo = get_config_repository().get_all()
|
||||
# If we succeeded, use the repository config instead of passed config
|
||||
config = config_from_repo
|
||||
except RuntimeError:
|
||||
# In tests, this may fail because the repository isn't set up
|
||||
# So we'll use the passed config directly
|
||||
pass
|
||||
if agent_type == "research":
|
||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("research_model", "") or config.get("model", "")
|
||||
elif agent_type == "planner":
|
||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||
else:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||
model_info = get_model_info(provider_model)
|
||||
max_input_tokens = model_info.get("max_input_tokens")
|
||||
if max_input_tokens:
|
||||
logger.debug(
|
||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||
)
|
||||
return max_input_tokens
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||
)
|
||||
|
||||
# Fallback to models_params dict
|
||||
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
||||
normalized_name = model_name.replace("-", "")
|
||||
provider_tokens = models_params.get(provider, {})
|
||||
if normalized_name in provider_tokens:
|
||||
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
|
||||
logger.debug(
|
||||
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
||||
)
|
||||
else:
|
||||
max_input_tokens = None
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model token limit: {e}")
|
||||
return None
|
||||
|
|
@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3
|
|||
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||
RETRY_FALLBACK_COUNT = 3
|
||||
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
||||
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
|
||||
|
||||
|
||||
VALID_PROVIDERS = [
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
|
|
@ -94,3 +94,57 @@ def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -
|
|||
"""
|
||||
|
||||
console.print(Panel(Markdown(message), title=title, border_style=border_style))
|
||||
|
||||
|
||||
def print_messages_compact(messages: Sequence[BaseMessage]) -> None:
|
||||
"""Print a compact representation of a list of messages.
|
||||
|
||||
Warning: Used mainly for debugging purposes so do not delete if not referenced anywhere!
|
||||
For all message types, only the first 30 characters of content are shown.
|
||||
|
||||
Args:
|
||||
messages: A sequence of BaseMessage objects to print
|
||||
"""
|
||||
if not messages:
|
||||
console.print("[italic]No messages[/italic]")
|
||||
return
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
msg_type = msg.__class__.__name__
|
||||
content = msg.content
|
||||
|
||||
# Process content based on its type
|
||||
if isinstance(content, str):
|
||||
display_content = f"{content[:30]}..." if len(content) > 30 else content
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (list of content blocks)
|
||||
content_preview = []
|
||||
for item in content[:2]: # Show first 2 items at most
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
content_preview.append(f"text: {text[:20]}..." if len(text) > 20 else f"text: {text}")
|
||||
elif item.get("type") == "tool_call":
|
||||
tool_name = item.get("tool_call", {}).get("name", "unknown")
|
||||
content_preview.append(f"tool_call: {tool_name}")
|
||||
else:
|
||||
content_preview.append(f"{item.get('type', 'unknown')}")
|
||||
|
||||
if len(content) > 2:
|
||||
content_preview.append(f"...({len(content)-2} more)")
|
||||
|
||||
display_content = ", ".join(content_preview)
|
||||
else:
|
||||
display_content = str(content)[:30] + "..." if len(str(content)) > 30 else str(content)
|
||||
|
||||
# Add additional tool message info if available
|
||||
additional_info = []
|
||||
if hasattr(msg, "tool_call_id") and msg.tool_call_id:
|
||||
additional_info.append(f"tool_call_id: {msg.tool_call_id}")
|
||||
if hasattr(msg, "name") and msg.name:
|
||||
additional_info.append(f"name: {msg.name}")
|
||||
if hasattr(msg, "status") and msg.status:
|
||||
additional_info.append(f"status: {msg.status}")
|
||||
|
||||
info_str = f" ({', '.join(additional_info)})" if additional_info else ""
|
||||
console.print(f"[{i}] [bold]{msg_type}{info_str}[/bold]: {display_content}")
|
||||
|
|
|
|||
|
|
@ -241,8 +241,9 @@ def create_llm_client(
|
|||
else:
|
||||
temp_kwargs = {}
|
||||
|
||||
thinking_kwargs = {}
|
||||
if supports_thinking:
|
||||
temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
||||
thinking_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
||||
|
||||
if provider == "deepseek":
|
||||
return create_deepseek_client(
|
||||
|
|
@ -250,6 +251,7 @@ def create_llm_client(
|
|||
api_key=config["api_key"],
|
||||
base_url=config["base_url"],
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
is_expert=is_expert,
|
||||
)
|
||||
elif provider == "openrouter":
|
||||
|
|
@ -257,6 +259,7 @@ def create_llm_client(
|
|||
model_name=model_name,
|
||||
api_key=config["api_key"],
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
is_expert=is_expert,
|
||||
)
|
||||
elif provider == "openai":
|
||||
|
|
@ -271,6 +274,7 @@ def create_llm_client(
|
|||
return ChatOpenAI(
|
||||
**{
|
||||
**openai_kwargs,
|
||||
**thinking_kwargs,
|
||||
"timeout": LLM_REQUEST_TIMEOUT,
|
||||
"max_retries": LLM_MAX_RETRIES,
|
||||
}
|
||||
|
|
@ -283,6 +287,7 @@ def create_llm_client(
|
|||
max_retries=LLM_MAX_RETRIES,
|
||||
max_tokens=model_config.get("max_tokens", 64000),
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
return ChatOpenAI(
|
||||
|
|
@ -292,6 +297,7 @@ def create_llm_client(
|
|||
timeout=LLM_REQUEST_TIMEOUT,
|
||||
max_retries=LLM_MAX_RETRIES,
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif provider == "gemini":
|
||||
return ChatGoogleGenerativeAI(
|
||||
|
|
@ -300,6 +306,7 @@ def create_llm_client(
|
|||
timeout=LLM_REQUEST_TIMEOUT,
|
||||
max_retries=LLM_MAX_RETRIES,
|
||||
**temp_kwargs,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from ra_aid.agent_context import (
|
|||
is_crashed,
|
||||
reset_completion_flags,
|
||||
)
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
|
|
@ -337,7 +338,7 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
config = get_config_repository().get_all()
|
||||
model = initialize_llm(
|
||||
config.get("provider", "anthropic"),
|
||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
||||
config.get("model",DEFAULT_MODEL),
|
||||
temperature=config.get("temperature"),
|
||||
)
|
||||
|
||||
|
|
@ -475,7 +476,7 @@ def request_implementation(task_spec: str) -> str:
|
|||
config = get_config_repository().get_all()
|
||||
model = initialize_llm(
|
||||
config.get("provider", "anthropic"),
|
||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
||||
config.get("model", DEFAULT_MODEL),
|
||||
temperature=config.get("temperature"),
|
||||
)
|
||||
|
||||
|
|
@ -592,4 +593,4 @@ def request_implementation(task_spec: str) -> str:
|
|||
# Join all parts into a single markdown string
|
||||
markdown_output = "".join(markdown_parts)
|
||||
|
||||
return markdown_output
|
||||
return markdown_output
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
|
||||
from ra_aid.anthropic_token_limiter import (
|
||||
create_token_counter_wrapper,
|
||||
estimate_messages_tokens,
|
||||
get_model_token_limit,
|
||||
state_modifier,
|
||||
)
|
||||
|
||||
|
||||
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
|
||||
self.mock_model = MagicMock(spec=ChatAnthropic)
|
||||
self.mock_model.model = DEFAULT_MODEL
|
||||
|
||||
# Sample messages for testing
|
||||
self.system_message = SystemMessage(content="You are a helpful assistant.")
|
||||
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
|
||||
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
|
||||
|
||||
# Create more messages for testing
|
||||
self.extra_messages = [
|
||||
HumanMessage(content=f"Extra message {i}") for i in range(5)
|
||||
]
|
||||
|
||||
# Mock state for testing state_modifier with many messages
|
||||
self.state = AgentState(
|
||||
messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages,
|
||||
next=None,
|
||||
)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.token_counter")
|
||||
def test_create_token_counter_wrapper(self, mock_token_counter):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
|
||||
# Setup mock return values
|
||||
mock_token_counter.return_value = 50
|
||||
|
||||
# Create the wrapper
|
||||
wrapper = create_token_counter_wrapper(DEFAULT_MODEL)
|
||||
|
||||
# Test with BaseMessage objects
|
||||
result = wrapper([self.human_message])
|
||||
self.assertEqual(result, 50)
|
||||
|
||||
# Test with empty list
|
||||
result = wrapper([])
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
# Verify the mock was called with the right parameters
|
||||
mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens")
|
||||
def test_estimate_messages_tokens(self, mock_estimate_tokens):
|
||||
# Setup mock to return different values for different messages
|
||||
mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20
|
||||
|
||||
# Test with multiple messages
|
||||
messages = [self.system_message, self.human_message]
|
||||
result = estimate_messages_tokens(messages)
|
||||
|
||||
# Should be sum of individual token counts (10 + 20)
|
||||
self.assertEqual(result, 30)
|
||||
|
||||
# Test with empty list
|
||||
result = estimate_messages_tokens([])
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
||||
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
||||
def test_state_modifier(self, mock_print, mock_create_wrapper):
|
||||
# Setup a proper token counter function that returns integers
|
||||
# This function needs to return values that will cause trim_messages to keep only the first message
|
||||
def token_counter(msgs):
|
||||
# For a single message, return a small token count
|
||||
if len(msgs) == 1:
|
||||
return 10
|
||||
# For two messages (first + one more), return a value under our limit
|
||||
elif len(msgs) == 2:
|
||||
return 30 # This is under our 40 token remaining budget (50-10)
|
||||
# For three messages, return a value just under our limit
|
||||
elif len(msgs) == 3:
|
||||
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
||||
# For four messages, return a value just at our limit
|
||||
elif len(msgs) == 4:
|
||||
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
||||
# For five messages, return a value that exceeds our 40 token budget
|
||||
elif len(msgs) == 5:
|
||||
return 60 # This exceeds our 40 token budget, forcing only 4 more messages
|
||||
# For more messages, return a value over our limit
|
||||
else:
|
||||
return 100 # This exceeds our limit
|
||||
|
||||
# Don't use side_effect here, directly return the function
|
||||
mock_create_wrapper.return_value = token_counter
|
||||
|
||||
# Call state_modifier with a max token limit of 50
|
||||
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
|
||||
|
||||
# Should keep first message and some of the others (up to 5 total)
|
||||
self.assertEqual(len(result), 5) # First message plus four more
|
||||
self.assertEqual(result[0], self.system_message) # First message is preserved
|
||||
|
||||
# Verify the wrapper was created with the right model
|
||||
mock_create_wrapper.assert_called_with(self.mock_model.model)
|
||||
|
||||
# Verify print_messages_compact was called
|
||||
mock_print.assert_called_once()
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
|
||||
# Setup mocks
|
||||
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Mock litellm's get_model_info to return a token limit
|
||||
mock_get_model_info.return_value = {"max_input_tokens": 100000}
|
||||
|
||||
# Test getting token limit
|
||||
result = get_model_token_limit(mock_config)
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
# Verify get_model_info was called with the right model
|
||||
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
|
||||
# Setup mocks
|
||||
mock_config = {"provider": "anthropic", "model": "claude-2"}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Make litellm's get_model_info raise an exception to test fallback
|
||||
mock_get_model_info.side_effect = Exception("Model not found")
|
||||
|
||||
# Test getting token limit from models_params fallback
|
||||
with patch("ra_aid.anthropic_token_limiter.models_params", {
|
||||
"anthropic": {
|
||||
"claude2": {"token_limit": 100000}
|
||||
}
|
||||
}):
|
||||
result = get_model_token_limit(mock_config)
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||
@patch("litellm.get_model_info")
|
||||
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
|
||||
# Setup mocks for different agent types
|
||||
mock_config = {
|
||||
"provider": "anthropic",
|
||||
"model": DEFAULT_MODEL,
|
||||
"research_provider": "openai",
|
||||
"research_model": "gpt-4",
|
||||
"planner_provider": "anthropic",
|
||||
"planner_model": "claude-3-sonnet-20240229"
|
||||
}
|
||||
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||
|
||||
# Mock different returns for different models
|
||||
def model_info_side_effect(model_name):
|
||||
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
|
||||
return {"max_input_tokens": 200000}
|
||||
elif "gpt-4" in model_name:
|
||||
return {"max_input_tokens": 8192}
|
||||
elif "claude-3-sonnet" in model_name:
|
||||
return {"max_input_tokens": 100000}
|
||||
else:
|
||||
raise Exception(f"Unknown model: {model_name}")
|
||||
|
||||
mock_get_model_info.side_effect = model_info_side_effect
|
||||
|
||||
# Test default agent type
|
||||
result = get_model_token_limit(mock_config, "default")
|
||||
self.assertEqual(result, 200000)
|
||||
|
||||
# Test research agent type
|
||||
result = get_model_token_limit(mock_config, "research")
|
||||
self.assertEqual(result, 8192)
|
||||
|
||||
# Test planner agent type
|
||||
result = get_model_token_limit(mock_config, "planner")
|
||||
self.assertEqual(result, 100000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue