feat(main.py): refactor imports for better organization and readability
feat(main.py): add DEFAULT_MODEL constant to centralize model configuration feat(main.py): enhance logging and error handling for better debugging feat(main.py): implement state_modifier for managing token limits in agent state feat(anthropic_token_limiter.py): create utilities for handling token limits with Anthropic models feat(output.py): add print_messages_compact function for debugging message output test(anthropic_token_limiter.py): add unit tests for token limit utilities and state management
This commit is contained in:
parent
b4b0fdd686
commit
5c9a1e81d2
|
|
@ -39,32 +39,38 @@ from ra_aid.agents.research_agent import run_research_agent
|
||||||
from ra_aid.agents import run_planning_agent
|
from ra_aid.agents import run_planning_agent
|
||||||
from ra_aid.config import (
|
from ra_aid.config import (
|
||||||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||||
|
DEFAULT_MODEL,
|
||||||
DEFAULT_RECURSION_LIMIT,
|
DEFAULT_RECURSION_LIMIT,
|
||||||
DEFAULT_TEST_CMD_TIMEOUT,
|
DEFAULT_TEST_CMD_TIMEOUT,
|
||||||
VALID_PROVIDERS,
|
VALID_PROVIDERS,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import (
|
||||||
|
KeyFactRepositoryManager,
|
||||||
|
get_key_fact_repository,
|
||||||
|
)
|
||||||
from ra_aid.database.repositories.key_snippet_repository import (
|
from ra_aid.database.repositories.key_snippet_repository import (
|
||||||
KeySnippetRepositoryManager, get_key_snippet_repository
|
KeySnippetRepositoryManager,
|
||||||
|
get_key_snippet_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.human_input_repository import (
|
from ra_aid.database.repositories.human_input_repository import (
|
||||||
HumanInputRepositoryManager, get_human_input_repository
|
HumanInputRepositoryManager,
|
||||||
|
get_human_input_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.research_note_repository import (
|
from ra_aid.database.repositories.research_note_repository import (
|
||||||
ResearchNoteRepositoryManager, get_research_note_repository
|
ResearchNoteRepositoryManager,
|
||||||
|
get_research_note_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.trajectory_repository import (
|
from ra_aid.database.repositories.trajectory_repository import (
|
||||||
TrajectoryRepositoryManager, get_trajectory_repository
|
TrajectoryRepositoryManager,
|
||||||
|
get_trajectory_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.related_files_repository import (
|
from ra_aid.database.repositories.related_files_repository import (
|
||||||
RelatedFilesRepositoryManager
|
RelatedFilesRepositoryManager,
|
||||||
)
|
|
||||||
from ra_aid.database.repositories.work_log_repository import (
|
|
||||||
WorkLogRepositoryManager
|
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager
|
||||||
from ra_aid.database.repositories.config_repository import (
|
from ra_aid.database.repositories.config_repository import (
|
||||||
ConfigRepositoryManager,
|
ConfigRepositoryManager,
|
||||||
get_config_repository
|
get_config_repository,
|
||||||
)
|
)
|
||||||
from ra_aid.env_inv import EnvDiscovery
|
from ra_aid.env_inv import EnvDiscovery
|
||||||
from ra_aid.env_inv_context import EnvInvManager, get_env_inv
|
from ra_aid.env_inv_context import EnvInvManager, get_env_inv
|
||||||
|
|
@ -100,9 +106,9 @@ def launch_webui(host: str, port: int):
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments(args=None):
|
def parse_arguments(args=None):
|
||||||
ANTHROPIC_DEFAULT_MODEL = "claude-3-7-sonnet-20250219"
|
ANTHROPIC_DEFAULT_MODEL = DEFAULT_MODEL
|
||||||
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
||||||
|
|
||||||
# Case-insensitive log level argument type
|
# Case-insensitive log level argument type
|
||||||
def log_level_type(value):
|
def log_level_type(value):
|
||||||
value = value.lower()
|
value = value.lower()
|
||||||
|
|
@ -199,8 +205,10 @@ Examples:
|
||||||
help="Enable chat mode with direct human interaction (implies --hil)",
|
help="Enable chat mode with direct human interaction (implies --hil)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-mode", choices=["console", "file"], default="file",
|
"--log-mode",
|
||||||
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console"
|
choices=["console", "file"],
|
||||||
|
default="file",
|
||||||
|
help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pretty-logger", action="store_true", help="Enable pretty logging output"
|
"--pretty-logger", action="store_true", help="Enable pretty logging output"
|
||||||
|
|
@ -378,20 +386,20 @@ def is_stage_requested(stage: str) -> bool:
|
||||||
|
|
||||||
def wipe_project_memory():
|
def wipe_project_memory():
|
||||||
"""Delete the project database file to wipe all stored memory.
|
"""Delete the project database file to wipe all stored memory.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A message indicating the result of the operation
|
str: A message indicating the result of the operation
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
ra_aid_dir = Path(os.path.join(cwd, ".ra-aid"))
|
ra_aid_dir = Path(os.path.join(cwd, ".ra-aid"))
|
||||||
db_path = os.path.join(ra_aid_dir, "pk.db")
|
db_path = os.path.join(ra_aid_dir, "pk.db")
|
||||||
|
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return "No project memory found to wipe."
|
return "No project memory found to wipe."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.remove(db_path)
|
os.remove(db_path)
|
||||||
return "Project memory wiped successfully."
|
return "Project memory wiped successfully."
|
||||||
|
|
@ -403,11 +411,11 @@ def wipe_project_memory():
|
||||||
|
|
||||||
def build_status():
|
def build_status():
|
||||||
"""Build status panel with model and feature information.
|
"""Build status panel with model and feature information.
|
||||||
|
|
||||||
Includes memory statistics at the bottom with counts of key facts, snippets, and research notes.
|
Includes memory statistics at the bottom with counts of key facts, snippets, and research notes.
|
||||||
"""
|
"""
|
||||||
status = Text()
|
status = Text()
|
||||||
|
|
||||||
# Get the config repository to get model/provider information
|
# Get the config repository to get model/provider information
|
||||||
config_repo = get_config_repository()
|
config_repo = get_config_repository()
|
||||||
provider = config_repo.get("provider", "")
|
provider = config_repo.get("provider", "")
|
||||||
|
|
@ -415,12 +423,14 @@ def build_status():
|
||||||
temperature = config_repo.get("temperature")
|
temperature = config_repo.get("temperature")
|
||||||
expert_provider = config_repo.get("expert_provider", "")
|
expert_provider = config_repo.get("expert_provider", "")
|
||||||
expert_model = config_repo.get("expert_model", "")
|
expert_model = config_repo.get("expert_model", "")
|
||||||
experimental_fallback_handler = config_repo.get("experimental_fallback_handler", False)
|
experimental_fallback_handler = config_repo.get(
|
||||||
|
"experimental_fallback_handler", False
|
||||||
|
)
|
||||||
web_research_enabled = config_repo.get("web_research_enabled", False)
|
web_research_enabled = config_repo.get("web_research_enabled", False)
|
||||||
|
|
||||||
# Get the expert enabled status
|
# Get the expert enabled status
|
||||||
expert_enabled = bool(expert_provider and expert_model)
|
expert_enabled = bool(expert_provider and expert_model)
|
||||||
|
|
||||||
# Basic model information
|
# Basic model information
|
||||||
status.append("🤖 ")
|
status.append("🤖 ")
|
||||||
status.append(f"{provider}/{model}")
|
status.append(f"{provider}/{model}")
|
||||||
|
|
@ -452,39 +462,41 @@ def build_status():
|
||||||
[fb_handler._format_model(m) for m in fb_handler.fallback_tool_models]
|
[fb_handler._format_model(m) for m in fb_handler.fallback_tool_models]
|
||||||
)
|
)
|
||||||
status.append(msg)
|
status.append(msg)
|
||||||
|
|
||||||
# Add memory statistics
|
# Add memory statistics
|
||||||
# Get counts of key facts, snippets, and research notes with error handling
|
# Get counts of key facts, snippets, and research notes with error handling
|
||||||
fact_count = 0
|
fact_count = 0
|
||||||
snippet_count = 0
|
snippet_count = 0
|
||||||
note_count = 0
|
note_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fact_count = len(get_key_fact_repository().get_all())
|
fact_count = len(get_key_fact_repository().get_all())
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.debug(f"Failed to get key facts count: {e}")
|
logger.debug(f"Failed to get key facts count: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
snippet_count = len(get_key_snippet_repository().get_all())
|
snippet_count = len(get_key_snippet_repository().get_all())
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.debug(f"Failed to get key snippets count: {e}")
|
logger.debug(f"Failed to get key snippets count: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
note_count = len(get_research_note_repository().get_all())
|
note_count = len(get_research_note_repository().get_all())
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.debug(f"Failed to get research notes count: {e}")
|
logger.debug(f"Failed to get research notes count: {e}")
|
||||||
|
|
||||||
# Add memory statistics line with reset option note
|
# Add memory statistics line with reset option note
|
||||||
status.append(f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes")
|
status.append(
|
||||||
|
f"\n💾 Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes"
|
||||||
|
)
|
||||||
if fact_count > 0 or snippet_count > 0 or note_count > 0:
|
if fact_count > 0 or snippet_count > 0 or note_count > 0:
|
||||||
status.append(" (use --wipe-project-memory to reset)")
|
status.append(" (use --wipe-project-memory to reset)")
|
||||||
|
|
||||||
# Check for newer version
|
# Check for newer version
|
||||||
version_message = check_for_newer_version()
|
version_message = check_for_newer_version()
|
||||||
if version_message:
|
if version_message:
|
||||||
status.append("\n\n")
|
status.append("\n\n")
|
||||||
status.append(version_message, style="yellow")
|
status.append(version_message, style="yellow")
|
||||||
|
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -493,7 +505,7 @@ def main():
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
setup_logging(args.log_mode, args.pretty_logger, args.log_level)
|
setup_logging(args.log_mode, args.pretty_logger, args.log_level)
|
||||||
logger.debug("Starting RA.Aid with arguments: %s", args)
|
logger.debug("Starting RA.Aid with arguments: %s", args)
|
||||||
|
|
||||||
# Check if we need to wipe project memory before starting
|
# Check if we need to wipe project memory before starting
|
||||||
if args.wipe_project_memory:
|
if args.wipe_project_memory:
|
||||||
result = wipe_project_memory()
|
result = wipe_project_memory()
|
||||||
|
|
@ -519,22 +531,24 @@ def main():
|
||||||
|
|
||||||
# Initialize empty config dictionary to be populated later
|
# Initialize empty config dictionary to be populated later
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
# Initialize repositories with database connection
|
# Initialize repositories with database connection
|
||||||
# Create environment inventory data
|
# Create environment inventory data
|
||||||
env_discovery = EnvDiscovery()
|
env_discovery = EnvDiscovery()
|
||||||
env_discovery.discover()
|
env_discovery.discover()
|
||||||
env_data = env_discovery.format_markdown()
|
env_data = env_discovery.format_markdown()
|
||||||
|
|
||||||
with KeyFactRepositoryManager(db) as key_fact_repo, \
|
with (
|
||||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
KeyFactRepositoryManager(db) as key_fact_repo,
|
||||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
KeySnippetRepositoryManager(db) as key_snippet_repo,
|
||||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
HumanInputRepositoryManager(db) as human_input_repo,
|
||||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
ResearchNoteRepositoryManager(db) as research_note_repo,
|
||||||
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
RelatedFilesRepositoryManager() as related_files_repo,
|
||||||
WorkLogRepositoryManager() as work_log_repo, \
|
TrajectoryRepositoryManager(db) as trajectory_repo,
|
||||||
ConfigRepositoryManager(config) as config_repo, \
|
WorkLogRepositoryManager() as work_log_repo,
|
||||||
EnvInvManager(env_data) as env_inv:
|
ConfigRepositoryManager(config) as config_repo,
|
||||||
|
EnvInvManager(env_data) as env_inv,
|
||||||
|
):
|
||||||
# This initializes all repositories and makes them available via their respective get methods
|
# This initializes all repositories and makes them available via their respective get methods
|
||||||
logger.debug("Initialized KeyFactRepository")
|
logger.debug("Initialized KeyFactRepository")
|
||||||
logger.debug("Initialized KeySnippetRepository")
|
logger.debug("Initialized KeySnippetRepository")
|
||||||
|
|
@ -554,7 +568,9 @@ def main():
|
||||||
expert_missing,
|
expert_missing,
|
||||||
web_research_enabled,
|
web_research_enabled,
|
||||||
web_research_missing,
|
web_research_missing,
|
||||||
) = validate_environment(args) # Will exit if main env vars missing
|
) = validate_environment(
|
||||||
|
args
|
||||||
|
) # Will exit if main env vars missing
|
||||||
logger.debug("Environment validation successful")
|
logger.debug("Environment validation successful")
|
||||||
|
|
||||||
# Validate model configuration early
|
# Validate model configuration early
|
||||||
|
|
@ -590,11 +606,15 @@ def main():
|
||||||
config_repo.set("expert_provider", args.expert_provider)
|
config_repo.set("expert_provider", args.expert_provider)
|
||||||
config_repo.set("expert_model", args.expert_model)
|
config_repo.set("expert_model", args.expert_model)
|
||||||
config_repo.set("temperature", args.temperature)
|
config_repo.set("temperature", args.temperature)
|
||||||
config_repo.set("experimental_fallback_handler", args.experimental_fallback_handler)
|
config_repo.set(
|
||||||
|
"experimental_fallback_handler", args.experimental_fallback_handler
|
||||||
|
)
|
||||||
config_repo.set("web_research_enabled", web_research_enabled)
|
config_repo.set("web_research_enabled", web_research_enabled)
|
||||||
config_repo.set("show_thoughts", args.show_thoughts)
|
config_repo.set("show_thoughts", args.show_thoughts)
|
||||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
config_repo.set(
|
||||||
|
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||||
|
)
|
||||||
|
|
||||||
# Build status panel with memory statistics
|
# Build status panel with memory statistics
|
||||||
status = build_status()
|
status = build_status()
|
||||||
|
|
@ -633,13 +653,15 @@ def main():
|
||||||
initial_request = ask_human.invoke(
|
initial_request = ask_human.invoke(
|
||||||
{"question": "What would you like help with?"}
|
{"question": "What would you like help with?"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Record chat input in database (redundant as ask_human already records it,
|
# Record chat input in database (redundant as ask_human already records it,
|
||||||
# but needed in case the ask_human implementation changes)
|
# but needed in case the ask_human implementation changes)
|
||||||
try:
|
try:
|
||||||
# Using get_human_input_repository() to access the repository from context
|
# Using get_human_input_repository() to access the repository from context
|
||||||
human_input_repository = get_human_input_repository()
|
human_input_repository = get_human_input_repository()
|
||||||
human_input_repository.create(content=initial_request, source='chat')
|
human_input_repository.create(
|
||||||
|
content=initial_request, source="chat"
|
||||||
|
)
|
||||||
human_input_repository.garbage_collect()
|
human_input_repository.garbage_collect()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to record initial chat input: {str(e)}")
|
logger.error(f"Failed to record initial chat input: {str(e)}")
|
||||||
|
|
@ -668,8 +690,12 @@ def main():
|
||||||
config_repo.set("expert_model", args.expert_model)
|
config_repo.set("expert_model", args.expert_model)
|
||||||
config_repo.set("temperature", args.temperature)
|
config_repo.set("temperature", args.temperature)
|
||||||
config_repo.set("show_thoughts", args.show_thoughts)
|
config_repo.set("show_thoughts", args.show_thoughts)
|
||||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
config_repo.set(
|
||||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
"force_reasoning_assistance", args.reasoning_assistance
|
||||||
|
)
|
||||||
|
config_repo.set(
|
||||||
|
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||||
|
)
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
# Set modification tools based on use_aider flag
|
||||||
set_modification_tools(args.use_aider)
|
set_modification_tools(args.use_aider)
|
||||||
|
|
@ -696,8 +722,12 @@ def main():
|
||||||
),
|
),
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()),
|
key_facts=format_key_facts_dict(
|
||||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
get_key_fact_repository().get_facts_dict()
|
||||||
|
),
|
||||||
|
key_snippets=format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
),
|
||||||
project_info=formatted_project_info,
|
project_info=formatted_project_info,
|
||||||
env_inv=get_env_inv(),
|
env_inv=get_env_inv(),
|
||||||
),
|
),
|
||||||
|
|
@ -711,12 +741,12 @@ def main():
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
base_task = args.message
|
base_task = args.message
|
||||||
|
|
||||||
# Record CLI input in database
|
# Record CLI input in database
|
||||||
try:
|
try:
|
||||||
# Using get_human_input_repository() to access the repository from context
|
# Using get_human_input_repository() to access the repository from context
|
||||||
human_input_repository = get_human_input_repository()
|
human_input_repository = get_human_input_repository()
|
||||||
human_input_repository.create(content=base_task, source='cli')
|
human_input_repository.create(content=base_task, source="cli")
|
||||||
# Run garbage collection to ensure we don't exceed 100 inputs
|
# Run garbage collection to ensure we don't exceed 100 inputs
|
||||||
human_input_repository.garbage_collect()
|
human_input_repository.garbage_collect()
|
||||||
logger.debug(f"Recorded CLI input: {base_task}")
|
logger.debug(f"Recorded CLI input: {base_task}")
|
||||||
|
|
@ -750,19 +780,25 @@ def main():
|
||||||
config_repo.set("expert_model", args.expert_model)
|
config_repo.set("expert_model", args.expert_model)
|
||||||
|
|
||||||
# Store planner config with fallback to base values
|
# Store planner config with fallback to base values
|
||||||
config_repo.set("planner_provider", args.planner_provider or args.provider)
|
config_repo.set(
|
||||||
|
"planner_provider", args.planner_provider or args.provider
|
||||||
|
)
|
||||||
config_repo.set("planner_model", args.planner_model or args.model)
|
config_repo.set("planner_model", args.planner_model or args.model)
|
||||||
|
|
||||||
# Store research config with fallback to base values
|
# Store research config with fallback to base values
|
||||||
config_repo.set("research_provider", args.research_provider or args.provider)
|
config_repo.set(
|
||||||
|
"research_provider", args.research_provider or args.provider
|
||||||
|
)
|
||||||
config_repo.set("research_model", args.research_model or args.model)
|
config_repo.set("research_model", args.research_model or args.model)
|
||||||
|
|
||||||
# Store temperature in config
|
# Store temperature in config
|
||||||
config_repo.set("temperature", args.temperature)
|
config_repo.set("temperature", args.temperature)
|
||||||
|
|
||||||
# Store reasoning assistance flags
|
# Store reasoning assistance flags
|
||||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
config_repo.set(
|
||||||
|
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||||
|
)
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
# Set modification tools based on use_aider flag
|
||||||
set_modification_tools(args.use_aider)
|
set_modification_tools(args.use_aider)
|
||||||
|
|
@ -794,5 +830,6 @@ def main():
|
||||||
print()
|
print()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,14 @@
|
||||||
"""Utility functions for working with agents."""
|
"""Utility functions for working with agents."""
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
|
||||||
|
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||||
from openai import RateLimitError as OpenAIRateLimitError
|
from openai import RateLimitError as OpenAIRateLimitError
|
||||||
from litellm.exceptions import RateLimitError as LiteLLMRateLimitError
|
from litellm.exceptions import RateLimitError as LiteLLMRateLimitError
|
||||||
|
|
@ -23,28 +18,24 @@ from langchain_core.messages import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
trim_messages,
|
|
||||||
)
|
)
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from litellm import get_model_info
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
from ra_aid.agent_context import (
|
from ra_aid.agent_context import (
|
||||||
agent_context,
|
agent_context,
|
||||||
get_depth,
|
|
||||||
is_completed,
|
is_completed,
|
||||||
reset_completion_flags,
|
reset_completion_flags,
|
||||||
should_exit,
|
should_exit,
|
||||||
)
|
)
|
||||||
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.agents_alias import RAgents
|
from ra_aid.agents_alias import RAgents
|
||||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES
|
||||||
from ra_aid.console.formatting import print_error, print_stage_header
|
from ra_aid.console.formatting import print_error
|
||||||
from ra_aid.console.output import print_agent_output
|
from ra_aid.console.output import print_agent_output
|
||||||
from ra_aid.exceptions import (
|
from ra_aid.exceptions import (
|
||||||
AgentInterrupt,
|
AgentInterrupt,
|
||||||
|
|
@ -53,76 +44,16 @@ from ra_aid.exceptions import (
|
||||||
)
|
)
|
||||||
from ra_aid.fallback_handler import FallbackHandler
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.llm import initialize_expert_llm
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
|
||||||
from ra_aid.text.processing import process_thinking_content
|
|
||||||
from ra_aid.project_info import (
|
|
||||||
display_project_status,
|
|
||||||
format_project_info,
|
|
||||||
get_project_info,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.expert_prompts import (
|
|
||||||
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
|
|
||||||
EXPERT_PROMPT_SECTION_PLANNING,
|
|
||||||
EXPERT_PROMPT_SECTION_RESEARCH,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.human_prompts import (
|
|
||||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
|
|
||||||
HUMAN_PROMPT_SECTION_PLANNING,
|
|
||||||
HUMAN_PROMPT_SECTION_RESEARCH,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
|
||||||
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
|
||||||
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
|
||||||
from ra_aid.prompts.reasoning_assist_prompt import (
|
|
||||||
REASONING_ASSIST_PROMPT_PLANNING,
|
|
||||||
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
|
|
||||||
REASONING_ASSIST_PROMPT_RESEARCH,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.research_prompts import (
|
|
||||||
RESEARCH_ONLY_PROMPT,
|
|
||||||
RESEARCH_PROMPT,
|
|
||||||
)
|
|
||||||
from ra_aid.prompts.web_research_prompts import (
|
|
||||||
WEB_RESEARCH_PROMPT,
|
|
||||||
WEB_RESEARCH_PROMPT_SECTION_CHAT,
|
|
||||||
WEB_RESEARCH_PROMPT_SECTION_PLANNING,
|
|
||||||
WEB_RESEARCH_PROMPT_SECTION_RESEARCH,
|
|
||||||
)
|
|
||||||
from ra_aid.tool_configs import (
|
|
||||||
get_implementation_tools,
|
|
||||||
get_planning_tools,
|
|
||||||
get_research_tools,
|
|
||||||
get_web_research_tools,
|
|
||||||
)
|
|
||||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
|
||||||
from ra_aid.database.repositories.key_snippet_repository import (
|
|
||||||
get_key_snippet_repository,
|
|
||||||
)
|
|
||||||
from ra_aid.database.repositories.human_input_repository import (
|
|
||||||
get_human_input_repository,
|
|
||||||
)
|
|
||||||
from ra_aid.database.repositories.research_note_repository import (
|
|
||||||
get_research_note_repository,
|
|
||||||
)
|
|
||||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
|
||||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
|
||||||
from ra_aid.tools.memory import (
|
|
||||||
get_related_files,
|
|
||||||
log_work_event,
|
|
||||||
)
|
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.env_inv_context import get_env_inv
|
from ra_aid.anthropic_token_limiter import state_modifier, get_model_token_limit
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# Import repositories using get_* functions
|
# Import repositories using get_* functions
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|
@ -132,131 +63,19 @@ def output_markdown_message(message: str) -> str:
|
||||||
return "Message output."
|
return "Message output."
|
||||||
|
|
||||||
|
|
||||||
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
||||||
"""Helper function to estimate total tokens in a sequence of messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Sequence of messages to count tokens for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total estimated token count
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
estimate_tokens = CiaynAgent._estimate_tokens
|
|
||||||
return sum(estimate_tokens(msg) for msg in messages)
|
|
||||||
|
|
||||||
|
|
||||||
def state_modifier(
|
|
||||||
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
|
||||||
) -> list[BaseMessage]:
|
|
||||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The current agent state containing messages
|
|
||||||
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
|
||||||
"""
|
|
||||||
messages = state["messages"]
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
return []
|
|
||||||
|
|
||||||
first_message = messages[0]
|
|
||||||
remaining_messages = messages[1:]
|
|
||||||
first_tokens = estimate_messages_tokens([first_message])
|
|
||||||
new_max_tokens = max_input_tokens - first_tokens
|
|
||||||
|
|
||||||
trimmed_remaining = trim_messages(
|
|
||||||
remaining_messages,
|
|
||||||
token_counter=estimate_messages_tokens,
|
|
||||||
max_tokens=new_max_tokens,
|
|
||||||
strategy="last",
|
|
||||||
allow_partial=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [first_message] + trimmed_remaining
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(
|
|
||||||
config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""Get the token limit for the current model configuration based on agent type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[int]: The token limit if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try to get config from repository for production use
|
|
||||||
try:
|
|
||||||
config_from_repo = get_config_repository().get_all()
|
|
||||||
# If we succeeded, use the repository config instead of passed config
|
|
||||||
config = config_from_repo
|
|
||||||
except RuntimeError:
|
|
||||||
# In tests, this may fail because the repository isn't set up
|
|
||||||
# So we'll use the passed config directly
|
|
||||||
pass
|
|
||||||
if agent_type == "research":
|
|
||||||
provider = config.get("research_provider", "") or config.get("provider", "")
|
|
||||||
model_name = config.get("research_model", "") or config.get("model", "")
|
|
||||||
elif agent_type == "planner":
|
|
||||||
provider = config.get("planner_provider", "") or config.get("provider", "")
|
|
||||||
model_name = config.get("planner_model", "") or config.get("model", "")
|
|
||||||
else:
|
|
||||||
provider = config.get("provider", "")
|
|
||||||
model_name = config.get("model", "")
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
|
||||||
model_info = get_model_info(provider_model)
|
|
||||||
max_input_tokens = model_info.get("max_input_tokens")
|
|
||||||
if max_input_tokens:
|
|
||||||
logger.debug(
|
|
||||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
|
||||||
)
|
|
||||||
return max_input_tokens
|
|
||||||
except litellm.exceptions.NotFoundError:
|
|
||||||
logger.debug(
|
|
||||||
f"Model {model_name} not found in litellm, falling back to models_params"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(
|
|
||||||
f"Error getting model info from litellm: {e}, falling back to models_params"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback to models_params dict
|
|
||||||
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
|
||||||
normalized_name = model_name.replace("-", "")
|
|
||||||
provider_tokens = models_params.get(provider, {})
|
|
||||||
if normalized_name in provider_tokens:
|
|
||||||
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
|
|
||||||
logger.debug(
|
|
||||||
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
max_input_tokens = None
|
|
||||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
|
||||||
|
|
||||||
return max_input_tokens
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get model token limit: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def build_agent_kwargs(
|
def build_agent_kwargs(
|
||||||
checkpointer: Optional[Any] = None,
|
checkpointer: Optional[Any] = None,
|
||||||
|
model: ChatAnthropic = None,
|
||||||
max_input_tokens: Optional[int] = None,
|
max_input_tokens: Optional[int] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Build kwargs dictionary for agent creation.
|
"""Build kwargs dictionary for agent creation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpointer: Optional memory checkpointer
|
checkpointer: Optional memory checkpointer
|
||||||
config: Optional configuration dictionary
|
model: The language model to use for token counting
|
||||||
token_limit: Optional token limit for the model
|
max_input_tokens: Optional token limit for the model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of kwargs for agent creation
|
Dictionary of kwargs for agent creation
|
||||||
|
|
@ -269,12 +88,17 @@ def build_agent_kwargs(
|
||||||
agent_kwargs["checkpointer"] = checkpointer
|
agent_kwargs["checkpointer"] = checkpointer
|
||||||
|
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
if config.get("limit_tokens", True) and is_anthropic_claude(config):
|
if (
|
||||||
|
config.get("limit_tokens", True)
|
||||||
|
and is_anthropic_claude(config)
|
||||||
|
and model is not None
|
||||||
|
):
|
||||||
|
|
||||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||||
return state_modifier(state, max_input_tokens=max_input_tokens)
|
return state_modifier(state, model, max_input_tokens=max_input_tokens)
|
||||||
|
|
||||||
agent_kwargs["state_modifier"] = wrapped_state_modifier
|
agent_kwargs["state_modifier"] = wrapped_state_modifier
|
||||||
|
agent_kwargs["name"] = "React"
|
||||||
|
|
||||||
return agent_kwargs
|
return agent_kwargs
|
||||||
|
|
||||||
|
|
@ -340,11 +164,13 @@ def create_agent(
|
||||||
max_input_tokens = (
|
max_input_tokens = (
|
||||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||||
)
|
)
|
||||||
|
print(f"max_input_tokens={max_input_tokens}")
|
||||||
|
|
||||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||||
if is_anthropic_claude(config):
|
if is_anthropic_claude(config):
|
||||||
logger.debug("Using create_react_agent to instantiate agent.")
|
logger.debug("Using create_react_agent to instantiate agent.")
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
|
||||||
|
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||||
return create_react_agent(
|
return create_react_agent(
|
||||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||||
)
|
)
|
||||||
|
|
@ -357,16 +183,12 @@ def create_agent(
|
||||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
|
||||||
return create_react_agent(
|
return create_react_agent(
|
||||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from ra_aid.agents.research_agent import run_research_agent, run_web_research_agent
|
|
||||||
from ra_aid.agents.implementation_agent import run_task_implementation_agent
|
|
||||||
|
|
||||||
|
|
||||||
_CONTEXT_STACK = []
|
_CONTEXT_STACK = []
|
||||||
_INTERRUPT_CONTEXT = None
|
_INTERRUPT_CONTEXT = None
|
||||||
_FEEDBACK_MODE = False
|
_FEEDBACK_MODE = False
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,210 @@
|
||||||
|
"""Utilities for handling token limits with Anthropic models."""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
from langchain_core.messages import BaseMessage, trim_messages
|
||||||
|
from langchain_core.messages.base import messages_to_dict
|
||||||
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
|
from litellm import token_counter
|
||||||
|
|
||||||
|
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||||
|
from ra_aid.console.output import print_messages_compact
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
|
"""Helper function to estimate total tokens in a sequence of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages to count tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total estimated token count
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
estimate_tokens = CiaynAgent._estimate_tokens
|
||||||
|
return sum(estimate_tokens(msg) for msg in messages)
|
||||||
|
|
||||||
|
|
||||||
|
def create_token_counter_wrapper(model: str):
|
||||||
|
"""Create a wrapper for token counter that handles BaseMessage conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model name to use for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that accepts BaseMessage objects and returns token count
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a partial function that already has the model parameter set
|
||||||
|
base_token_counter = partial(token_counter, model=model)
|
||||||
|
|
||||||
|
def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int:
|
||||||
|
"""Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages (either BaseMessage objects or dicts)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token count for the messages
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if isinstance(messages[0], BaseMessage):
|
||||||
|
messages_dicts = [msg["data"] for msg in messages_to_dict(messages)]
|
||||||
|
return base_token_counter(messages=messages_dicts)
|
||||||
|
else:
|
||||||
|
# Already in dict format
|
||||||
|
return base_token_counter(messages=messages)
|
||||||
|
|
||||||
|
return wrapped_token_counter
|
||||||
|
|
||||||
|
|
||||||
|
def state_modifier(
|
||||||
|
state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
"""Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current agent state containing messages
|
||||||
|
model: The language model to use for token counting
|
||||||
|
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||||
|
"""
|
||||||
|
messages = state["messages"]
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
first_message = messages[0]
|
||||||
|
remaining_messages = messages[1:]
|
||||||
|
|
||||||
|
|
||||||
|
wrapped_token_counter = create_token_counter_wrapper(model.model)
|
||||||
|
|
||||||
|
first_tokens = wrapped_token_counter([first_message])
|
||||||
|
new_max_tokens = max_input_tokens - first_tokens
|
||||||
|
|
||||||
|
print_messages_compact(messages)
|
||||||
|
|
||||||
|
trimmed_remaining = trim_messages(
|
||||||
|
remaining_messages,
|
||||||
|
token_counter=wrapped_token_counter,
|
||||||
|
max_tokens=new_max_tokens,
|
||||||
|
strategy="last",
|
||||||
|
allow_partial=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [first_message] + trimmed_remaining
|
||||||
|
|
||||||
|
|
||||||
|
def sonnet_3_5_state_modifier(
|
||||||
|
state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current agent state containing messages
|
||||||
|
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||||
|
"""
|
||||||
|
messages = state["messages"]
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
first_message = messages[0]
|
||||||
|
remaining_messages = messages[1:]
|
||||||
|
first_tokens = estimate_messages_tokens([first_message])
|
||||||
|
new_max_tokens = max_input_tokens - first_tokens
|
||||||
|
|
||||||
|
trimmed_remaining = trim_messages(
|
||||||
|
remaining_messages,
|
||||||
|
token_counter=estimate_messages_tokens,
|
||||||
|
max_tokens=new_max_tokens,
|
||||||
|
strategy="last",
|
||||||
|
allow_partial=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [first_message] + trimmed_remaining
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_token_limit(
|
||||||
|
config: Dict[str, Any], agent_type: str = "default"
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Get the token limit for the current model configuration based on agent type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary containing provider and model information
|
||||||
|
agent_type: Type of agent ("default", "research", or "planner")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: The token limit if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to get config from repository for production use
|
||||||
|
try:
|
||||||
|
config_from_repo = get_config_repository().get_all()
|
||||||
|
# If we succeeded, use the repository config instead of passed config
|
||||||
|
config = config_from_repo
|
||||||
|
except RuntimeError:
|
||||||
|
# In tests, this may fail because the repository isn't set up
|
||||||
|
# So we'll use the passed config directly
|
||||||
|
pass
|
||||||
|
if agent_type == "research":
|
||||||
|
provider = config.get("research_provider", "") or config.get("provider", "")
|
||||||
|
model_name = config.get("research_model", "") or config.get("model", "")
|
||||||
|
elif agent_type == "planner":
|
||||||
|
provider = config.get("planner_provider", "") or config.get("provider", "")
|
||||||
|
model_name = config.get("planner_model", "") or config.get("model", "")
|
||||||
|
else:
|
||||||
|
provider = config.get("provider", "")
|
||||||
|
model_name = config.get("model", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from litellm import get_model_info
|
||||||
|
|
||||||
|
provider_model = model_name if not provider else f"{provider}/{model_name}"
|
||||||
|
model_info = get_model_info(provider_model)
|
||||||
|
max_input_tokens = model_info.get("max_input_tokens")
|
||||||
|
if max_input_tokens:
|
||||||
|
logger.debug(
|
||||||
|
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||||
|
)
|
||||||
|
return max_input_tokens
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
f"Error getting model info from litellm: {e}, falling back to models_params"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback to models_params dict
|
||||||
|
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
||||||
|
normalized_name = model_name.replace("-", "")
|
||||||
|
provider_tokens = models_params.get(provider, {})
|
||||||
|
if normalized_name in provider_tokens:
|
||||||
|
max_input_tokens = provider_tokens[normalized_name]["token_limit"]
|
||||||
|
logger.debug(
|
||||||
|
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_input_tokens = None
|
||||||
|
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||||
|
|
||||||
|
return max_input_tokens
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get model token limit: {e}")
|
||||||
|
return None
|
||||||
|
|
@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3
|
||||||
FALLBACK_TOOL_MODEL_LIMIT = 5
|
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||||
RETRY_FALLBACK_COUNT = 3
|
RETRY_FALLBACK_COUNT = 3
|
||||||
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
||||||
|
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
|
||||||
|
|
||||||
|
|
||||||
VALID_PROVIDERS = [
|
VALID_PROVIDERS = [
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
|
@ -94,3 +94,57 @@ def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -
|
||||||
"""
|
"""
|
||||||
|
|
||||||
console.print(Panel(Markdown(message), title=title, border_style=border_style))
|
console.print(Panel(Markdown(message), title=title, border_style=border_style))
|
||||||
|
|
||||||
|
|
||||||
|
def print_messages_compact(messages: Sequence[BaseMessage]) -> None:
|
||||||
|
"""Print a compact representation of a list of messages.
|
||||||
|
|
||||||
|
Warning: Used mainly for debugging purposes so do not delete if not referenced anywhere!
|
||||||
|
For all message types, only the first 30 characters of content are shown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: A sequence of BaseMessage objects to print
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
console.print("[italic]No messages[/italic]")
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
msg_type = msg.__class__.__name__
|
||||||
|
content = msg.content
|
||||||
|
|
||||||
|
# Process content based on its type
|
||||||
|
if isinstance(content, str):
|
||||||
|
display_content = f"{content[:30]}..." if len(content) > 30 else content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Handle structured content (list of content blocks)
|
||||||
|
content_preview = []
|
||||||
|
for item in content[:2]: # Show first 2 items at most
|
||||||
|
if isinstance(item, dict):
|
||||||
|
if item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
content_preview.append(f"text: {text[:20]}..." if len(text) > 20 else f"text: {text}")
|
||||||
|
elif item.get("type") == "tool_call":
|
||||||
|
tool_name = item.get("tool_call", {}).get("name", "unknown")
|
||||||
|
content_preview.append(f"tool_call: {tool_name}")
|
||||||
|
else:
|
||||||
|
content_preview.append(f"{item.get('type', 'unknown')}")
|
||||||
|
|
||||||
|
if len(content) > 2:
|
||||||
|
content_preview.append(f"...({len(content)-2} more)")
|
||||||
|
|
||||||
|
display_content = ", ".join(content_preview)
|
||||||
|
else:
|
||||||
|
display_content = str(content)[:30] + "..." if len(str(content)) > 30 else str(content)
|
||||||
|
|
||||||
|
# Add additional tool message info if available
|
||||||
|
additional_info = []
|
||||||
|
if hasattr(msg, "tool_call_id") and msg.tool_call_id:
|
||||||
|
additional_info.append(f"tool_call_id: {msg.tool_call_id}")
|
||||||
|
if hasattr(msg, "name") and msg.name:
|
||||||
|
additional_info.append(f"name: {msg.name}")
|
||||||
|
if hasattr(msg, "status") and msg.status:
|
||||||
|
additional_info.append(f"status: {msg.status}")
|
||||||
|
|
||||||
|
info_str = f" ({', '.join(additional_info)})" if additional_info else ""
|
||||||
|
console.print(f"[{i}] [bold]{msg_type}{info_str}[/bold]: {display_content}")
|
||||||
|
|
|
||||||
|
|
@ -241,8 +241,9 @@ def create_llm_client(
|
||||||
else:
|
else:
|
||||||
temp_kwargs = {}
|
temp_kwargs = {}
|
||||||
|
|
||||||
|
thinking_kwargs = {}
|
||||||
if supports_thinking:
|
if supports_thinking:
|
||||||
temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
thinking_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}}
|
||||||
|
|
||||||
if provider == "deepseek":
|
if provider == "deepseek":
|
||||||
return create_deepseek_client(
|
return create_deepseek_client(
|
||||||
|
|
@ -250,6 +251,7 @@ def create_llm_client(
|
||||||
api_key=config["api_key"],
|
api_key=config["api_key"],
|
||||||
base_url=config["base_url"],
|
base_url=config["base_url"],
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
is_expert=is_expert,
|
is_expert=is_expert,
|
||||||
)
|
)
|
||||||
elif provider == "openrouter":
|
elif provider == "openrouter":
|
||||||
|
|
@ -257,6 +259,7 @@ def create_llm_client(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config["api_key"],
|
api_key=config["api_key"],
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
is_expert=is_expert,
|
is_expert=is_expert,
|
||||||
)
|
)
|
||||||
elif provider == "openai":
|
elif provider == "openai":
|
||||||
|
|
@ -271,6 +274,7 @@ def create_llm_client(
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
**{
|
**{
|
||||||
**openai_kwargs,
|
**openai_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
"timeout": LLM_REQUEST_TIMEOUT,
|
"timeout": LLM_REQUEST_TIMEOUT,
|
||||||
"max_retries": LLM_MAX_RETRIES,
|
"max_retries": LLM_MAX_RETRIES,
|
||||||
}
|
}
|
||||||
|
|
@ -283,6 +287,7 @@ def create_llm_client(
|
||||||
max_retries=LLM_MAX_RETRIES,
|
max_retries=LLM_MAX_RETRIES,
|
||||||
max_tokens=model_config.get("max_tokens", 64000),
|
max_tokens=model_config.get("max_tokens", 64000),
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
)
|
)
|
||||||
elif provider == "openai-compatible":
|
elif provider == "openai-compatible":
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
|
|
@ -292,6 +297,7 @@ def create_llm_client(
|
||||||
timeout=LLM_REQUEST_TIMEOUT,
|
timeout=LLM_REQUEST_TIMEOUT,
|
||||||
max_retries=LLM_MAX_RETRIES,
|
max_retries=LLM_MAX_RETRIES,
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
)
|
)
|
||||||
elif provider == "gemini":
|
elif provider == "gemini":
|
||||||
return ChatGoogleGenerativeAI(
|
return ChatGoogleGenerativeAI(
|
||||||
|
|
@ -300,6 +306,7 @@ def create_llm_client(
|
||||||
timeout=LLM_REQUEST_TIMEOUT,
|
timeout=LLM_REQUEST_TIMEOUT,
|
||||||
max_retries=LLM_MAX_RETRIES,
|
max_retries=LLM_MAX_RETRIES,
|
||||||
**temp_kwargs,
|
**temp_kwargs,
|
||||||
|
**thinking_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from ra_aid.agent_context import (
|
||||||
is_crashed,
|
is_crashed,
|
||||||
reset_completion_flags,
|
reset_completion_flags,
|
||||||
)
|
)
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
from ra_aid.console.formatting import print_error
|
from ra_aid.console.formatting import print_error
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
|
|
@ -337,7 +338,7 @@ def request_task_implementation(task_spec: str) -> str:
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
config.get("provider", "anthropic"),
|
config.get("provider", "anthropic"),
|
||||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
config.get("model",DEFAULT_MODEL),
|
||||||
temperature=config.get("temperature"),
|
temperature=config.get("temperature"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -475,7 +476,7 @@ def request_implementation(task_spec: str) -> str:
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
config.get("provider", "anthropic"),
|
config.get("provider", "anthropic"),
|
||||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
config.get("model", DEFAULT_MODEL),
|
||||||
temperature=config.get("temperature"),
|
temperature=config.get("temperature"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -592,4 +593,4 @@ def request_implementation(task_spec: str) -> str:
|
||||||
# Join all parts into a single markdown string
|
# Join all parts into a single markdown string
|
||||||
markdown_output = "".join(markdown_parts)
|
markdown_output = "".join(markdown_parts)
|
||||||
|
|
||||||
return markdown_output
|
return markdown_output
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,198 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
|
|
||||||
|
from ra_aid.anthropic_token_limiter import (
|
||||||
|
create_token_counter_wrapper,
|
||||||
|
estimate_messages_tokens,
|
||||||
|
get_model_token_limit,
|
||||||
|
state_modifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicTokenLimiter(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
self.mock_model = MagicMock(spec=ChatAnthropic)
|
||||||
|
self.mock_model.model = DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Sample messages for testing
|
||||||
|
self.system_message = SystemMessage(content="You are a helpful assistant.")
|
||||||
|
self.human_message = HumanMessage(content="Hello, can you help me with a task?")
|
||||||
|
self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming
|
||||||
|
|
||||||
|
# Create more messages for testing
|
||||||
|
self.extra_messages = [
|
||||||
|
HumanMessage(content=f"Extra message {i}") for i in range(5)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock state for testing state_modifier with many messages
|
||||||
|
self.state = AgentState(
|
||||||
|
messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages,
|
||||||
|
next=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.token_counter")
|
||||||
|
def test_create_token_counter_wrapper(self, mock_token_counter):
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Setup mock return values
|
||||||
|
mock_token_counter.return_value = 50
|
||||||
|
|
||||||
|
# Create the wrapper
|
||||||
|
wrapper = create_token_counter_wrapper(DEFAULT_MODEL)
|
||||||
|
|
||||||
|
# Test with BaseMessage objects
|
||||||
|
result = wrapper([self.human_message])
|
||||||
|
self.assertEqual(result, 50)
|
||||||
|
|
||||||
|
# Test with empty list
|
||||||
|
result = wrapper([])
|
||||||
|
self.assertEqual(result, 0)
|
||||||
|
|
||||||
|
# Verify the mock was called with the right parameters
|
||||||
|
mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL)
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens")
|
||||||
|
def test_estimate_messages_tokens(self, mock_estimate_tokens):
|
||||||
|
# Setup mock to return different values for different messages
|
||||||
|
mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20
|
||||||
|
|
||||||
|
# Test with multiple messages
|
||||||
|
messages = [self.system_message, self.human_message]
|
||||||
|
result = estimate_messages_tokens(messages)
|
||||||
|
|
||||||
|
# Should be sum of individual token counts (10 + 20)
|
||||||
|
self.assertEqual(result, 30)
|
||||||
|
|
||||||
|
# Test with empty list
|
||||||
|
result = estimate_messages_tokens([])
|
||||||
|
self.assertEqual(result, 0)
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper")
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.print_messages_compact")
|
||||||
|
def test_state_modifier(self, mock_print, mock_create_wrapper):
|
||||||
|
# Setup a proper token counter function that returns integers
|
||||||
|
# This function needs to return values that will cause trim_messages to keep only the first message
|
||||||
|
def token_counter(msgs):
|
||||||
|
# For a single message, return a small token count
|
||||||
|
if len(msgs) == 1:
|
||||||
|
return 10
|
||||||
|
# For two messages (first + one more), return a value under our limit
|
||||||
|
elif len(msgs) == 2:
|
||||||
|
return 30 # This is under our 40 token remaining budget (50-10)
|
||||||
|
# For three messages, return a value just under our limit
|
||||||
|
elif len(msgs) == 3:
|
||||||
|
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
||||||
|
# For four messages, return a value just at our limit
|
||||||
|
elif len(msgs) == 4:
|
||||||
|
return 40 # This is exactly at our 40 token remaining budget (50-10)
|
||||||
|
# For five messages, return a value that exceeds our 40 token budget
|
||||||
|
elif len(msgs) == 5:
|
||||||
|
return 60 # This exceeds our 40 token budget, forcing only 4 more messages
|
||||||
|
# For more messages, return a value over our limit
|
||||||
|
else:
|
||||||
|
return 100 # This exceeds our limit
|
||||||
|
|
||||||
|
# Don't use side_effect here, directly return the function
|
||||||
|
mock_create_wrapper.return_value = token_counter
|
||||||
|
|
||||||
|
# Call state_modifier with a max token limit of 50
|
||||||
|
result = state_modifier(self.state, self.mock_model, max_input_tokens=50)
|
||||||
|
|
||||||
|
# Should keep first message and some of the others (up to 5 total)
|
||||||
|
self.assertEqual(len(result), 5) # First message plus four more
|
||||||
|
self.assertEqual(result[0], self.system_message) # First message is preserved
|
||||||
|
|
||||||
|
# Verify the wrapper was created with the right model
|
||||||
|
mock_create_wrapper.assert_called_with(self.mock_model.model)
|
||||||
|
|
||||||
|
# Verify print_messages_compact was called
|
||||||
|
mock_print.assert_called_once()
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
|
@patch("litellm.get_model_info")
|
||||||
|
def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo):
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL}
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||||
|
|
||||||
|
# Mock litellm's get_model_info to return a token limit
|
||||||
|
mock_get_model_info.return_value = {"max_input_tokens": 100000}
|
||||||
|
|
||||||
|
# Test getting token limit
|
||||||
|
result = get_model_token_limit(mock_config)
|
||||||
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
|
# Verify get_model_info was called with the right model
|
||||||
|
mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}")
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
|
@patch("litellm.get_model_info")
|
||||||
|
def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo):
|
||||||
|
# Setup mocks
|
||||||
|
mock_config = {"provider": "anthropic", "model": "claude-2"}
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||||
|
|
||||||
|
# Make litellm's get_model_info raise an exception to test fallback
|
||||||
|
mock_get_model_info.side_effect = Exception("Model not found")
|
||||||
|
|
||||||
|
# Test getting token limit from models_params fallback
|
||||||
|
with patch("ra_aid.anthropic_token_limiter.models_params", {
|
||||||
|
"anthropic": {
|
||||||
|
"claude2": {"token_limit": 100000}
|
||||||
|
}
|
||||||
|
}):
|
||||||
|
result = get_model_token_limit(mock_config)
|
||||||
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
|
@patch("ra_aid.anthropic_token_limiter.get_config_repository")
|
||||||
|
@patch("litellm.get_model_info")
|
||||||
|
def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo):
|
||||||
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Setup mocks for different agent types
|
||||||
|
mock_config = {
|
||||||
|
"provider": "anthropic",
|
||||||
|
"model": DEFAULT_MODEL,
|
||||||
|
"research_provider": "openai",
|
||||||
|
"research_model": "gpt-4",
|
||||||
|
"planner_provider": "anthropic",
|
||||||
|
"planner_model": "claude-3-sonnet-20240229"
|
||||||
|
}
|
||||||
|
mock_get_config_repo.return_value.get_all.return_value = mock_config
|
||||||
|
|
||||||
|
# Mock different returns for different models
|
||||||
|
def model_info_side_effect(model_name):
|
||||||
|
if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name:
|
||||||
|
return {"max_input_tokens": 200000}
|
||||||
|
elif "gpt-4" in model_name:
|
||||||
|
return {"max_input_tokens": 8192}
|
||||||
|
elif "claude-3-sonnet" in model_name:
|
||||||
|
return {"max_input_tokens": 100000}
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown model: {model_name}")
|
||||||
|
|
||||||
|
mock_get_model_info.side_effect = model_info_side_effect
|
||||||
|
|
||||||
|
# Test default agent type
|
||||||
|
result = get_model_token_limit(mock_config, "default")
|
||||||
|
self.assertEqual(result, 200000)
|
||||||
|
|
||||||
|
# Test research agent type
|
||||||
|
result = get_model_token_limit(mock_config, "research")
|
||||||
|
self.assertEqual(result, 8192)
|
||||||
|
|
||||||
|
# Test planner agent type
|
||||||
|
result = get_model_token_limit(mock_config, "planner")
|
||||||
|
self.assertEqual(result, 100000)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue