Merge pull request #124 from ariel-frischer/fix-token-limiter
Fix Sonnet 3.7 Token Limiter API Errors
This commit is contained in:
commit
a9656552a9
|
|
@ -39,35 +39,41 @@ from ra_aid.agents.research_agent import run_research_agent
|
||||||
from ra_aid.agents import run_planning_agent
|
from ra_aid.agents import run_planning_agent
|
||||||
from ra_aid.config import (
|
from ra_aid.config import (
|
||||||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||||
|
DEFAULT_MODEL,
|
||||||
DEFAULT_RECURSION_LIMIT,
|
DEFAULT_RECURSION_LIMIT,
|
||||||
DEFAULT_TEST_CMD_TIMEOUT,
|
DEFAULT_TEST_CMD_TIMEOUT,
|
||||||
VALID_PROVIDERS,
|
VALID_PROVIDERS,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import (
|
||||||
|
KeyFactRepositoryManager,
|
||||||
|
get_key_fact_repository,
|
||||||
|
)
|
||||||
from ra_aid.database.repositories.key_snippet_repository import (
|
from ra_aid.database.repositories.key_snippet_repository import (
|
||||||
KeySnippetRepositoryManager, get_key_snippet_repository
|
KeySnippetRepositoryManager,
|
||||||
|
get_key_snippet_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.human_input_repository import (
|
from ra_aid.database.repositories.human_input_repository import (
|
||||||
HumanInputRepositoryManager, get_human_input_repository
|
HumanInputRepositoryManager,
|
||||||
|
get_human_input_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.research_note_repository import (
|
from ra_aid.database.repositories.research_note_repository import (
|
||||||
ResearchNoteRepositoryManager, get_research_note_repository
|
ResearchNoteRepositoryManager,
|
||||||
|
get_research_note_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.trajectory_repository import (
|
from ra_aid.database.repositories.trajectory_repository import (
|
||||||
TrajectoryRepositoryManager, get_trajectory_repository
|
TrajectoryRepositoryManager,
|
||||||
|
get_trajectory_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.session_repository import (
|
from ra_aid.database.repositories.session_repository import (
|
||||||
SessionRepositoryManager, get_session_repository
|
SessionRepositoryManager, get_session_repository
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.related_files_repository import (
|
from ra_aid.database.repositories.related_files_repository import (
|
||||||
RelatedFilesRepositoryManager
|
RelatedFilesRepositoryManager,
|
||||||
)
|
|
||||||
from ra_aid.database.repositories.work_log_repository import (
|
|
||||||
WorkLogRepositoryManager
|
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
|
||||||
from ra_aid.database.repositories.config_repository import (
|
from ra_aid.database.repositories.config_repository import (
|
||||||
ConfigRepositoryManager,
|
ConfigRepositoryManager,
|
||||||
get_config_repository
|
get_config_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.env_inv import EnvDiscovery
|
from ra_aid.env_inv import EnvDiscovery
|
||||||
from ra_aid.env_inv_context import EnvInvManager, get_env_inv
|
from ra_aid.env_inv_context import EnvInvManager, get_env_inv
|
||||||
|
|
@ -103,9 +109,9 @@ def launch_webui(host: str, port: int):
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments(args=None):
|
def parse_arguments(args=None):
|
||||||
ANTHROPIC_DEFAULT_MODEL = "claude-3-7-sonnet-20250219"
|
ANTHROPIC_DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
||||||
|
|
||||||
# Case-insensitive log level argument type
|
# Case-insensitive log level argument type
|
||||||
def log_level_type(value):
|
def log_level_type(value):
|
||||||
value = value.lower()
|
value = value.lower()
|
||||||
|
|
@ -202,8 +208,10 @@ Examples:
|
||||||
help="Enable chat mode with direct human interaction (implies --hil)",
|
help="Enable chat mode with direct human interaction (implies --hil)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-mode", choices=["console", "file"], default="file",
|
"--log-mode",
|
||||||
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console"
|
choices=["console", "file"],
|
||||||
|
default="file",
|
||||||
|
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pretty-logger", action="store_true", help="Enable pretty logging output"
|
"--pretty-logger", action="store_true", help="Enable pretty logging output"
|
||||||
|
|
@ -386,20 +394,20 @@ def is_stage_requested(stage: str) -> bool:
|
||||||
|
|
||||||
def wipe_project_memory():
|
def wipe_project_memory():
|
||||||
"""Delete the project database file to wipe all stored memory.
|
"""Delete the project database file to wipe all stored memory.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A message indicating the result of the operation
|
str: A message indicating the result of the operation
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
ra_aid_dir = Path(os.path.join(cwd, ".ra-aid"))
|
ra_aid_dir = Path(os.path.join(cwd, ".ra-aid"))
|
||||||
db_path = os.path.join(ra_aid_dir, "pk.db")
|
db_path = os.path.join(ra_aid_dir, "pk.db")
|
||||||
|
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return "No project memory found to wipe."
|
return "No project memory found to wipe."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.remove(db_path)
|
os.remove(db_path)
|
||||||
return "Project memory wiped successfully."
|
return "Project memory wiped successfully."
|
||||||
|
|
@ -411,11 +419,11 @@ def wipe_project_memory():
|
||||||
|
|
||||||
def build_status():
|
def build_status():
|
||||||
"""Build status panel with model and feature information.
|
"""Build status panel with model and feature information.
|
||||||
|
|
||||||
Includes memory statistics at the bottom with counts of key facts, snippets, and research notes.
|
Includes memory statistics at the bottom with counts of key facts, snippets, and research notes.
|
||||||
"""
|
"""
|
||||||
status = Text()
|
status = Text()
|
||||||
|
|
||||||
# Get the config repository to get model/provider information
|
# Get the config repository to get model/provider information
|
||||||
config_repo = get_config_repository()
|
config_repo = get_config_repository()
|
||||||
provider = config_repo.get("provider", "")
|
provider = config_repo.get("provider", "")
|
||||||
|
|
@ -423,12 +431,14 @@ def build_status():
|
||||||
temperature = config_repo.get("temperature")
|
temperature = config_repo.get("temperature")
|
||||||
expert_provider = config_repo.get("expert_provider", "")
|
expert_provider = config_repo.get("expert_provider", "")
|
||||||
expert_model = config_repo.get("expert_model", "")
|
expert_model = config_repo.get("expert_model", "")
|
||||||
experimental_fallback_handler = config_repo.get("experimental_fallback_handler", False)
|
experimental_fallback_handler = config_repo.get(
|
||||||
|
"experimental_fallback_handler", False
|
||||||
|
)
|
||||||
web_research_enabled = config_repo.get("web_research_enabled", False)
|
web_research_enabled = config_repo.get("web_research_enabled", False)
|
||||||
|
|
||||||
# Get the expert enabled status
|
# Get the expert enabled status
|
||||||
expert_enabled = bool(expert_provider and expert_model)
|
expert_enabled = bool(expert_provider and expert_model)
|
||||||
|
|
||||||
# Basic model information
|
# Basic model information
|
||||||
status.append("🤖 ")
|
status.append("🤖 ")
|
||||||
status.append(f"{provider}/{model}")
|
status.append(f"{provider}/{model}")
|
||||||
|
|
@ -460,39 +470,41 @@ def build_status():
|
||||||
[fb_handler._format_model(m) for m in fb_handler.fallback_tool_models]
|
[fb_handler._format_model(m) for m in fb_handler.fallback_tool_models]
|
||||||
)
|
)
|
||||||
status.append(msg)
|
status.append(msg)
|
||||||
|
|
||||||
# Add memory statistics
|
# Add memory statistics
|
||||||
# Get counts of key facts, snippets, and research notes with error handling
|
# Get counts of key facts, snippets, and research notes with error handling
|
||||||
fact_count = 0
|
fact_count = 0
|
||||||
snippet_count = 0
|
snippet_count = 0
|
||||||
note_count = 0
|
note_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fact_count = len(get_key_fact_repository().get_all())
|
fact_count = len(get_key_fact_repository().get_all())
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.debug(f"Failed to get key facts count: {e}")
|
logger.debug(f"Failed to get key facts count: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
snippet_count = len(get_key_snippet_repository().get_all())
|
snippet_count = len(get_key_snippet_repository().get_all())
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.debug(f"Failed to get key snippets count: {e}")
|
logger.debug(f"Failed to get key snippets count: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
note_count = len(get_research_note_repository().get_all())
|
note_count = len(get_research_note_repository().get_all())
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.debug(f"Failed to get research notes count: {e}")
|
logger.debug(f"Failed to get research notes count: {e}")
|
||||||
|
|
||||||
# Add memory statistics line with reset option note
|
# Add memory statistics line with reset option note
|
||||||
status.append(f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes")
|
status.append(
|
||||||
|
f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes"
|
||||||
|
)
|
||||||
if fact_count > 0 or snippet_count > 0 or note_count > 0:
|
if fact_count > 0 or snippet_count > 0 or note_count > 0:
|
||||||
status.append(" (use --wipe-project-memory to reset)")
|
status.append(" (use --wipe-project-memory to reset)")
|
||||||
|
|
||||||
# Check for newer version
|
# Check for newer version
|
||||||
version_message = check_for_newer_version()
|
version_message = check_for_newer_version()
|
||||||
if version_message:
|
if version_message:
|
||||||
status.append("\n\n")
|
status.append("\n\n")
|
||||||
status.append(version_message, style="yellow")
|
status.append(version_message, style="yellow")
|
||||||
|
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -501,7 +513,7 @@ def main():
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
setup_logging(args.log_mode, args.pretty_logger, args.log_level)
|
setup_logging(args.log_mode, args.pretty_logger, args.log_level)
|
||||||
logger.debug("Starting RA.Aid with arguments: %s", args)
|
logger.debug("Starting RA.Aid with arguments: %s", args)
|
||||||
|
|
||||||
# Check if we need to wipe project memory before starting
|
# Check if we need to wipe project memory before starting
|
||||||
if args.wipe_project_memory:
|
if args.wipe_project_memory:
|
||||||
result = wipe_project_memory()
|
result = wipe_project_memory()
|
||||||
|
|
@ -527,7 +539,7 @@ def main():
|
||||||
|
|
||||||
# Initialize empty config dictionary to be populated later
|
# Initialize empty config dictionary to be populated later
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
# Initialize repositories with database connection
|
# Initialize repositories with database connection
|
||||||
# Create environment inventory data
|
# Create environment inventory data
|
||||||
env_discovery = EnvDiscovery()
|
env_discovery = EnvDiscovery()
|
||||||
|
|
@ -568,7 +580,9 @@ def main():
|
||||||
expert_missing,
|
expert_missing,
|
||||||
web_research_enabled,
|
web_research_enabled,
|
||||||
web_research_missing,
|
web_research_missing,
|
||||||
) = validate_environment(args) # Will exit if main env vars missing
|
) = validate_environment(
|
||||||
|
args
|
||||||
|
) # Will exit if main env vars missing
|
||||||
logger.debug("Environment validation successful")
|
logger.debug("Environment validation successful")
|
||||||
|
|
||||||
# Validate model configuration early
|
# Validate model configuration early
|
||||||
|
|
@ -604,12 +618,16 @@ def main():
|
||||||
config_repo.set("expert_provider", args.expert_provider)
|
config_repo.set("expert_provider", args.expert_provider)
|
||||||
config_repo.set("expert_model", args.expert_model)
|
config_repo.set("expert_model", args.expert_model)
|
||||||
config_repo.set("temperature", args.temperature)
|
config_repo.set("temperature", args.temperature)
|
||||||
config_repo.set("experimental_fallback_handler", args.experimental_fallback_handler)
|
config_repo.set(
|
||||||
|
"experimental_fallback_handler", args.experimental_fallback_handler
|
||||||
|
)
|
||||||
config_repo.set("web_research_enabled", web_research_enabled)
|
config_repo.set("web_research_enabled", web_research_enabled)
|
||||||
config_repo.set("show_thoughts", args.show_thoughts)
|
config_repo.set("show_thoughts", args.show_thoughts)
|
||||||
config_repo.set("show_cost", args.show_cost)
|
config_repo.set("show_cost", args.show_cost)
|
||||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
config_repo.set(
|
||||||
|
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||||
|
)
|
||||||
|
|
||||||
# Build status panel with memory statistics
|
# Build status panel with memory statistics
|
||||||
status = build_status()
|
status = build_status()
|
||||||
|
|
@ -678,13 +696,15 @@ def main():
|
||||||
initial_request = ask_human.invoke(
|
initial_request = ask_human.invoke(
|
||||||
{"question": "What would you like help with?"}
|
{"question": "What would you like help with?"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Record chat input in database (redundant as ask_human already records it,
|
# Record chat input in database (redundant as ask_human already records it,
|
||||||
# but needed in case the ask_human implementation changes)
|
# but needed in case the ask_human implementation changes)
|
||||||
try:
|
try:
|
||||||
# Using get_human_input_repository() to access the repository from context
|
# Using get_human_input_repository() to access the repository from context
|
||||||
human_input_repository = get_human_input_repository()
|
human_input_repository = get_human_input_repository()
|
||||||
human_input_repository.create(content=initial_request, source='chat')
|
human_input_repository.create(
|
||||||
|
content=initial_request, source="chat"
|
||||||
|
)
|
||||||
human_input_repository.garbage_collect()
|
human_input_repository.garbage_collect()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to record initial chat input: {str(e)}")
|
logger.error(f"Failed to record initial chat input: {str(e)}")
|
||||||
|
|
@ -742,8 +762,12 @@ def main():
|
||||||
),
|
),
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
|
key_facts=format_key_facts_dict(
|
||||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
get_key_fact_repository().get_facts_dict()
|
||||||
|
),
|
||||||
|
key_snippets=format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
),
|
||||||
project_info=formatted_project_info,
|
project_info=formatted_project_info,
|
||||||
env_inv=get_env_inv(),
|
env_inv=get_env_inv(),
|
||||||
),
|
),
|
||||||
|
|
@ -775,12 +799,12 @@ def main():
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
base_task = args.message
|
base_task = args.message
|
||||||
|
|
||||||
# Record CLI input in database
|
# Record CLI input in database
|
||||||
try:
|
try:
|
||||||
# Using get_human_input_repository() to access the repository from context
|
# Using get_human_input_repository() to access the repository from context
|
||||||
human_input_repository = get_human_input_repository()
|
human_input_repository = get_human_input_repository()
|
||||||
human_input_repository.create(content=base_task, source='cli')
|
human_input_repository.create(content=base_task, source="cli")
|
||||||
# Run garbage collection to ensure we don't exceed 100 inputs
|
# Run garbage collection to ensure we don't exceed 100 inputs
|
||||||
human_input_repository.garbage_collect()
|
human_input_repository.garbage_collect()
|
||||||
logger.debug(f"Recorded CLI input: {base_task}")
|
logger.debug(f"Recorded CLI input: {base_task}")
|
||||||
|
|
@ -814,19 +838,25 @@ def main():
|
||||||
config_repo.set("expert_model", args.expert_model)
|
config_repo.set("expert_model", args.expert_model)
|
||||||
|
|
||||||
# Store planner config with fallback to base values
|
# Store planner config with fallback to base values
|
||||||
config_repo.set("planner_provider", args.planner_provider or args.provider)
|
config_repo.set(
|
||||||
|
"planner_provider", args.planner_provider or args.provider
|
||||||
|
)
|
||||||
config_repo.set("planner_model", args.planner_model or args.model)
|
config_repo.set("planner_model", args.planner_model or args.model)
|
||||||
|
|
||||||
# Store research config with fallback to base values
|
# Store research config with fallback to base values
|
||||||
config_repo.set("research_provider", args.research_provider or args.provider)
|
config_repo.set(
|
||||||
|
"research_provider", args.research_provider or args.provider
|
||||||
|
)
|
||||||
config_repo.set("research_model", args.research_model or args.model)
|
config_repo.set("research_model", args.research_model or args.model)
|
||||||
|
|
||||||
# Store temperature in config
|
# Store temperature in config
|
||||||
config_repo.set("temperature", args.temperature)
|
config_repo.set("temperature", args.temperature)
|
||||||
|
|
||||||
# Store reasoning assistance flags
|
# Store reasoning assistance flags
|
||||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
config_repo.set(
|
||||||
|
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||||
|
)
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
# Set modification tools based on use_aider flag
|
||||||
set_modification_tools(args.use_aider)
|
set_modification_tools(args.use_aider)
|
||||||
|
|
@ -870,5 +900,6 @@ def main():
|
||||||
print()
|
print()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,14 @@
|
||||||
"""Utility functions for working with agents."""
|
"""Utility functions for working with agents."""
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
|
||||||
|
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||||
from openai import RateLimitError as OpenAIRateLimitError
|
from openai import RateLimitError as OpenAIRateLimitError
|
||||||
from litellm.exceptions import RateLimitError as LiteLLMRateLimitError
|
from litellm.exceptions import RateLimitError as LiteLLMRateLimitError
|
||||||
|
|
@ -23,28 +18,24 @@ from langchain_core.messages import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
trim_messages,
|
|
||||||
)
|
)
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from litellm import get_model_info
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
from ra_aid.agent_context import (
|
from ra_aid.agent_context import (
|
||||||
agent_context,
|
agent_context,
|
||||||
get_depth,
|
|
||||||
is_completed,
|
is_completed,
|
||||||
reset_completion_flags,
|
reset_completion_flags,
|
||||||
should_exit,
|
should_exit,
|
||||||
)
|
)
|
||||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.agents_alias import RAgents
|
from ra_aid.agents_alias import RAgents
|
||||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES
|
||||||
from ra_aid.console.formatting import print_error, print_stage_header
|
from ra_aid.console.formatting import print_error
|
||||||
from ra_aid.console.output import print_agent_output
|
from ra_aid.console.output import print_agent_output
|
||||||
from ra_aid.exceptions import (
|
from ra_aid.exceptions import (
|
||||||
AgentInterrupt,
|
AgentInterrupt,
|
||||||
|
|
@ -53,77 +44,20 @@ from ra_aid.exceptions import (
|
||||||
)
|
)
|
||||||
from ra_aid.fallback_handler import FallbackHandler
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.llm import initialize_expert_llm
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
|
||||||
from ra_aid.text.processing import process_thinking_content
|
|
||||||
from ra_aid.project_info import (
|
|
||||||
display_project_status,
|
|
||||||
format_project_info,
|
|
||||||
get_project_info,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.expert_prompts import (
|
|
||||||
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
|
|
||||||
EXPERT_PROMPT_SECTION_PLANNING,
|
|
||||||
EXPERT_PROMPT_SECTION_RESEARCH,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.human_prompts import (
|
|
||||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
|
|
||||||
HUMAN_PROMPT_SECTION_PLANNING,
|
|
||||||
HUMAN_PROMPT_SECTION_RESEARCH,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
|
||||||
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
|
||||||
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
|
||||||
from ra_aid.prompts.reasoning_assist_prompt import (
|
|
||||||
REASONING_ASSIST_PROMPT_PLANNING,
|
|
||||||
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
|
|
||||||
REASONING_ASSIST_PROMPT_RESEARCH,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.research_prompts import (
|
|
||||||
RESEARCH_ONLY_PROMPT,
|
|
||||||
RESEARCH_PROMPT,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.web_research_prompts import (
|
|
||||||
WEB_RESEARCH_PROMPT,
|
|
||||||
WEB_RESEARCH_PROMPT_SECTION_CHAT,
|
|
||||||
WEB_RESEARCH_PROMPT_SECTION_PLANNING,
|
|
||||||
WEB_RESEARCH_PROMPT_SECTION_RESEARCH,
|
|
||||||
)
|
|
||||||
from ra_aid.tool_configs import (
|
|
||||||
get_implementation_tools,
|
|
||||||
get_planning_tools,
|
|
||||||
get_research_tools,
|
|
||||||
get_web_research_tools,
|
|
||||||
)
|
|
||||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
|
||||||
from ra_aid.database.repositories.key_snippet_repository import (
|
|
||||||
get_key_snippet_repository,
|
|
||||||
)
|
|
||||||
from ra_aid.database.repositories.human_input_repository import (
|
from ra_aid.database.repositories.human_input_repository import (
|
||||||
get_human_input_repository,
|
get_human_input_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.research_note_repository import (
|
|
||||||
get_research_note_repository,
|
|
||||||
)
|
|
||||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
|
||||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
|
||||||
from ra_aid.tools.memory import (
|
|
||||||
get_related_files,
|
|
||||||
log_work_event,
|
|
||||||
)
|
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.env_inv_context import get_env_inv
|
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# Import repositories using get_* functions
|
# Import repositories using get_* functions
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -133,131 +67,19 @@ def output_markdown_message(message: str) -> str:
|
||||||
return "Message output."
|
return "Message output."
|
||||||
|
|
||||||
|
|
||||||
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
||||||
"""Helper function to estimate total tokens in a sequence of messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Sequence of messages to count tokens for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total estimated token count
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
estimate_tokens = CiaynAgent._estimate_tokens
|
|
||||||
return sum(estimate_tokens(msg) for msg in messages)
|
|
||||||
|
|
||||||
|
|
||||||
def state_modifier(
|
|
||||||
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
|
||||||
) -> list[BaseMessage]:
|
|
||||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The current agent state containing messages
|
|
||||||
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
|
||||||
"""
|
|
||||||
messages = state["messages"]
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
return []
|
|
||||||
|
|
||||||
first_message = messages[0]
|
|
||||||
remaining_messages = messages[1:]
|
|
||||||
first_tokens = estimate_messages_tokens([first_message])
|
|
||||||
new_max_tokens = max_input_tokens - first_tokens
|
|
||||||
|
|
||||||
trimmed_remaining = trim_messages(
|
|
||||||
remaining_messages,
|
|
||||||
token_counter=estimate_messages_tokens,
|
|
||||||
max_tokens=new_max_tokens,
|
|
||||||
strategy="last",
|
|
||||||
allow_partial=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [first_message] + trimmed_remaining
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(
|
|
||||||
config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""Get the token limit for the current model configuration based on agent type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[int]: The token limit if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try to get config from repository for production use
|
|
||||||
try:
|
|
||||||
config_from_repo = get_config_repository().get_all()
|
|
||||||
# If we succeeded, use the repository config instead of passed config
|
|
||||||
config = config_from_repo
|
|
||||||
except RuntimeError:
|
|
||||||
# In tests, this may fail because the repository isn't set up
|
|
||||||
# So we'll use the passed config directly
|
|
||||||
pass
|
|
||||||
if agent_type == "research":
|
|
||||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
|
||||||
model_name = config.get("research_model", "") or config.get("model", "")
|
|
||||||
elif agent_type == "planner":
|
|
||||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
|
||||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
|
||||||
else:
|
|
||||||
provider = config.get("provider", "")
|
|
||||||
model_name = config.get("model", "")
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
|
||||||
model_info = get_model_info(provider_model)
|
|
||||||
max_input_tokens = model_info.get("max_input_tokens")
|
|
||||||
if max_input_tokens:
|
|
||||||
logger.debug(
|
|
||||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
|
||||||
)
|
|
||||||
return max_input_tokens
|
|
||||||
except litellm.exceptions.NotFoundError:
|
|
||||||
logger.debug(
|
|
||||||
f"Model {model_name} not found in litellm, falling back to models_params"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(
|
|
||||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback to models_params dict
|
|
||||||
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
|
||||||
normalized_name = model_name.replace("-", "")
|
|
||||||
provider_tokens = models_params.get(provider, {})
|
|
||||||
if normalized_name in provider_tokens:
|
|
||||||
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
|
|
||||||
logger.debug(
|
|
||||||
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
max_input_tokens = None
|
|
||||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
|
||||||
|
|
||||||
return max_input_tokens
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get model token limit: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def build_agent_kwargs(
|
def build_agent_kwargs(
|
||||||
checkpointer: Optional[Any] = None,
|
checkpointer: Optional[Any] = None,
|
||||||
|
model: ChatAnthropic = None,
|
||||||
max_input_tokens: Optional[int] = None,
|
max_input_tokens: Optional[int] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Build kwargs dictionary for agent creation.
|
"""Build kwargs dictionary for agent creation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpointer: Optional memory checkpointer
|
checkpointer: Optional memory checkpointer
|
||||||
config: Optional configuration dictionary
|
model: The language model to use for token counting
|
||||||
token_limit: Optional token limit for the model
|
max_input_tokens: Optional token limit for the model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of kwargs for agent creation
|
Dictionary of kwargs for agent creation
|
||||||
|
|
@ -270,12 +92,20 @@ def build_agent_kwargs(
|
||||||
agent_kwargs["checkpointer"] = checkpointer
|
agent_kwargs["checkpointer"] = checkpointer
|
||||||
|
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
if config.get("limit_tokens", True) and is_anthropic_claude(config):
|
if (
|
||||||
|
config.get("limit_tokens", True)
|
||||||
|
and is_anthropic_claude(config)
|
||||||
|
and model is not None
|
||||||
|
):
|
||||||
|
|
||||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||||
return state_modifier(state, max_input_tokens=max_input_tokens)
|
if any(pattern in model.model for pattern in ["claude-3.5", "claude3.5", "claude-3-5"]):
|
||||||
|
return sonnet_35_state_modifier(state, max_input_tokens=max_input_tokens)
|
||||||
|
|
||||||
|
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
||||||
|
|
||||||
agent_kwargs["state_modifier"] = wrapped_state_modifier
|
agent_kwargs["state_modifier"] = wrapped_state_modifier
|
||||||
|
agent_kwargs["name"] = "React"
|
||||||
|
|
||||||
return agent_kwargs
|
return agent_kwargs
|
||||||
|
|
||||||
|
|
@ -345,7 +175,8 @@ def create_agent(
|
||||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||||
if is_anthropic_claude(config):
|
if is_anthropic_claude(config):
|
||||||
logger.debug("Using create_react_agent to instantiate agent.")
|
logger.debug("Using create_react_agent to instantiate agent.")
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
|
||||||
|
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||||
return create_react_agent(
|
return create_react_agent(
|
||||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||||
)
|
)
|
||||||
|
|
@ -358,16 +189,12 @@ def create_agent(
|
||||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||||
return create_react_agent(
|
return create_react_agent(
|
||||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from ra_aid.agents.research_agent import run_research_agent, run_web_research_agent
|
|
||||||
from ra_aid.agents.implementation_agent import run_task_implementation_agent
|
|
||||||
|
|
||||||
|
|
||||||
_CONTEXT_STACK = []
|
_CONTEXT_STACK = []
|
||||||
_INTERRUPT_CONTEXT = None
|
_INTERRUPT_CONTEXT = None
|
||||||
_FEEDBACK_MODE = False
|
_FEEDBACK_MODE = False
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,312 @@
|
||||||
|
"""Utilities for handling Anthropic-specific message formats and trimming."""
|
||||||
|
|
||||||
|
from typing import Callable, List, Literal, Optional, Sequence, Union, cast
|
||||||
|
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_message_type(
|
||||||
|
message: BaseMessage, message_types: Union[str, type, List[Union[str, type]]]
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a message is of a specific type or types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to check
|
||||||
|
message_types: Type(s) to check against (string name or class)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if message matches any of the specified types
|
||||||
|
"""
|
||||||
|
if not isinstance(message_types, list):
|
||||||
|
message_types = [message_types]
|
||||||
|
|
||||||
|
types_str = [t for t in message_types if isinstance(t, str)]
|
||||||
|
types_classes = tuple(t for t in message_types if isinstance(t, type))
|
||||||
|
|
||||||
|
return message.type in types_str or isinstance(message, types_classes)
|
||||||
|
|
||||||
|
|
||||||
|
def has_tool_use(message: BaseMessage) -> bool:
|
||||||
|
"""Check if a message contains tool use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the message contains tool use
|
||||||
|
"""
|
||||||
|
if not isinstance(message, AIMessage):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check content for tool_use
|
||||||
|
if isinstance(message.content, str) and "tool_use" in message.content:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check content list for tool_use blocks
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
for item in message.content:
|
||||||
|
if isinstance(item, dict) and item.get("type") == "tool_use":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check additional_kwargs for tool_calls
|
||||||
|
if hasattr(message, "additional_kwargs") and message.additional_kwargs.get(
|
||||||
|
"tool_calls"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_tool_pair(message1: BaseMessage, message2: BaseMessage) -> bool:
|
||||||
|
"""Check if two messages form a tool use/result pair.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message1: First message
|
||||||
|
message2: Second message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the messages form a tool use/result pair
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
isinstance(message1, AIMessage)
|
||||||
|
and isinstance(message2, ToolMessage)
|
||||||
|
and has_tool_use(message1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def anthropic_trim_messages(
|
||||||
|
messages: Sequence[BaseMessage],
|
||||||
|
*,
|
||||||
|
max_tokens: int,
|
||||||
|
token_counter: Callable[[List[BaseMessage]], int],
|
||||||
|
strategy: Literal["first", "last"] = "last",
|
||||||
|
num_messages_to_keep: int = 2,
|
||||||
|
allow_partial: bool = False,
|
||||||
|
include_system: bool = True,
|
||||||
|
start_on: Optional[Union[str, type, List[Union[str, type]]]] = None,
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Trim messages to fit within a token limit, with Anthropic-specific handling.
|
||||||
|
|
||||||
|
Warning - not fully implemented - last strategy is supported and test, not
|
||||||
|
allow partial, not 'first' strategy either.
|
||||||
|
This function is similar to langchain_core's trim_messages but with special
|
||||||
|
handling for Anthropic message formats to avoid API errors.
|
||||||
|
|
||||||
|
It always keeps the first num_messages_to_keep messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages to trim
|
||||||
|
max_tokens: Maximum number of tokens allowed
|
||||||
|
token_counter: Function to count tokens in messages
|
||||||
|
strategy: Whether to keep the "first" or "last" messages
|
||||||
|
allow_partial: Whether to allow partial messages
|
||||||
|
include_system: Whether to always include the system message
|
||||||
|
start_on: Message type to start on (only for "last" strategy)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseMessage]: Trimmed messages that fit within token limit
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
messages = list(messages)
|
||||||
|
|
||||||
|
# Always keep the first num_messages_to_keep messages
|
||||||
|
kept_messages = messages[:num_messages_to_keep]
|
||||||
|
remaining_msgs = messages[num_messages_to_keep:]
|
||||||
|
|
||||||
|
|
||||||
|
# For Anthropic, we need to maintain the conversation structure where:
|
||||||
|
# 1. Every AIMessage with tool_use must be followed by a ToolMessage
|
||||||
|
# 2. Every AIMessage that follows a ToolMessage must start with a tool_result
|
||||||
|
|
||||||
|
# First, check if we have any tool_use in the messages
|
||||||
|
has_tool_use_anywhere = any(has_tool_use(msg) for msg in messages)
|
||||||
|
|
||||||
|
# If we have tool_use anywhere, we need to be very careful about trimming
|
||||||
|
if has_tool_use_anywhere:
|
||||||
|
# For safety, just keep all messages if we're under the token limit
|
||||||
|
if token_counter(messages) <= max_tokens:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# We need to identify all tool_use/tool_result relationships
|
||||||
|
# First, find all AIMessage+ToolMessage pairs
|
||||||
|
pairs = []
|
||||||
|
i = 0
|
||||||
|
while i < len(messages) - 1:
|
||||||
|
if is_tool_pair(messages[i], messages[i + 1]):
|
||||||
|
pairs.append((i, i + 1))
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# For Anthropic, we need to ensure that:
|
||||||
|
# 1. If we include an AIMessage with tool_use, we must include the following ToolMessage
|
||||||
|
# 2. If we include a ToolMessage, we must include the preceding AIMessage with tool_use
|
||||||
|
|
||||||
|
# The safest approach is to always keep complete AIMessage+ToolMessage pairs together
|
||||||
|
# First, identify all complete pairs
|
||||||
|
complete_pairs = []
|
||||||
|
for start, end in pairs:
|
||||||
|
complete_pairs.append((start, end))
|
||||||
|
|
||||||
|
# Now we'll build our result, starting with the kept_messages
|
||||||
|
# But we need to be careful about the first message if it has tool_use
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Check if the last message in kept_messages has tool_use
|
||||||
|
if (
|
||||||
|
kept_messages
|
||||||
|
and isinstance(kept_messages[-1], AIMessage)
|
||||||
|
and has_tool_use(kept_messages[-1])
|
||||||
|
):
|
||||||
|
# We need to find the corresponding ToolMessage
|
||||||
|
for i, (ai_idx, tool_idx) in enumerate(pairs):
|
||||||
|
if messages[ai_idx] is kept_messages[-1]:
|
||||||
|
# Found the pair, add all kept_messages except the last one
|
||||||
|
result.extend(kept_messages[:-1])
|
||||||
|
# Add the AIMessage and ToolMessage as a pair
|
||||||
|
result.extend([messages[ai_idx], messages[tool_idx]])
|
||||||
|
# Remove this pair from the list of pairs to process later
|
||||||
|
pairs = pairs[:i] + pairs[i + 1 :]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# If we didn't find a matching pair, just add all kept_messages
|
||||||
|
result.extend(kept_messages)
|
||||||
|
else:
|
||||||
|
# No tool_use in the last kept message, just add all kept_messages
|
||||||
|
result.extend(kept_messages)
|
||||||
|
|
||||||
|
# If we're using the "last" strategy, we'll try to include pairs from the end
|
||||||
|
if strategy == "last":
|
||||||
|
# First collect all pairs we can include within the token limit
|
||||||
|
pairs_to_include = []
|
||||||
|
|
||||||
|
# Process pairs from the end (newest first)
|
||||||
|
for pair_idx, (ai_idx, tool_idx) in enumerate(reversed(complete_pairs)):
|
||||||
|
# Try adding this pair
|
||||||
|
test_msgs = result.copy()
|
||||||
|
|
||||||
|
# Add all previously selected pairs
|
||||||
|
for prev_ai_idx, prev_tool_idx in pairs_to_include:
|
||||||
|
test_msgs.extend([messages[prev_ai_idx], messages[prev_tool_idx]])
|
||||||
|
|
||||||
|
# Add this pair
|
||||||
|
test_msgs.extend([messages[ai_idx], messages[tool_idx]])
|
||||||
|
|
||||||
|
if token_counter(test_msgs) <= max_tokens:
|
||||||
|
# This pair fits, add it to our list
|
||||||
|
pairs_to_include.append((ai_idx, tool_idx))
|
||||||
|
else:
|
||||||
|
# This pair would exceed the token limit
|
||||||
|
break
|
||||||
|
|
||||||
|
# Now add the pairs in the correct order
|
||||||
|
# Sort by index to maintain the original conversation flow
|
||||||
|
pairs_to_include.sort(key=lambda x: x[0])
|
||||||
|
for ai_idx, tool_idx in pairs_to_include:
|
||||||
|
result.extend([messages[ai_idx], messages[tool_idx]])
|
||||||
|
|
||||||
|
# No need to sort - we've already added messages in the correct order
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# If no tool_use, proceed with normal segmentation
|
||||||
|
segments = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
# Group messages into segments
|
||||||
|
while i < len(remaining_msgs):
|
||||||
|
segments.append([remaining_msgs[i]])
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Now we have segments that maintain the required structure
|
||||||
|
# We'll add segments from the end (for "last" strategy) or beginning (for "first")
|
||||||
|
# until we hit the token limit
|
||||||
|
|
||||||
|
if strategy == "last":
|
||||||
|
# If we have no segments, just return kept_messages
|
||||||
|
if not segments:
|
||||||
|
return kept_messages
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Process segments from the end
|
||||||
|
for i, segment in enumerate(reversed(segments)):
|
||||||
|
# Try adding this segment
|
||||||
|
test_msgs = segment + result
|
||||||
|
|
||||||
|
if token_counter(kept_messages + test_msgs) <= max_tokens:
|
||||||
|
result = segment + result
|
||||||
|
else:
|
||||||
|
# This segment would exceed the token limit
|
||||||
|
break
|
||||||
|
|
||||||
|
final_result = kept_messages + result
|
||||||
|
|
||||||
|
# For Anthropic, we need to ensure the conversation follows a valid structure
|
||||||
|
# We'll do a final check of the entire conversation
|
||||||
|
|
||||||
|
# Validate the conversation structure
|
||||||
|
valid_result = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
# Process messages in order
|
||||||
|
while i < len(final_result):
|
||||||
|
current_msg = final_result[i]
|
||||||
|
|
||||||
|
# If this is an AIMessage with tool_use, it must be followed by a ToolMessage
|
||||||
|
if (
|
||||||
|
i < len(final_result) - 1
|
||||||
|
and isinstance(current_msg, AIMessage)
|
||||||
|
and has_tool_use(current_msg)
|
||||||
|
):
|
||||||
|
if isinstance(final_result[i + 1], ToolMessage):
|
||||||
|
# This is a valid tool_use + tool_result pair
|
||||||
|
valid_result.append(current_msg)
|
||||||
|
valid_result.append(final_result[i + 1])
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
# Invalid: AIMessage with tool_use not followed by ToolMessage
|
||||||
|
# Skip this message to maintain valid structure
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
# Regular message, just add it
|
||||||
|
valid_result.append(current_msg)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Final check: don't end with an AIMessage that has tool_use
|
||||||
|
if (
|
||||||
|
valid_result
|
||||||
|
and isinstance(valid_result[-1], AIMessage)
|
||||||
|
and has_tool_use(valid_result[-1])
|
||||||
|
):
|
||||||
|
valid_result.pop() # Remove the last message
|
||||||
|
|
||||||
|
return valid_result
|
||||||
|
|
||||||
|
elif strategy == "first":
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Process segments from the beginning
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
# Try adding this segment
|
||||||
|
test_msgs = result + segment
|
||||||
|
if token_counter(kept_messages + test_msgs) <= max_tokens:
|
||||||
|
result = result + segment
|
||||||
|
else:
|
||||||
|
# This segment would exceed the token limit
|
||||||
|
break
|
||||||
|
|
||||||
|
final_result = kept_messages + result
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
@ -0,0 +1,236 @@
|
||||||
|
"""Utilities for handling token limits with Anthropic models."""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
RemoveMessage,
|
||||||
|
ToolMessage,
|
||||||
|
trim_messages,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.base import message_to_dict
|
||||||
|
|
||||||
|
from ra_aid.anthropic_message_utils import (
|
||||||
|
anthropic_trim_messages,
|
||||||
|
has_tool_use,
|
||||||
|
)
|
||||||
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
|
from litellm import token_counter
|
||||||
|
|
||||||
|
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||||
|
from ra_aid.console.output import cpm, print_messages_compact
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
|
"""Helper function to estimate total tokens in a sequence of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages to count tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total estimated token count
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
estimate_tokens = CiaynAgent._estimate_tokens
|
||||||
|
return sum(estimate_tokens(msg) for msg in messages)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_message_to_litellm_format(message: BaseMessage) -> Dict:
|
||||||
|
"""Convert a BaseMessage to the format expected by litellm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The BaseMessage to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict in litellm format
|
||||||
|
"""
|
||||||
|
message_dict = message_to_dict(message)
|
||||||
|
return {
|
||||||
|
"role": message_dict["type"],
|
||||||
|
"content": message_dict["data"]["content"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_token_counter_wrapper(model: str):
|
||||||
|
"""Create a wrapper for token counter that handles BaseMessage conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model name to use for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that accepts BaseMessage objects and returns token count
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a partial function that already has the model parameter set
|
||||||
|
base_token_counter = partial(token_counter, model=model)
|
||||||
|
|
||||||
|
def wrapped_token_counter(messages: List[BaseMessage]) -> int:
|
||||||
|
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of BaseMessage objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token count for the messages
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
litellm_messages = [convert_message_to_litellm_format(msg) for msg in messages]
|
||||||
|
result = base_token_counter(messages=litellm_messages)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapped_token_counter
|
||||||
|
|
||||||
|
|
||||||
|
def state_modifier(
|
||||||
|
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||||
|
|
||||||
|
This uses anthropic_trim_messages which always keeps the first 2 messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current agent state containing messages
|
||||||
|
model: The language model to use for token counting
|
||||||
|
max_input_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = state["messages"]
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
||||||
|
|
||||||
|
result = anthropic_trim_messages(
|
||||||
|
messages,
|
||||||
|
token_counter=wrapped_token_counter,
|
||||||
|
max_tokens=max_input_tokens,
|
||||||
|
strategy="last",
|
||||||
|
allow_partial=False,
|
||||||
|
include_system=True,
|
||||||
|
num_messages_to_keep=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(result) < len(messages):
|
||||||
|
logger.info(f"Anthropic Token Limiter Trimmed: {len(messages)} messages → {len(result)} messages")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def sonnet_35_state_modifier(
|
||||||
|
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current agent state containing messages
|
||||||
|
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||||
|
"""
|
||||||
|
messages = state["messages"]
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
first_message = messages[0]
|
||||||
|
remaining_messages = messages[1:]
|
||||||
|
first_tokens = estimate_messages_tokens([first_message])
|
||||||
|
new_max_tokens = max_input_tokens - first_tokens
|
||||||
|
|
||||||
|
trimmed_remaining = trim_messages(
|
||||||
|
remaining_messages,
|
||||||
|
token_counter=estimate_messages_tokens,
|
||||||
|
max_tokens=new_max_tokens,
|
||||||
|
strategy="last",
|
||||||
|
allow_partial=False,
|
||||||
|
include_system=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = [first_message] + trimmed_remaining
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_token_limit(
|
||||||
|
config: Dict[str, Any], agent_type: str = "default"
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Get the token limit for the current model configuration based on agent type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary containing provider and model information
|
||||||
|
agent_type: Type of agent ("default", "research", or "planner")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: The token limit if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to get config from repository for production use
|
||||||
|
try:
|
||||||
|
config_from_repo = get_config_repository().get_all()
|
||||||
|
# If we succeeded, use the repository config instead of passed config
|
||||||
|
config = config_from_repo
|
||||||
|
except RuntimeError:
|
||||||
|
# In tests, this may fail because the repository isn't set up
|
||||||
|
# So we'll use the passed config directly
|
||||||
|
pass
|
||||||
|
if agent_type == "research":
|
||||||
|
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||||
|
model_name = config.get("research_model", "") or config.get("model", "")
|
||||||
|
elif agent_type == "planner":
|
||||||
|
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||||
|
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||||
|
else:
|
||||||
|
provider = config.get("provider", "")
|
||||||
|
model_name = config.get("model", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from litellm import get_model_info
|
||||||
|
|
||||||
|
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||||
|
model_info = get_model_info(provider_model)
|
||||||
|
max_input_tokens = model_info.get("max_input_tokens")
|
||||||
|
if max_input_tokens:
|
||||||
|
logger.debug(
|
||||||
|
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||||
|
)
|
||||||
|
return max_input_tokens
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback to models_params dict
|
||||||
|
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
||||||
|
normalized_name = model_name.replace("-", "")
|
||||||
|
provider_tokens = models_params.get(provider, {})
|
||||||
|
if normalized_name in provider_tokens:
|
||||||
|
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
|
||||||
|
logger.debug(
|
||||||
|
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_input_tokens = None
|
||||||
|
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||||
|
|
||||||
|
return max_input_tokens
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get model token limit: {e}")
|
||||||
|
return None
|
||||||
|
|
@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3
|
||||||
FALLBACK_TOOL_MODEL_LIMIT = 5
|
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||||
RETRY_FALLBACK_COUNT = 3
|
RETRY_FALLBACK_COUNT = 3
|
||||||
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
||||||
|
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
|
||||||
DEFAULT_SHOW_COST = False
|
DEFAULT_SHOW_COST = False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -16,4 +17,4 @@ VALID_PROVIDERS = [
|
||||||
"openai-compatible",
|
"openai-compatible",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"gemini",
|
"gemini",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
|
@ -98,3 +98,57 @@ def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -
|
||||||
"""
|
"""
|
||||||
|
|
||||||
console.print(Panel(Markdown(message), title=title, border_style=border_style))
|
console.print(Panel(Markdown(message), title=title, border_style=border_style))
|
||||||
|
|
||||||
|
|
||||||
|
def print_messages_compact(messages: Sequence[BaseMessage]) -> None:
|
||||||
|
"""Print a compact representation of a list of messages.
|
||||||
|
|
||||||
|
Warning: Used mainly for debugging purposes so do not delete if not referenced anywhere!
|
||||||
|
For all message types, only the first 30 characters of content are shown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: A sequence of BaseMessage objects to print
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
console.print("[italic]No messages[/italic]")
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
msg_type = msg.__class__.__name__
|
||||||
|
content = msg.content
|
||||||
|
|
||||||
|
# Process content based on its type
|
||||||
|
if isinstance(content, str):
|
||||||
|
display_content = f"{content[:30]}..." if len(content) > 30 else content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Handle structured content (list of content blocks)
|
||||||
|
content_preview = []
|
||||||
|
for item in content[:2]: # Show first 2 items at most
|
||||||
|
if isinstance(item, dict):
|
||||||
|
if item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
content_preview.append(f"text: {text[:20]}..." if len(text) > 20 else f"text: {text}")
|
||||||
|
elif item.get("type") == "tool_call":
|
||||||
|
tool_name = item.get("tool_call", {}).get("name", "unknown")
|
||||||
|
content_preview.append(f"tool_call: {tool_name}")
|
||||||
|
else:
|
||||||
|
content_preview.append(f"{item.get('type', 'unknown')}")
|
||||||
|
|
||||||
|
if len(content) > 2:
|
||||||
|
content_preview.append(f"...({len(content)-2} more)")
|
||||||
|
|
||||||
|
display_content = ", ".join(content_preview)
|
||||||
|
else:
|
||||||
|
display_content = str(content)[:30] + "..." if len(str(content)) > 30 else str(content)
|
||||||
|
|
||||||
|
# Add additional tool message info if available
|
||||||
|
additional_info = []
|
||||||
|
if hasattr(msg, "tool_call_id") and msg.tool_call_id:
|
||||||
|
additional_info.append(f"tool_call_id: {msg.tool_call_id}")
|
||||||
|
if hasattr(msg, "name") and msg.name:
|
||||||
|
additional_info.append(f"name: {msg.name}")
|
||||||
|
if hasattr(msg, "status") and msg.status:
|
||||||
|
additional_info.append(f"status: {msg.status}")
|
||||||
|
|
||||||
|
info_str = f" ({', '.join(additional_info)})" if additional_info else ""
|
||||||
|
console.print(f"[{i}] [bold]{msg_type}{info_str}[/bold]: {display_content}")
|
||||||
|
|
|
||||||
|
|
@ -259,8 +259,9 @@ def create_llm_client(
|
||||||
else:
|
else:
|
||||||
temp_kwargs = {}
|
temp_kwargs = {}
|
||||||
|
|
||||||
|
thinking_kwargs = {}
|
||||||
if supports_thinking:
|
if supports_thinking:
|
||||||
temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
thinking_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
||||||
|
|
||||||
if provider == "deepseek":
|
if provider == "deepseek":
|
||||||
return create_deepseek_client(
|
return create_deepseek_client(
|
||||||
|
|
@ -268,6 +269,7 @@ def create_llm_client(
|
||||||
api_key=config["api_key"],
|
api_key=config["api_key"],
|
||||||
base_url=config["base_url"],
|
base_url=config["base_url"],
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
is_expert=is_expert,
|
is_expert=is_expert,
|
||||||
)
|
)
|
||||||
elif provider == "openrouter":
|
elif provider == "openrouter":
|
||||||
|
|
@ -275,6 +277,7 @@ def create_llm_client(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config["api_key"],
|
api_key=config["api_key"],
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
is_expert=is_expert,
|
is_expert=is_expert,
|
||||||
)
|
)
|
||||||
elif provider == "openai":
|
elif provider == "openai":
|
||||||
|
|
@ -301,6 +304,7 @@ def create_llm_client(
|
||||||
max_retries=LLM_MAX_RETRIES,
|
max_retries=LLM_MAX_RETRIES,
|
||||||
max_tokens=model_config.get("max_tokens", 64000),
|
max_tokens=model_config.get("max_tokens", 64000),
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
)
|
)
|
||||||
elif provider == "openai-compatible":
|
elif provider == "openai-compatible":
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
|
|
@ -310,6 +314,7 @@ def create_llm_client(
|
||||||
timeout=LLM_REQUEST_TIMEOUT,
|
timeout=LLM_REQUEST_TIMEOUT,
|
||||||
max_retries=LLM_MAX_RETRIES,
|
max_retries=LLM_MAX_RETRIES,
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
)
|
)
|
||||||
elif provider == "gemini":
|
elif provider == "gemini":
|
||||||
return ChatGoogleGenerativeAI(
|
return ChatGoogleGenerativeAI(
|
||||||
|
|
@ -318,6 +323,7 @@ def create_llm_client(
|
||||||
timeout=LLM_REQUEST_TIMEOUT,
|
timeout=LLM_REQUEST_TIMEOUT,
|
||||||
max_retries=LLM_MAX_RETRIES,
|
max_retries=LLM_MAX_RETRIES,
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,9 @@ from ra_aid.agent_context import (
|
||||||
is_crashed,
|
is_crashed,
|
||||||
reset_completion_flags,
|
reset_completion_flags,
|
||||||
)
|
)
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
from ra_aid.console.formatting import print_error, print_task_header
|
from ra_aid.console.formatting import print_error, print_task_header
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository, get_human_input_repository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
|
@ -385,7 +386,7 @@ def request_task_implementation(task_spec: str) -> str:
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
config.get("provider", "anthropic"),
|
config.get("provider", "anthropic"),
|
||||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
config.get("model",DEFAULT_MODEL),
|
||||||
temperature=config.get("temperature"),
|
temperature=config.get("temperature"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -552,7 +553,7 @@ def request_implementation(task_spec: str) -> str:
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
config.get("provider", "anthropic"),
|
config.get("provider", "anthropic"),
|
||||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
config.get("model", DEFAULT_MODEL),
|
||||||
temperature=config.get("temperature"),
|
temperature=config.get("temperature"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -685,4 +686,4 @@ def request_implementation(task_spec: str) -> str:
|
||||||
# Join all parts into a single markdown string
|
# Join all parts into a single markdown string
|
||||||
markdown_output = "".join(markdown_parts)
|
markdown_output = "".join(markdown_parts)
|
||||||
|
|
||||||
return markdown_output
|
return markdown_output
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,18 @@ from ra_aid.agent_context import (
|
||||||
from ra_aid.agent_utils import (
|
from ra_aid.agent_utils import (
|
||||||
AgentState,
|
AgentState,
|
||||||
create_agent,
|
create_agent,
|
||||||
get_model_token_limit,
|
|
||||||
is_anthropic_claude,
|
is_anthropic_claude,
|
||||||
|
)
|
||||||
|
from ra_aid.anthropic_token_limiter import (
|
||||||
|
get_model_token_limit,
|
||||||
state_modifier,
|
state_modifier,
|
||||||
)
|
)
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
|
from ra_aid.database.repositories.config_repository import (
|
||||||
|
ConfigRepositoryManager,
|
||||||
|
get_config_repository,
|
||||||
|
config_repo_var,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -32,154 +38,91 @@ def mock_model():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config_repository():
|
def mock_config_repository():
|
||||||
"""Mock the ConfigRepository to avoid database operations during tests"""
|
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||||
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
with patch(
|
||||||
|
"ra_aid.database.repositories.config_repository.config_repo_var"
|
||||||
|
) as mock_repo_var:
|
||||||
# Setup a mock repository
|
# Setup a mock repository
|
||||||
mock_repo = MagicMock()
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
# Create a dictionary to simulate config
|
# Create a dictionary to simulate config
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
# Setup get method to return config values
|
# Setup get method to return config values
|
||||||
def get_config(key, default=None):
|
def get_config(key, default=None):
|
||||||
return config.get(key, default)
|
return config.get(key, default)
|
||||||
|
|
||||||
mock_repo.get.side_effect = get_config
|
mock_repo.get.side_effect = get_config
|
||||||
|
|
||||||
# Setup get_all method to return all config values
|
# Setup get_all method to return all config values
|
||||||
mock_repo.get_all.return_value = config
|
mock_repo.get_all.return_value = config
|
||||||
|
|
||||||
# Setup set method to update config values
|
# Setup set method to update config values
|
||||||
def set_config(key, value):
|
def set_config(key, value):
|
||||||
config[key] = value
|
config[key] = value
|
||||||
|
|
||||||
mock_repo.set.side_effect = set_config
|
mock_repo.set.side_effect = set_config
|
||||||
|
|
||||||
# Setup update method to update multiple config values
|
# Setup update method to update multiple config values
|
||||||
def update_config(update_dict):
|
def update_config(update_dict):
|
||||||
config.update(update_dict)
|
config.update(update_dict)
|
||||||
|
|
||||||
mock_repo.update.side_effect = update_config
|
mock_repo.update.side_effect = update_config
|
||||||
|
|
||||||
# Make the mock context var return our mock repo
|
# Make the mock context var return our mock repo
|
||||||
mock_repo_var.get.return_value = mock_repo
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
yield mock_repo
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_trajectory_repository():
|
def mock_trajectory_repository():
|
||||||
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||||
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
|
with patch(
|
||||||
|
"ra_aid.database.repositories.trajectory_repository.trajectory_repo_var"
|
||||||
|
) as mock_repo_var:
|
||||||
# Setup a mock repository
|
# Setup a mock repository
|
||||||
mock_repo = MagicMock()
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
# Setup create method to return a mock trajectory
|
# Setup create method to return a mock trajectory
|
||||||
def mock_create(**kwargs):
|
def mock_create(**kwargs):
|
||||||
mock_trajectory = MagicMock()
|
mock_trajectory = MagicMock()
|
||||||
mock_trajectory.id = 1
|
mock_trajectory.id = 1
|
||||||
return mock_trajectory
|
return mock_trajectory
|
||||||
|
|
||||||
mock_repo.create.side_effect = mock_create
|
mock_repo.create.side_effect = mock_create
|
||||||
|
|
||||||
# Make the mock context var return our mock repo
|
# Make the mock context var return our mock repo
|
||||||
mock_repo_var.get.return_value = mock_repo
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
yield mock_repo
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_human_input_repository():
|
def mock_human_input_repository():
|
||||||
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||||
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
|
with patch(
|
||||||
|
"ra_aid.database.repositories.human_input_repository.human_input_repo_var"
|
||||||
|
) as mock_repo_var:
|
||||||
# Setup a mock repository
|
# Setup a mock repository
|
||||||
mock_repo = MagicMock()
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
# Setup get_most_recent_id method to return a dummy ID
|
# Setup get_most_recent_id method to return a dummy ID
|
||||||
mock_repo.get_most_recent_id.return_value = 1
|
mock_repo.get_most_recent_id.return_value = 1
|
||||||
|
|
||||||
# Make the mock context var return our mock repo
|
# Make the mock context var return our mock repo
|
||||||
mock_repo_var.get.return_value = mock_repo
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
yield mock_repo
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_anthropic(mock_config_repository):
|
|
||||||
"""Test get_model_token_limit with Anthropic model."""
|
|
||||||
config = {"provider": "anthropic", "model": "claude2"}
|
|
||||||
mock_config_repository.update(config)
|
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
|
||||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_openai(mock_config_repository):
|
|
||||||
"""Test get_model_token_limit with OpenAI model."""
|
|
||||||
config = {"provider": "openai", "model": "gpt-4"}
|
|
||||||
mock_config_repository.update(config)
|
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
|
||||||
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_unknown(mock_config_repository):
|
|
||||||
"""Test get_model_token_limit with unknown provider/model."""
|
|
||||||
config = {"provider": "unknown", "model": "unknown-model"}
|
|
||||||
mock_config_repository.update(config)
|
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
|
||||||
assert token_limit is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_missing_config(mock_config_repository):
|
|
||||||
"""Test get_model_token_limit with missing configuration."""
|
|
||||||
config = {}
|
|
||||||
mock_config_repository.update(config)
|
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
|
||||||
assert token_limit is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_litellm_success():
|
|
||||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
|
||||||
config = {"provider": "anthropic", "model": "claude-2"}
|
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
|
||||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
|
||||||
assert token_limit == 100000
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_litellm_not_found():
|
|
||||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
|
||||||
config = {"provider": "anthropic", "model": "claude-2"}
|
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
|
||||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
|
||||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
|
||||||
)
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
|
||||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_litellm_error():
|
|
||||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
|
||||||
config = {"provider": "anthropic", "model": "claude-2"}
|
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
|
||||||
mock_get_info.side_effect = Exception("Unknown error")
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
|
||||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_unexpected_error():
|
|
||||||
"""Test returning None when unexpected errors occur."""
|
|
||||||
config = None # This will cause an attribute error when accessed
|
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
|
||||||
assert token_limit is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_anthropic(mock_model, mock_config_repository):
|
def test_create_agent_anthropic(mock_model, mock_config_repository):
|
||||||
"""Test create_agent with Anthropic Claude model."""
|
"""Test create_agent with Anthropic Claude model."""
|
||||||
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
|
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
|
with (
|
||||||
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||||
|
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier,
|
||||||
|
):
|
||||||
mock_react.return_value = "react_agent"
|
mock_react.return_value = "react_agent"
|
||||||
agent = create_agent(mock_model, [])
|
agent = create_agent(mock_model, [])
|
||||||
|
|
||||||
|
|
@ -187,9 +130,10 @@ def test_create_agent_anthropic(mock_model, mock_config_repository):
|
||||||
mock_react.assert_called_once_with(
|
mock_react.assert_called_once_with(
|
||||||
mock_model,
|
mock_model,
|
||||||
[],
|
[],
|
||||||
interrupt_after=['tools'],
|
interrupt_after=["tools"],
|
||||||
version="v2",
|
version="v2",
|
||||||
state_modifier=mock_react.call_args[1]["state_modifier"],
|
state_modifier=mock_react.call_args[1]["state_modifier"],
|
||||||
|
name="React",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -257,20 +201,7 @@ def mock_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_state_modifier(mock_messages):
|
# This test has been moved to test_anthropic_token_limiter.py
|
||||||
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
|
|
||||||
state = AgentState(messages=mock_messages)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"ra_aid.agent_backends.ciayn_agent.CiaynAgent._estimate_tokens"
|
|
||||||
) as mock_estimate:
|
|
||||||
mock_estimate.side_effect = lambda msg: 100 if msg else 0
|
|
||||||
|
|
||||||
result = state_modifier(state, max_input_tokens=250)
|
|
||||||
|
|
||||||
assert len(result) < len(mock_messages)
|
|
||||||
assert isinstance(result[0], SystemMessage)
|
|
||||||
assert result[-1] == mock_messages[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
||||||
|
|
@ -291,17 +222,21 @@ def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
|
def test_create_agent_anthropic_token_limiting_enabled(
|
||||||
|
mock_model, mock_config_repository
|
||||||
|
):
|
||||||
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
||||||
mock_config_repository.update({
|
mock_config_repository.update(
|
||||||
"provider": "anthropic",
|
{
|
||||||
"model": "claude-2",
|
"provider": "anthropic",
|
||||||
"limit_tokens": True,
|
"model": "claude-2",
|
||||||
})
|
"limit_tokens": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||||
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
|
||||||
):
|
):
|
||||||
mock_react.return_value = "react_agent"
|
mock_react.return_value = "react_agent"
|
||||||
mock_limit.return_value = 100000
|
mock_limit.return_value = 100000
|
||||||
|
|
@ -314,17 +249,21 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r
|
||||||
assert callable(args[1]["state_modifier"])
|
assert callable(args[1]["state_modifier"])
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
|
def test_create_agent_anthropic_token_limiting_disabled(
|
||||||
|
mock_model, mock_config_repository
|
||||||
|
):
|
||||||
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
||||||
mock_config_repository.update({
|
mock_config_repository.update(
|
||||||
"provider": "anthropic",
|
{
|
||||||
"model": "claude-2",
|
"provider": "anthropic",
|
||||||
"limit_tokens": False,
|
"model": "claude-2",
|
||||||
})
|
"limit_tokens": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||||
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
patch("ra_aid.anthropic_token_limiter.get_model_token_limit") as mock_limit,
|
||||||
):
|
):
|
||||||
mock_react.return_value = "react_agent"
|
mock_react.return_value = "react_agent"
|
||||||
mock_limit.return_value = 100000
|
mock_limit.return_value = 100000
|
||||||
|
|
@ -332,39 +271,12 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
|
||||||
agent = create_agent(mock_model, [])
|
agent = create_agent(mock_model, [])
|
||||||
|
|
||||||
assert agent == "react_agent"
|
assert agent == "react_agent"
|
||||||
mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2")
|
mock_react.assert_called_once_with(
|
||||||
|
mock_model, [], interrupt_after=["tools"], version="v2", name="React"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_research(mock_config_repository):
|
# These tests have been moved to test_anthropic_token_limiter.py
|
||||||
"""Test get_model_token_limit with research provider and model."""
|
|
||||||
config = {
|
|
||||||
"provider": "openai",
|
|
||||||
"model": "gpt-4",
|
|
||||||
"research_provider": "anthropic",
|
|
||||||
"research_model": "claude-2",
|
|
||||||
}
|
|
||||||
mock_config_repository.update(config)
|
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
|
||||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
|
||||||
token_limit = get_model_token_limit(config, "research")
|
|
||||||
assert token_limit == 150000
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_planner(mock_config_repository):
|
|
||||||
"""Test get_model_token_limit with planner provider and model."""
|
|
||||||
config = {
|
|
||||||
"provider": "openai",
|
|
||||||
"model": "gpt-4",
|
|
||||||
"planner_provider": "deepseek",
|
|
||||||
"planner_model": "dsm-1",
|
|
||||||
}
|
|
||||||
mock_config_repository.update(config)
|
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
|
||||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
|
||||||
token_limit = get_model_token_limit(config, "planner")
|
|
||||||
assert token_limit == 120000
|
|
||||||
|
|
||||||
|
|
||||||
# New tests for private helper methods in agent_utils.py
|
# New tests for private helper methods in agent_utils.py
|
||||||
|
|
@ -396,11 +308,11 @@ def test_agent_context_depth():
|
||||||
with agent_context() as ctx1:
|
with agent_context() as ctx1:
|
||||||
assert get_depth() == 0 # Root context has depth 0
|
assert get_depth() == 0 # Root context has depth 0
|
||||||
assert ctx1.depth == 0
|
assert ctx1.depth == 0
|
||||||
|
|
||||||
with agent_context() as ctx2:
|
with agent_context() as ctx2:
|
||||||
assert get_depth() == 1 # Nested context has depth 1
|
assert get_depth() == 1 # Nested context has depth 1
|
||||||
assert ctx2.depth == 1
|
assert ctx2.depth == 1
|
||||||
|
|
||||||
with agent_context() as ctx3:
|
with agent_context() as ctx3:
|
||||||
assert get_depth() == 2 # Doubly nested context has depth 2
|
assert get_depth() == 2 # Doubly nested context has depth 2
|
||||||
assert ctx3.depth == 2
|
assert ctx3.depth == 2
|
||||||
|
|
@ -418,7 +330,7 @@ def test_run_agent_stream(monkeypatch, mock_config_repository):
|
||||||
class DummyAgent:
|
class DummyAgent:
|
||||||
def stream(self, input_data, cfg: dict):
|
def stream(self, input_data, cfg: dict):
|
||||||
yield {"content": "chunk1"}
|
yield {"content": "chunk1"}
|
||||||
|
|
||||||
def get_state(self, state_config=None):
|
def get_state(self, state_config=None):
|
||||||
# Return an object with a next property set to None
|
# Return an object with a next property set to None
|
||||||
return State()
|
return State()
|
||||||
|
|
@ -469,28 +381,28 @@ def test_handle_api_error_valueerror():
|
||||||
# ValueError not containing "code" or rate limit phrases should be re-raised
|
# ValueError not containing "code" or rate limit phrases should be re-raised
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_handle_api_error(ValueError("some unrelated error"), 0, 5, 1)
|
_handle_api_error(ValueError("some unrelated error"), 0, 5, 1)
|
||||||
|
|
||||||
# ValueError with "429" should be handled without raising
|
# ValueError with "429" should be handled without raising
|
||||||
_handle_api_error(ValueError("error code 429"), 0, 5, 1)
|
_handle_api_error(ValueError("error code 429"), 0, 5, 1)
|
||||||
|
|
||||||
# ValueError with "rate limit" phrase should be handled without raising
|
# ValueError with "rate limit" phrase should be handled without raising
|
||||||
_handle_api_error(ValueError("hit rate limit"), 0, 5, 1)
|
_handle_api_error(ValueError("hit rate limit"), 0, 5, 1)
|
||||||
|
|
||||||
# ValueError with "too many requests" phrase should be handled without raising
|
# ValueError with "too many requests" phrase should be handled without raising
|
||||||
_handle_api_error(ValueError("too many requests, try later"), 0, 5, 1)
|
_handle_api_error(ValueError("too many requests, try later"), 0, 5, 1)
|
||||||
|
|
||||||
# ValueError with "quota exceeded" phrase should be handled without raising
|
# ValueError with "quota exceeded" phrase should be handled without raising
|
||||||
_handle_api_error(ValueError("quota exceeded for this month"), 0, 5, 1)
|
_handle_api_error(ValueError("quota exceeded for this month"), 0, 5, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_handle_api_error_status_code():
|
def test_handle_api_error_status_code():
|
||||||
from ra_aid.agent_utils import _handle_api_error
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
|
||||||
# Error with status_code=429 attribute should be handled without raising
|
# Error with status_code=429 attribute should be handled without raising
|
||||||
error_with_status = Exception("Rate limited")
|
error_with_status = Exception("Rate limited")
|
||||||
error_with_status.status_code = 429
|
error_with_status.status_code = 429
|
||||||
_handle_api_error(error_with_status, 0, 5, 1)
|
_handle_api_error(error_with_status, 0, 5, 1)
|
||||||
|
|
||||||
# Error with http_status=429 attribute should be handled without raising
|
# Error with http_status=429 attribute should be handled without raising
|
||||||
error_with_http_status = Exception("Too many requests")
|
error_with_http_status = Exception("Too many requests")
|
||||||
error_with_http_status.http_status = 429
|
error_with_http_status.http_status = 429
|
||||||
|
|
@ -499,16 +411,16 @@ def test_handle_api_error_status_code():
|
||||||
|
|
||||||
def test_handle_api_error_rate_limit_phrases():
|
def test_handle_api_error_rate_limit_phrases():
|
||||||
from ra_aid.agent_utils import _handle_api_error
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
|
||||||
# Generic exception with "rate limit" phrase should be handled without raising
|
# Generic exception with "rate limit" phrase should be handled without raising
|
||||||
_handle_api_error(Exception("You have exceeded your rate limit"), 0, 5, 1)
|
_handle_api_error(Exception("You have exceeded your rate limit"), 0, 5, 1)
|
||||||
|
|
||||||
# Generic exception with "too many requests" phrase should be handled without raising
|
# Generic exception with "too many requests" phrase should be handled without raising
|
||||||
_handle_api_error(Exception("Too many requests, please slow down"), 0, 5, 1)
|
_handle_api_error(Exception("Too many requests, please slow down"), 0, 5, 1)
|
||||||
|
|
||||||
# Generic exception with "quota exceeded" phrase should be handled without raising
|
# Generic exception with "quota exceeded" phrase should be handled without raising
|
||||||
_handle_api_error(Exception("API quota exceeded for this billing period"), 0, 5, 1)
|
_handle_api_error(Exception("API quota exceeded for this billing period"), 0, 5, 1)
|
||||||
|
|
||||||
# Generic exception with "rate" and "limit" separate but in message should be handled
|
# Generic exception with "rate" and "limit" separate but in message should be handled
|
||||||
_handle_api_error(Exception("You hit the rate at which we limit requests"), 0, 5, 1)
|
_handle_api_error(Exception("You hit the rate at which we limit requests"), 0, 5, 1)
|
||||||
|
|
||||||
|
|
@ -629,7 +541,9 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repos
|
||||||
assert "Agent has crashed: Test crash message" in result
|
assert "Agent has crashed: Test crash message" in result
|
||||||
|
|
||||||
|
|
||||||
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_repository):
|
def test_run_agent_with_retry_handles_badrequest_error(
|
||||||
|
monkeypatch, mock_config_repository
|
||||||
|
):
|
||||||
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
||||||
from ra_aid.agent_context import agent_context, is_crashed
|
from ra_aid.agent_context import agent_context, is_crashed
|
||||||
from ra_aid.agent_utils import run_agent_with_retry
|
from ra_aid.agent_utils import run_agent_with_retry
|
||||||
|
|
@ -687,7 +601,9 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_
|
||||||
assert is_crashed()
|
assert is_crashed()
|
||||||
|
|
||||||
|
|
||||||
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch, mock_config_repository):
|
def test_run_agent_with_retry_handles_api_badrequest_error(
|
||||||
|
monkeypatch, mock_config_repository
|
||||||
|
):
|
||||||
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
||||||
# Import APIError from anthropic module and patch it on the agent_utils module
|
# Import APIError from anthropic module and patch it on the agent_utils module
|
||||||
|
|
||||||
|
|
@ -758,7 +674,9 @@ def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch, mock_con
|
||||||
def test_handle_api_error_resource_exhausted():
|
def test_handle_api_error_resource_exhausted():
|
||||||
from google.api_core.exceptions import ResourceExhausted
|
from google.api_core.exceptions import ResourceExhausted
|
||||||
from ra_aid.agent_utils import _handle_api_error
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
|
||||||
# ResourceExhausted exception should be handled without raising
|
# ResourceExhausted exception should be handled without raising
|
||||||
resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).")
|
resource_exhausted_error = ResourceExhausted(
|
||||||
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
"429 Resource has been exhausted (e.g. check quota)."
|
||||||
|
)
|
||||||
|
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,507 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage
|
||||||
|
)
|
||||||
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
|
|
||||||
|
from ra_aid.anthropic_token_limiter import (
|
||||||
|
create_token_counter_wrapper,
|
||||||
|
estimate_messages_tokens,
|
||||||
|
get_model_token_limit,
|
||||||
|
state_modifier,
|
||||||
|
sonnet_35_state_modifier,
|
||||||
|
convert_message_to_litellm_format
|
||||||
|
)
|
||||||
|
from ra_aid.anthropic_message_utils import has_tool_use, is_tool_pair
|
||||||
|
from ra_aid.models_params import models_params, DEFAULT_TOKEN_LIMIT
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
self.mock_model = MagicMock(spec=ChatAnthropic)
|
||||||
|
self.mock_model.model = DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Sample messages for testing
|
||||||
|
self.system_message = SystemMessage(content="You are a helpful assistant.")
|
||||||
|
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
|
||||||
|
self.ai_message = AIMessage(content="I'd be happy to help! What do you need?")
|
||||||
|
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
|
||||||
|
|
||||||
|
# Create more messages for testing
|
||||||
|
self.extra_messages = [
|
||||||
|
HumanMessage(content=f"Extra message {i}") for i in range(5)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock state for testing state_modifier with many messages
|
||||||
|
self.state = AgentState(
|
||||||
|
messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages,
|
||||||
|
next=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tool-related messages for testing
|
||||||
|
self.ai_with_tool_use = AIMessage(
|
||||||
|
content="I'll use a tool to help you",
|
||||||
|
additional_kwargs={"tool_calls": [{"name": "calculator", "input": {"expression": "2+2"}}]}
|
||||||
|
)
|
||||||
|
self.tool_message = ToolMessage(
|
||||||
|
content="4",
|
||||||
|
tool_call_id="tool_call_1",
|
||||||
|
name="calculator"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_convert_message_to_litellm_format(self):
|
||||||
|
"""Test conversion of BaseMessage to litellm format."""
|
||||||
|
# Test human message
|
||||||
|
human_result = convert_message_to_litellm_format(self.human_message)
|
||||||
|
self.assertEqual(human_result["role"], "human")
|
||||||
|
self.assertEqual(human_result["content"], "Hello, can you help me with a task?")
|
||||||
|
|
||||||
|
# Test system message
|
||||||
|
system_result = convert_message_to_litellm_format(self.system_message)
|
||||||
|
self.assertEqual(system_result["role"], "system")
|
||||||
|
self.assertEqual(system_result["content"], "You are a helpful assistant.")
|
||||||
|
|
||||||
|
# Test AI message
|
||||||
|
ai_result = convert_message_to_litellm_format(self.ai_message)
|
||||||
|
self.assertEqual(ai_result["role"], "ai")
|
||||||
|
self.assertEqual(ai_result["content"], "I'd be happy to help! What do you need?")
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.token_counter")
|
||||||
|
def test_create_token_counter_wrapper(self, mock_token_counter):
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Setup mock return values
|
||||||
|
mock_token_counter.return_value = 50
|
||||||
|
|
||||||
|
# Create the wrapper
|
||||||
|
wrapper = create_token_counter_wrapper(DEFAULT_MODEL)
|
||||||
|
|
||||||
|
# Test with BaseMessage objects
|
||||||
|
result = wrapper([self.human_message])
|
||||||
|
self.assertEqual(result, 50)
|
||||||
|
|
||||||
|
# Test with empty list
|
||||||
|
result = wrapper([])
|
||||||
|
self.assertEqual(result, 0)
|
||||||
|
|
||||||
|
# Verify the mock was called with the right parameters
|
||||||
|
mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL)
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens")
|
||||||
|
def test_estimate_messages_tokens(self, mock_estimate_tokens):
|
||||||
|
# Setup mock to return different values for different messages
|
||||||
|
mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20
|
||||||
|
|
||||||
|
# Test with multiple messages
|
||||||
|
messages = [self.system_message, self.human_message]
|
||||||
|
result = estimate_messages_tokens(messages)
|
||||||
|
|
||||||
|
# Should be sum of individual token counts (10 + 20)
|
||||||
|
self.assertEqual(result, 30)
|
||||||
|
|
||||||
|
# Test with empty list
|
||||||
|
result = estimate_messages_tokens([])
|
||||||
|
self.assertEqual(result, 0)
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages")
|
||||||
|
def test_state_modifier(self, mock_trim_messages, mock_print, mock_create_wrapper):
|
||||||
|
# Setup a proper token counter function that returns integers
|
||||||
|
def token_counter(msgs):
|
||||||
|
# Return token count based on number of messages
|
||||||
|
return len(msgs) * 10
|
||||||
|
|
||||||
|
# Configure the mock to return our token counter
|
||||||
|
mock_create_wrapper.return_value = token_counter
|
||||||
|
|
||||||
|
# Configure anthropic_trim_messages to return a subset of messages
|
||||||
|
mock_trim_messages.return_value = [self.system_message, self.human_message]
|
||||||
|
|
||||||
|
# Call state_modifier with a max token limit of 50
|
||||||
|
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
|
||||||
|
|
||||||
|
# Should return what anthropic_trim_messages returned
|
||||||
|
self.assertEqual(result, [self.system_message, self.human_message])
|
||||||
|
|
||||||
|
# Verify the wrapper was created with the right model
|
||||||
|
mock_create_wrapper.assert_called_with(self.mock_model.model)
|
||||||
|
|
||||||
|
# Verify anthropic_trim_messages was called with the right parameters
|
||||||
|
mock_trim_messages.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_modifier_with_messages(self):
|
||||||
|
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
|
||||||
|
# Create a state with messages
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="System prompt"),
|
||||||
|
HumanMessage(content="Human message 1"),
|
||||||
|
AIMessage(content="AI response 1"),
|
||||||
|
HumanMessage(content="Human message 2"),
|
||||||
|
AIMessage(content="AI response 2"),
|
||||||
|
]
|
||||||
|
state = AgentState(messages=messages)
|
||||||
|
model = MagicMock(spec=ChatAnthropic)
|
||||||
|
model.model = "claude-3-opus-20240229"
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") as mock_wrapper, \
|
||||||
|
patch("ra_aid.anthropic_token_limiter.anthropic_trim_messages") as mock_trim, \
|
||||||
|
patch("ra_aid.anthropic_token_limiter.print_messages_compact"):
|
||||||
|
# Setup mock to return a fixed token count per message
|
||||||
|
mock_wrapper.return_value = lambda msgs: len(msgs) * 100
|
||||||
|
# Setup mock to return a subset of messages
|
||||||
|
mock_trim.return_value = [messages[0], messages[-2], messages[-1]]
|
||||||
|
|
||||||
|
result = state_modifier(state, model, max_input_tokens=250)
|
||||||
|
|
||||||
|
# Should return what anthropic_trim_messages returned
|
||||||
|
self.assertEqual(len(result), 3)
|
||||||
|
self.assertEqual(result[0], messages[0]) # First message preserved
|
||||||
|
self.assertEqual(result[-1], messages[-1]) # Last message preserved
|
||||||
|
|
||||||
|
def test_sonnet_35_state_modifier(self):
|
||||||
|
"""Test the sonnet 35 state modifier function."""
|
||||||
|
# Create a state with messages
|
||||||
|
state = {"messages": [self.system_message, self.human_message, self.ai_message]}
|
||||||
|
|
||||||
|
# Test with empty messages
|
||||||
|
empty_state = {"messages": []}
|
||||||
|
|
||||||
|
# Instead of patching trim_messages which has complex internal logic,
|
||||||
|
# we'll directly patch the sonnet_35_state_modifier's call to trim_messages
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.trim_messages") as mock_trim:
|
||||||
|
# Setup mock to return our desired messages
|
||||||
|
mock_trim.return_value = [self.human_message, self.ai_message]
|
||||||
|
|
||||||
|
# Test with empty messages
|
||||||
|
self.assertEqual(sonnet_35_state_modifier(empty_state), [])
|
||||||
|
|
||||||
|
# Test with messages under the limit
|
||||||
|
result = sonnet_35_state_modifier(state, max_input_tokens=10000)
|
||||||
|
|
||||||
|
# Should keep the first message and call trim_messages for the rest
|
||||||
|
self.assertEqual(len(result), 3)
|
||||||
|
self.assertEqual(result[0], self.system_message)
|
||||||
|
self.assertEqual(result[1:], [self.human_message, self.ai_message])
|
||||||
|
|
||||||
|
# Verify trim_messages was called with the right parameters
|
||||||
|
mock_trim.assert_called_once()
|
||||||
|
# We can check some of the key arguments
|
||||||
|
call_args = mock_trim.call_args[1]
|
||||||
|
# The actual value is based on the token estimation logic, not a hard-coded 9000
|
||||||
|
self.assertIn("max_tokens", call_args)
|
||||||
|
self.assertEqual(call_args["strategy"], "last")
|
||||||
|
self.assertEqual(call_args["strategy"], "last")
|
||||||
|
self.assertEqual(call_args["allow_partial"], False)
|
||||||
|
self.assertEqual(call_args["include_system"], True)
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
|
@patch("litellm.get_model_info")
|
||||||
|
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||||
|
|
||||||
|
# Mock litellm's get_model_info to return a token limit
|
||||||
|
mock_get_model_info.return_value = {"max_input_tokens": 100000}
|
||||||
|
|
||||||
|
# Test getting token limit
|
||||||
|
result = get_model_token_limit(mock_config)
|
||||||
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
|
# Verify get_model_info was called with the right model
|
||||||
|
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
|
||||||
|
|
||||||
|
def test_get_model_token_limit_research(self):
|
||||||
|
"""Test get_model_token_limit with research provider and model."""
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "gpt-4",
|
||||||
|
"research_provider": "anthropic",
|
||||||
|
"research_model": "claude-2",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
|
patch("litellm.get_model_info") as mock_get_info:
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||||
|
token_limit = get_model_token_limit(config, "research")
|
||||||
|
self.assertEqual(token_limit, 150000)
|
||||||
|
# Verify get_model_info was called with the research model
|
||||||
|
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||||
|
|
||||||
|
def test_get_model_token_limit_planner(self):
|
||||||
|
"""Test get_model_token_limit with planner provider and model."""
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "gpt-4",
|
||||||
|
"planner_provider": "deepseek",
|
||||||
|
"planner_model": "dsm-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
|
patch("litellm.get_model_info") as mock_get_info:
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||||
|
token_limit = get_model_token_limit(config, "planner")
|
||||||
|
self.assertEqual(token_limit, 120000)
|
||||||
|
# Verify get_model_info was called with the planner model
|
||||||
|
mock_get_info.assert_called_with("deepseek/dsm-1")
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
|
@patch("litellm.get_model_info")
|
||||||
|
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
|
||||||
|
# Setup mocks
|
||||||
|
mock_config = {"provider": "anthropic", "model": "claude-2"}
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||||
|
|
||||||
|
# Make litellm's get_model_info raise an exception to test fallback
|
||||||
|
mock_get_model_info.side_effect = Exception("Model not found")
|
||||||
|
|
||||||
|
# Test getting token limit from models_params fallback
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.models_params", {
|
||||||
|
"anthropic": {
|
||||||
|
"claude2": {"token_limit": 100000}
|
||||||
|
}
|
||||||
|
}):
|
||||||
|
result = get_model_token_limit(mock_config)
|
||||||
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
|
@patch("litellm.get_model_info")
|
||||||
|
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Setup mocks for different agent types
|
||||||
|
mock_config = {
|
||||||
|
"provider": "anthropic",
|
||||||
|
"model": DEFAULT_MODEL,
|
||||||
|
"research_provider": "openai",
|
||||||
|
"research_model": "gpt-4",
|
||||||
|
"planner_provider": "anthropic",
|
||||||
|
"planner_model": "claude-3-sonnet-20240229"
|
||||||
|
}
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||||
|
|
||||||
|
# Mock different returns for different models
|
||||||
|
def model_info_side_effect(model_name):
|
||||||
|
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
|
||||||
|
return {"max_input_tokens": 200000}
|
||||||
|
elif "gpt-4" in model_name:
|
||||||
|
return {"max_input_tokens": 8192}
|
||||||
|
elif "claude-3-sonnet" in model_name:
|
||||||
|
return {"max_input_tokens": 100000}
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown model: {model_name}")
|
||||||
|
|
||||||
|
mock_get_model_info.side_effect = model_info_side_effect
|
||||||
|
|
||||||
|
# Test default agent type
|
||||||
|
result = get_model_token_limit(mock_config, "default")
|
||||||
|
self.assertEqual(result, 200000)
|
||||||
|
|
||||||
|
# Test research agent type
|
||||||
|
result = get_model_token_limit(mock_config, "research")
|
||||||
|
self.assertEqual(result, 8192)
|
||||||
|
|
||||||
|
# Test planner agent type
|
||||||
|
result = get_model_token_limit(mock_config, "planner")
|
||||||
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
|
def test_get_model_token_limit_anthropic(self):
|
||||||
|
"""Test get_model_token_limit with Anthropic model."""
|
||||||
|
config = {"provider": "anthropic", "model": "claude2"}
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
token_limit = get_model_token_limit(config, "default")
|
||||||
|
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||||
|
|
||||||
|
def test_get_model_token_limit_openai(self):
|
||||||
|
"""Test get_model_token_limit with OpenAI model."""
|
||||||
|
config = {"provider": "openai", "model": "gpt-4"}
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
token_limit = get_model_token_limit(config, "default")
|
||||||
|
self.assertEqual(token_limit, models_params["openai"]["gpt-4"]["token_limit"])
|
||||||
|
|
||||||
|
def test_get_model_token_limit_unknown(self):
|
||||||
|
"""Test get_model_token_limit with unknown provider/model."""
|
||||||
|
config = {"provider": "unknown", "model": "unknown-model"}
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
token_limit = get_model_token_limit(config, "default")
|
||||||
|
self.assertIsNone(token_limit)
|
||||||
|
|
||||||
|
def test_get_model_token_limit_missing_config(self):
|
||||||
|
"""Test get_model_token_limit with missing configuration."""
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo:
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
token_limit = get_model_token_limit(config, "default")
|
||||||
|
self.assertIsNone(token_limit)
|
||||||
|
|
||||||
|
def test_get_model_token_limit_litellm_success(self):
|
||||||
|
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||||
|
config = {"provider": "anthropic", "model": "claude-2"}
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
|
patch("litellm.get_model_info") as mock_get_info:
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||||
|
token_limit = get_model_token_limit(config, "default")
|
||||||
|
self.assertEqual(token_limit, 100000)
|
||||||
|
mock_get_info.assert_called_with("anthropic/claude-2")
|
||||||
|
|
||||||
|
def test_get_model_token_limit_litellm_not_found(self):
|
||||||
|
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||||
|
config = {"provider": "anthropic", "model": "claude-2"}
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
|
patch("litellm.get_model_info") as mock_get_info:
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||||
|
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||||
|
)
|
||||||
|
token_limit = get_model_token_limit(config, "default")
|
||||||
|
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||||
|
|
||||||
|
def test_get_model_token_limit_litellm_error(self):
|
||||||
|
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||||
|
config = {"provider": "anthropic", "model": "claude-2"}
|
||||||
|
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.get_config_repository") as mock_get_config_repo, \
|
||||||
|
patch("litellm.get_model_info") as mock_get_info:
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = config
|
||||||
|
mock_get_info.side_effect = Exception("Unknown error")
|
||||||
|
token_limit = get_model_token_limit(config, "default")
|
||||||
|
self.assertEqual(token_limit, models_params["anthropic"]["claude2"]["token_limit"])
|
||||||
|
|
||||||
|
def test_get_model_token_limit_unexpected_error(self):
|
||||||
|
"""Test returning None when unexpected errors occur."""
|
||||||
|
config = None # This will cause an attribute error when accessed
|
||||||
|
|
||||||
|
token_limit = get_model_token_limit(config, "default")
|
||||||
|
self.assertIsNone(token_limit)
|
||||||
|
|
||||||
|
def test_has_tool_use(self):
|
||||||
|
"""Test the has_tool_use function."""
|
||||||
|
# Test with regular AI message
|
||||||
|
self.assertFalse(has_tool_use(self.ai_message))
|
||||||
|
|
||||||
|
# Test with AI message containing tool_use in string content
|
||||||
|
ai_with_tool_str = AIMessage(content="I'll use a tool_use to help you")
|
||||||
|
self.assertTrue(has_tool_use(ai_with_tool_str))
|
||||||
|
|
||||||
|
# Test with AI message containing tool_use in structured content
|
||||||
|
ai_with_tool_dict = AIMessage(content=[
|
||||||
|
{"type": "text", "text": "I'll use a tool to help you"},
|
||||||
|
{"type": "tool_use", "tool_use": {"name": "calculator", "input": {"expression": "2+2"}}}
|
||||||
|
])
|
||||||
|
self.assertTrue(has_tool_use(ai_with_tool_dict))
|
||||||
|
|
||||||
|
# Test with AI message containing tool_calls in additional_kwargs
|
||||||
|
self.assertTrue(has_tool_use(self.ai_with_tool_use))
|
||||||
|
|
||||||
|
# Test with non-AI message
|
||||||
|
self.assertFalse(has_tool_use(self.human_message))
|
||||||
|
|
||||||
|
def test_is_tool_pair(self):
|
||||||
|
"""Test the is_tool_pair function."""
|
||||||
|
# Test with valid tool pair
|
||||||
|
self.assertTrue(is_tool_pair(self.ai_with_tool_use, self.tool_message))
|
||||||
|
|
||||||
|
# Test with non-tool pair (wrong order)
|
||||||
|
self.assertFalse(is_tool_pair(self.tool_message, self.ai_with_tool_use))
|
||||||
|
|
||||||
|
# Test with non-tool pair (wrong types)
|
||||||
|
self.assertFalse(is_tool_pair(self.ai_message, self.human_message))
|
||||||
|
|
||||||
|
# Test with non-tool pair (AI message without tool use)
|
||||||
|
self.assertFalse(is_tool_pair(self.ai_message, self.tool_message))
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_message_utils.has_tool_use")
|
||||||
|
def test_anthropic_trim_messages_with_tool_use(self, mock_has_tool_use):
|
||||||
|
"""Test anthropic_trim_messages with a sequence of messages including tool use."""
|
||||||
|
from ra_aid.anthropic_message_utils import anthropic_trim_messages
|
||||||
|
|
||||||
|
# Setup mock for has_tool_use to return True for AI messages at even indices
|
||||||
|
def side_effect(msg):
|
||||||
|
if isinstance(msg, AIMessage) and hasattr(msg, 'test_index'):
|
||||||
|
return msg.test_index % 2 == 0 # Even indices have tool use
|
||||||
|
return False
|
||||||
|
|
||||||
|
mock_has_tool_use.side_effect = side_effect
|
||||||
|
|
||||||
|
# Create a sequence of alternating human and AI messages with tool use
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# Start with system message
|
||||||
|
system_msg = SystemMessage(content="You are a helpful assistant.")
|
||||||
|
messages.append(system_msg)
|
||||||
|
|
||||||
|
# Add alternating human and AI messages with tool use
|
||||||
|
for i in range(8):
|
||||||
|
if i % 2 == 0:
|
||||||
|
# Human message
|
||||||
|
msg = HumanMessage(content=f"Human message {i}")
|
||||||
|
messages.append(msg)
|
||||||
|
else:
|
||||||
|
# AI message, every other one has tool use
|
||||||
|
ai_msg = AIMessage(content=f"AI message {i}")
|
||||||
|
# Add a test_index attribute to track position
|
||||||
|
ai_msg.test_index = i
|
||||||
|
messages.append(ai_msg)
|
||||||
|
|
||||||
|
# If this AI message has tool use (even index), add a tool message after it
|
||||||
|
if i % 4 == 1: # 1, 5, etc.
|
||||||
|
tool_msg = ToolMessage(
|
||||||
|
content=f"Tool result {i}",
|
||||||
|
tool_call_id=f"tool_call_{i}",
|
||||||
|
name="test_tool"
|
||||||
|
)
|
||||||
|
messages.append(tool_msg)
|
||||||
|
|
||||||
|
# Define a token counter that returns a fixed value per message
|
||||||
|
def token_counter(msgs):
|
||||||
|
return len(msgs) * 1000
|
||||||
|
|
||||||
|
# Test with a token limit that will require trimming
|
||||||
|
result = anthropic_trim_messages(
|
||||||
|
messages,
|
||||||
|
token_counter=token_counter,
|
||||||
|
max_tokens=5000, # This will allow 5 messages
|
||||||
|
strategy="last",
|
||||||
|
allow_partial=False,
|
||||||
|
include_system=True,
|
||||||
|
num_messages_to_keep=2 # Keep system and first human message
|
||||||
|
)
|
||||||
|
|
||||||
|
# We should have kept the first 2 messages (system + human)
|
||||||
|
self.assertEqual(len(result), 5) # 2 kept + 3 more that fit in token limit
|
||||||
|
self.assertEqual(result[0], system_msg)
|
||||||
|
|
||||||
|
# Verify that we don't have any AI messages with tool use that aren't followed by a tool message
|
||||||
|
for i in range(len(result) - 1):
|
||||||
|
if isinstance(result[i], AIMessage) and mock_has_tool_use(result[i]):
|
||||||
|
self.assertTrue(isinstance(result[i+1], ToolMessage),
|
||||||
|
f"AI message with tool use at index {i} not followed by ToolMessage")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue