diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index bf42ec8..837f261 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -39,32 +39,38 @@ from ra_aid.agents.research_agent import run_research_agent from ra_aid.agents import run_planning_agent from ra_aid.config import ( DEFAULT_MAX_TEST_CMD_RETRIES, + DEFAULT_MODEL, DEFAULT_RECURSION_LIMIT, DEFAULT_TEST_CMD_TIMEOUT, VALID_PROVIDERS, ) -from ra_aid.database.repositories.key_fact_repository import KeyFactRepositoryManager, get_key_fact_repository +from ra_aid.database.repositories.key_fact_repository import ( + KeyFactRepositoryManager, + get_key_fact_repository, +) from ra_aid.database.repositories.key_snippet_repository import ( - KeySnippetRepositoryManager, get_key_snippet_repository + KeySnippetRepositoryManager, + get_key_snippet_repository, ) from ra_aid.database.repositories.human_input_repository import ( - HumanInputRepositoryManager, get_human_input_repository + HumanInputRepositoryManager, + get_human_input_repository, ) from ra_aid.database.repositories.research_note_repository import ( - ResearchNoteRepositoryManager, get_research_note_repository + ResearchNoteRepositoryManager, + get_research_note_repository, ) from ra_aid.database.repositories.trajectory_repository import ( - TrajectoryRepositoryManager, get_trajectory_repository + TrajectoryRepositoryManager, + get_trajectory_repository, ) from ra_aid.database.repositories.related_files_repository import ( - RelatedFilesRepositoryManager -) -from ra_aid.database.repositories.work_log_repository import ( - WorkLogRepositoryManager + RelatedFilesRepositoryManager, ) +from ra_aid.database.repositories.work_log_repository import WorkLogRepositoryManager from ra_aid.database.repositories.config_repository import ( ConfigRepositoryManager, - get_config_repository + get_config_repository, ) from ra_aid.env_inv import EnvDiscovery from ra_aid.env_inv_context import EnvInvManager, get_env_inv @@ -100,9 +106,9 @@ def launch_webui(host: str, port: int): def parse_arguments(args=None): - ANTHROPIC_DEFAULT_MODEL = "claude-3-7-sonnet-20250219" + ANTHROPIC_DEFAULT_MODEL = DEFAULT_MODEL OPENAI_DEFAULT_MODEL = "gpt-4o" - + # Case-insensitive log level argument type def log_level_type(value): value = value.lower() @@ -199,8 +205,10 @@ Examples: help="Enable chat mode with direct human interaction (implies --hil)", ) parser.add_argument( - "--log-mode", choices=["console", "file"], default="file", - help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console" + "--log-mode", + choices=["console", "file"], + default="file", + help="Logging mode: 'console' shows all logs in console, 'file' logs to file with only warnings+ in console", ) parser.add_argument( "--pretty-logger", action="store_true", help="Enable pretty logging output" @@ -378,20 +386,20 @@ def is_stage_requested(stage: str) -> bool: def wipe_project_memory(): """Delete the project database file to wipe all stored memory. - + Returns: str: A message indicating the result of the operation """ import os from pathlib import Path - + cwd = os.getcwd() ra_aid_dir = Path(os.path.join(cwd, ".ra-aid")) db_path = os.path.join(ra_aid_dir, "pk.db") - + if not os.path.exists(db_path): return "No project memory found to wipe." - + try: os.remove(db_path) return "Project memory wiped successfully." @@ -403,11 +411,11 @@ def wipe_project_memory(): def build_status(): """Build status panel with model and feature information. - + Includes memory statistics at the bottom with counts of key facts, snippets, and research notes. """ status = Text() - + # Get the config repository to get model/provider information config_repo = get_config_repository() provider = config_repo.get("provider", "") @@ -415,12 +423,14 @@ def build_status(): temperature = config_repo.get("temperature") expert_provider = config_repo.get("expert_provider", "") expert_model = config_repo.get("expert_model", "") - experimental_fallback_handler = config_repo.get("experimental_fallback_handler", False) + experimental_fallback_handler = config_repo.get( + "experimental_fallback_handler", False + ) web_research_enabled = config_repo.get("web_research_enabled", False) - + # Get the expert enabled status expert_enabled = bool(expert_provider and expert_model) - + # Basic model information status.append("šŸ¤– ") status.append(f"{provider}/{model}") @@ -452,39 +462,41 @@ def build_status(): [fb_handler._format_model(m) for m in fb_handler.fallback_tool_models] ) status.append(msg) - + # Add memory statistics # Get counts of key facts, snippets, and research notes with error handling fact_count = 0 snippet_count = 0 note_count = 0 - + try: fact_count = len(get_key_fact_repository().get_all()) except RuntimeError as e: logger.debug(f"Failed to get key facts count: {e}") - + try: snippet_count = len(get_key_snippet_repository().get_all()) except RuntimeError as e: logger.debug(f"Failed to get key snippets count: {e}") - + try: note_count = len(get_research_note_repository().get_all()) except RuntimeError as e: logger.debug(f"Failed to get research notes count: {e}") - + # Add memory statistics line with reset option note - status.append(f"\nšŸ’¾ Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes") + status.append( + f"\nšŸ’¾ Memory: {fact_count} facts, {snippet_count} snippets, {note_count} notes" + ) if fact_count > 0 or snippet_count > 0 or note_count > 0: status.append(" (use --wipe-project-memory to reset)") - + # Check for newer version version_message = check_for_newer_version() if version_message: status.append("\n\n") status.append(version_message, style="yellow") - + return status @@ -493,7 +505,7 @@ def main(): args = parse_arguments() setup_logging(args.log_mode, args.pretty_logger, args.log_level) logger.debug("Starting RA.Aid with arguments: %s", args) - + # Check if we need to wipe project memory before starting if args.wipe_project_memory: result = wipe_project_memory() @@ -519,22 +531,24 @@ def main(): # Initialize empty config dictionary to be populated later config = {} - + # Initialize repositories with database connection # Create environment inventory data env_discovery = EnvDiscovery() env_discovery.discover() env_data = env_discovery.format_markdown() - - with KeyFactRepositoryManager(db) as key_fact_repo, \ - KeySnippetRepositoryManager(db) as key_snippet_repo, \ - HumanInputRepositoryManager(db) as human_input_repo, \ - ResearchNoteRepositoryManager(db) as research_note_repo, \ - RelatedFilesRepositoryManager() as related_files_repo, \ - TrajectoryRepositoryManager(db) as trajectory_repo, \ - WorkLogRepositoryManager() as work_log_repo, \ - ConfigRepositoryManager(config) as config_repo, \ - EnvInvManager(env_data) as env_inv: + + with ( + KeyFactRepositoryManager(db) as key_fact_repo, + KeySnippetRepositoryManager(db) as key_snippet_repo, + HumanInputRepositoryManager(db) as human_input_repo, + ResearchNoteRepositoryManager(db) as research_note_repo, + RelatedFilesRepositoryManager() as related_files_repo, + TrajectoryRepositoryManager(db) as trajectory_repo, + WorkLogRepositoryManager() as work_log_repo, + ConfigRepositoryManager(config) as config_repo, + EnvInvManager(env_data) as env_inv, + ): # This initializes all repositories and makes them available via their respective get methods logger.debug("Initialized KeyFactRepository") logger.debug("Initialized KeySnippetRepository") @@ -554,7 +568,9 @@ def main(): expert_missing, web_research_enabled, web_research_missing, - ) = validate_environment(args) # Will exit if main env vars missing + ) = validate_environment( + args + ) # Will exit if main env vars missing logger.debug("Environment validation successful") # Validate model configuration early @@ -590,11 +606,15 @@ def main(): config_repo.set("expert_provider", args.expert_provider) config_repo.set("expert_model", args.expert_model) config_repo.set("temperature", args.temperature) - config_repo.set("experimental_fallback_handler", args.experimental_fallback_handler) + config_repo.set( + "experimental_fallback_handler", args.experimental_fallback_handler + ) config_repo.set("web_research_enabled", web_research_enabled) config_repo.set("show_thoughts", args.show_thoughts) config_repo.set("force_reasoning_assistance", args.reasoning_assistance) - config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) + config_repo.set( + "disable_reasoning_assistance", args.no_reasoning_assistance + ) # Build status panel with memory statistics status = build_status() @@ -633,13 +653,15 @@ def main(): initial_request = ask_human.invoke( {"question": "What would you like help with?"} ) - + # Record chat input in database (redundant as ask_human already records it, # but needed in case the ask_human implementation changes) try: # Using get_human_input_repository() to access the repository from context human_input_repository = get_human_input_repository() - human_input_repository.create(content=initial_request, source='chat') + human_input_repository.create( + content=initial_request, source="chat" + ) human_input_repository.garbage_collect() except Exception as e: logger.error(f"Failed to record initial chat input: {str(e)}") @@ -668,8 +690,12 @@ def main(): config_repo.set("expert_model", args.expert_model) config_repo.set("temperature", args.temperature) config_repo.set("show_thoughts", args.show_thoughts) - config_repo.set("force_reasoning_assistance", args.reasoning_assistance) - config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) + config_repo.set( + "force_reasoning_assistance", args.reasoning_assistance + ) + config_repo.set( + "disable_reasoning_assistance", args.no_reasoning_assistance + ) # Set modification tools based on use_aider flag set_modification_tools(args.use_aider) @@ -696,8 +722,12 @@ def main(): ), working_directory=working_directory, current_date=current_date, - key_facts=format_key_facts_dict(get_key_fact_repository().get_facts_dict()), - key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), + key_facts=format_key_facts_dict( + get_key_fact_repository().get_facts_dict() + ), + key_snippets=format_key_snippets_dict( + get_key_snippet_repository().get_snippets_dict() + ), project_info=formatted_project_info, env_inv=get_env_inv(), ), @@ -711,12 +741,12 @@ def main(): sys.exit(1) base_task = args.message - + # Record CLI input in database try: # Using get_human_input_repository() to access the repository from context human_input_repository = get_human_input_repository() - human_input_repository.create(content=base_task, source='cli') + human_input_repository.create(content=base_task, source="cli") # Run garbage collection to ensure we don't exceed 100 inputs human_input_repository.garbage_collect() logger.debug(f"Recorded CLI input: {base_task}") @@ -750,19 +780,25 @@ def main(): config_repo.set("expert_model", args.expert_model) # Store planner config with fallback to base values - config_repo.set("planner_provider", args.planner_provider or args.provider) + config_repo.set( + "planner_provider", args.planner_provider or args.provider + ) config_repo.set("planner_model", args.planner_model or args.model) # Store research config with fallback to base values - config_repo.set("research_provider", args.research_provider or args.provider) + config_repo.set( + "research_provider", args.research_provider or args.provider + ) config_repo.set("research_model", args.research_model or args.model) # Store temperature in config config_repo.set("temperature", args.temperature) - + # Store reasoning assistance flags config_repo.set("force_reasoning_assistance", args.reasoning_assistance) - config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance) + config_repo.set( + "disable_reasoning_assistance", args.no_reasoning_assistance + ) # Set modification tools based on use_aider flag set_modification_tools(args.use_aider) @@ -794,5 +830,6 @@ def main(): print() sys.exit(0) + if __name__ == "__main__": main() diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index e325745..0b18d34 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -1,19 +1,14 @@ """Utility functions for working with agents.""" -import inspect -import os import signal import sys import threading import time -import uuid -from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional +from langchain_anthropic import ChatAnthropic from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler - -import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from openai import RateLimitError as OpenAIRateLimitError from litellm.exceptions import RateLimitError as LiteLLMRateLimitError @@ -23,28 +18,24 @@ from langchain_core.messages import ( BaseMessage, HumanMessage, SystemMessage, - trim_messages, ) from langchain_core.tools import tool -from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from langgraph.prebuilt.chat_agent_executor import AgentState -from litellm import get_model_info from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel from ra_aid.agent_context import ( agent_context, - get_depth, is_completed, reset_completion_flags, should_exit, ) from ra_aid.agent_backends.ciayn_agent import CiaynAgent from ra_aid.agents_alias import RAgents -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT -from ra_aid.console.formatting import print_error, print_stage_header +from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES +from ra_aid.console.formatting import print_error from ra_aid.console.output import print_agent_output from ra_aid.exceptions import ( AgentInterrupt, @@ -53,76 +44,16 @@ from ra_aid.exceptions import ( ) from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger -from ra_aid.llm import initialize_expert_llm -from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params -from ra_aid.text.processing import process_thinking_content -from ra_aid.project_info import ( - display_project_status, - format_project_info, - get_project_info, -) -from ra_aid.prompts.expert_prompts import ( - EXPERT_PROMPT_SECTION_IMPLEMENTATION, - EXPERT_PROMPT_SECTION_PLANNING, - EXPERT_PROMPT_SECTION_RESEARCH, -) -from ra_aid.prompts.human_prompts import ( - HUMAN_PROMPT_SECTION_IMPLEMENTATION, - HUMAN_PROMPT_SECTION_PLANNING, - HUMAN_PROMPT_SECTION_RESEARCH, -) -from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT -from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS -from ra_aid.prompts.planning_prompts import PLANNING_PROMPT -from ra_aid.prompts.reasoning_assist_prompt import ( - REASONING_ASSIST_PROMPT_PLANNING, - REASONING_ASSIST_PROMPT_IMPLEMENTATION, - REASONING_ASSIST_PROMPT_RESEARCH, -) -from ra_aid.prompts.research_prompts import ( - RESEARCH_ONLY_PROMPT, - RESEARCH_PROMPT, -) -from ra_aid.prompts.web_research_prompts import ( - WEB_RESEARCH_PROMPT, - WEB_RESEARCH_PROMPT_SECTION_CHAT, - WEB_RESEARCH_PROMPT_SECTION_PLANNING, - WEB_RESEARCH_PROMPT_SECTION_RESEARCH, -) -from ra_aid.tool_configs import ( - get_implementation_tools, - get_planning_tools, - get_research_tools, - get_web_research_tools, -) +from ra_aid.models_params import DEFAULT_TOKEN_LIMIT from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command -from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository -from ra_aid.database.repositories.key_snippet_repository import ( - get_key_snippet_repository, -) -from ra_aid.database.repositories.human_input_repository import ( - get_human_input_repository, -) -from ra_aid.database.repositories.research_note_repository import ( - get_research_note_repository, -) -from ra_aid.database.repositories.work_log_repository import get_work_log_repository -from ra_aid.model_formatters import format_key_facts_dict -from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict -from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict -from ra_aid.tools.memory import ( - get_related_files, - log_work_event, -) from ra_aid.database.repositories.config_repository import get_config_repository -from ra_aid.env_inv_context import get_env_inv +from ra_aid.anthropic_token_limiter import state_modifier, get_model_token_limit console = Console() logger = get_logger(__name__) # Import repositories using get_* functions -from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository @tool @@ -132,131 +63,19 @@ def output_markdown_message(message: str) -> str: return "Message output." -def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: - """Helper function to estimate total tokens in a sequence of messages. - - Args: - messages: Sequence of messages to count tokens for - - Returns: - Total estimated token count - """ - if not messages: - return 0 - - estimate_tokens = CiaynAgent._estimate_tokens - return sum(estimate_tokens(msg) for msg in messages) - - -def state_modifier( - state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT -) -> list[BaseMessage]: - """Given the agent state and max_tokens, return a trimmed list of messages. - - Args: - state: The current agent state containing messages - max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) - - Returns: - list[BaseMessage]: Trimmed list of messages that fits within token limit - """ - messages = state["messages"] - - if not messages: - return [] - - first_message = messages[0] - remaining_messages = messages[1:] - first_tokens = estimate_messages_tokens([first_message]) - new_max_tokens = max_input_tokens - first_tokens - - trimmed_remaining = trim_messages( - remaining_messages, - token_counter=estimate_messages_tokens, - max_tokens=new_max_tokens, - strategy="last", - allow_partial=False, - ) - - return [first_message] + trimmed_remaining - - -def get_model_token_limit( - config: Dict[str, Any], agent_type: Literal["default", "research", "planner"] -) -> Optional[int]: - """Get the token limit for the current model configuration based on agent type. - - Returns: - Optional[int]: The token limit if found, None otherwise - """ - try: - # Try to get config from repository for production use - try: - config_from_repo = get_config_repository().get_all() - # If we succeeded, use the repository config instead of passed config - config = config_from_repo - except RuntimeError: - # In tests, this may fail because the repository isn't set up - # So we'll use the passed config directly - pass - if agent_type == "research": - provider = config.get("research_provider", "") or config.get("provider", "") - model_name = config.get("research_model", "") or config.get("model", "") - elif agent_type == "planner": - provider = config.get("planner_provider", "") or config.get("provider", "") - model_name = config.get("planner_model", "") or config.get("model", "") - else: - provider = config.get("provider", "") - model_name = config.get("model", "") - - try: - provider_model = model_name if not provider else f"{provider}/{model_name}" - model_info = get_model_info(provider_model) - max_input_tokens = model_info.get("max_input_tokens") - if max_input_tokens: - logger.debug( - f"Using litellm token limit for {model_name}: {max_input_tokens}" - ) - return max_input_tokens - except litellm.exceptions.NotFoundError: - logger.debug( - f"Model {model_name} not found in litellm, falling back to models_params" - ) - except Exception as e: - logger.debug( - f"Error getting model info from litellm: {e}, falling back to models_params" - ) - - # Fallback to models_params dict - # Normalize model name for fallback lookup (e.g. claude-2 -> claude2) - normalized_name = model_name.replace("-", "") - provider_tokens = models_params.get(provider, {}) - if normalized_name in provider_tokens: - max_input_tokens = provider_tokens[normalized_name]["token_limit"] - logger.debug( - f"Found token limit for {provider}/{model_name}: {max_input_tokens}" - ) - else: - max_input_tokens = None - logger.debug(f"Could not find token limit for {provider}/{model_name}") - - return max_input_tokens - - except Exception as e: - logger.warning(f"Failed to get model token limit: {e}") - return None def build_agent_kwargs( checkpointer: Optional[Any] = None, + model: ChatAnthropic = None, max_input_tokens: Optional[int] = None, ) -> Dict[str, Any]: """Build kwargs dictionary for agent creation. Args: checkpointer: Optional memory checkpointer - config: Optional configuration dictionary - token_limit: Optional token limit for the model + model: The language model to use for token counting + max_input_tokens: Optional token limit for the model Returns: Dictionary of kwargs for agent creation @@ -269,12 +88,17 @@ def build_agent_kwargs( agent_kwargs["checkpointer"] = checkpointer config = get_config_repository().get_all() - if config.get("limit_tokens", True) and is_anthropic_claude(config): + if ( + config.get("limit_tokens", True) + and is_anthropic_claude(config) + and model is not None + ): def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: - return state_modifier(state, max_input_tokens=max_input_tokens) + return state_modifier(state, model, max_input_tokens=max_input_tokens) agent_kwargs["state_modifier"] = wrapped_state_modifier + agent_kwargs["name"] = "React" return agent_kwargs @@ -340,11 +164,13 @@ def create_agent( max_input_tokens = ( get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT ) + print(f"max_input_tokens={max_input_tokens}") # Use REACT agent for Anthropic Claude models, otherwise use CIAYN if is_anthropic_claude(config): logger.debug("Using create_react_agent to instantiate agent.") - agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) + + agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens) return create_react_agent( model, tools, interrupt_after=["tools"], **agent_kwargs ) @@ -357,16 +183,12 @@ def create_agent( logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") config = get_config_repository().get_all() max_input_tokens = get_model_token_limit(config, agent_type) - agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) + agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens) return create_react_agent( model, tools, interrupt_after=["tools"], **agent_kwargs ) -from ra_aid.agents.research_agent import run_research_agent, run_web_research_agent -from ra_aid.agents.implementation_agent import run_task_implementation_agent - - _CONTEXT_STACK = [] _INTERRUPT_CONTEXT = None _FEEDBACK_MODE = False diff --git a/ra_aid/anthropic_token_limiter.py b/ra_aid/anthropic_token_limiter.py new file mode 100644 index 0000000..97f235c --- /dev/null +++ b/ra_aid/anthropic_token_limiter.py @@ -0,0 +1,210 @@ +"""Utilities for handling token limits with Anthropic models.""" + +from functools import partial +from typing import Any, Dict, List, Optional, Sequence, Union + +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import BaseMessage, trim_messages +from langchain_core.messages.base import messages_to_dict +from langgraph.prebuilt.chat_agent_executor import AgentState +from litellm import token_counter + +from ra_aid.agent_backends.ciayn_agent import CiaynAgent +from ra_aid.database.repositories.config_repository import get_config_repository +from ra_aid.logging_config import get_logger +from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params +from ra_aid.console.output import print_messages_compact + +logger = get_logger(__name__) + + +def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: + """Helper function to estimate total tokens in a sequence of messages. + + Args: + messages: Sequence of messages to count tokens for + + Returns: + Total estimated token count + """ + if not messages: + return 0 + + estimate_tokens = CiaynAgent._estimate_tokens + return sum(estimate_tokens(msg) for msg in messages) + + +def create_token_counter_wrapper(model: str): + """Create a wrapper for token counter that handles BaseMessage conversion. + + Args: + model: The model name to use for token counting + + Returns: + A function that accepts BaseMessage objects and returns token count + """ + + # Create a partial function that already has the model parameter set + base_token_counter = partial(token_counter, model=model) + + def wrapped_token_counter(messages: List[Union[BaseMessage, Dict]]) -> int: + """Count tokens in a list of messages, converting BaseMessage to dict for litellm token counter usage. + + Args: + messages: List of messages (either BaseMessage objects or dicts) + + Returns: + Token count for the messages + """ + if not messages: + return 0 + + if isinstance(messages[0], BaseMessage): + messages_dicts = [msg["data"] for msg in messages_to_dict(messages)] + return base_token_counter(messages=messages_dicts) + else: + # Already in dict format + return base_token_counter(messages=messages) + + return wrapped_token_counter + + +def state_modifier( + state: AgentState, model: ChatAnthropic, max_input_tokens: int = DEFAULT_TOKEN_LIMIT +) -> list[BaseMessage]: + """Given the agent state and max_tokens, return a trimmed list of messages but always keep the first message. + + Args: + state: The current agent state containing messages + model: The language model to use for token counting + max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) + + Returns: + list[BaseMessage]: Trimmed list of messages that fits within token limit + """ + messages = state["messages"] + + if not messages: + return [] + + first_message = messages[0] + remaining_messages = messages[1:] + + + wrapped_token_counter = create_token_counter_wrapper(model.model) + + first_tokens = wrapped_token_counter([first_message]) + new_max_tokens = max_input_tokens - first_tokens + + print_messages_compact(messages) + + trimmed_remaining = trim_messages( + remaining_messages, + token_counter=wrapped_token_counter, + max_tokens=new_max_tokens, + strategy="last", + allow_partial=False, + ) + + return [first_message] + trimmed_remaining + + +def sonnet_3_5_state_modifier( + state: AgentState, max_input_tokens: int = DEFAULT_TOKEN_LIMIT +) -> list[BaseMessage]: + """Given the agent state and max_tokens, return a trimmed list of messages. + + Args: + state: The current agent state containing messages + max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) + + Returns: + list[BaseMessage]: Trimmed list of messages that fits within token limit + """ + messages = state["messages"] + + if not messages: + return [] + + first_message = messages[0] + remaining_messages = messages[1:] + first_tokens = estimate_messages_tokens([first_message]) + new_max_tokens = max_input_tokens - first_tokens + + trimmed_remaining = trim_messages( + remaining_messages, + token_counter=estimate_messages_tokens, + max_tokens=new_max_tokens, + strategy="last", + allow_partial=False, + ) + + return [first_message] + trimmed_remaining + + +def get_model_token_limit( + config: Dict[str, Any], agent_type: str = "default" +) -> Optional[int]: + """Get the token limit for the current model configuration based on agent type. + + Args: + config: Configuration dictionary containing provider and model information + agent_type: Type of agent ("default", "research", or "planner") + + Returns: + Optional[int]: The token limit if found, None otherwise + """ + try: + # Try to get config from repository for production use + try: + config_from_repo = get_config_repository().get_all() + # If we succeeded, use the repository config instead of passed config + config = config_from_repo + except RuntimeError: + # In tests, this may fail because the repository isn't set up + # So we'll use the passed config directly + pass + if agent_type == "research": + provider = config.get("research_provider", "") or config.get("provider", "") + model_name = config.get("research_model", "") or config.get("model", "") + elif agent_type == "planner": + provider = config.get("planner_provider", "") or config.get("provider", "") + model_name = config.get("planner_model", "") or config.get("model", "") + else: + provider = config.get("provider", "") + model_name = config.get("model", "") + + try: + from litellm import get_model_info + + provider_model = model_name if not provider else f"{provider}/{model_name}" + model_info = get_model_info(provider_model) + max_input_tokens = model_info.get("max_input_tokens") + if max_input_tokens: + logger.debug( + f"Using litellm token limit for {model_name}: {max_input_tokens}" + ) + return max_input_tokens + except Exception as e: + logger.debug( + f"Error getting model info from litellm: {e}, falling back to models_params" + ) + + # Fallback to models_params dict + # Normalize model name for fallback lookup (e.g. claude-2 -> claude2) + normalized_name = model_name.replace("-", "") + provider_tokens = models_params.get(provider, {}) + if normalized_name in provider_tokens: + max_input_tokens = provider_tokens[normalized_name]["token_limit"] + logger.debug( + f"Found token limit for {provider}/{model_name}: {max_input_tokens}" + ) + else: + max_input_tokens = None + logger.debug(f"Could not find token limit for {provider}/{model_name}") + + return max_input_tokens + + except Exception as e: + logger.warning(f"Failed to get model token limit: {e}") + return None diff --git a/ra_aid/config.py b/ra_aid/config.py index c414f3b..ac69365 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -6,6 +6,7 @@ DEFAULT_MAX_TOOL_FAILURES = 3 FALLBACK_TOOL_MODEL_LIMIT = 5 RETRY_FALLBACK_COUNT = 3 DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds +DEFAULT_MODEL="claude-3-7-sonnet-20250219" VALID_PROVIDERS = [ diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 8a45fec..14ef069 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Sequence -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from rich.markdown import Markdown from rich.panel import Panel @@ -94,3 +94,57 @@ def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") - """ console.print(Panel(Markdown(message), title=title, border_style=border_style)) + + +def print_messages_compact(messages: Sequence[BaseMessage]) -> None: + """Print a compact representation of a list of messages. + + Warning: Used mainly for debugging purposes so do not delete if not referenced anywhere! + For all message types, only the first 30 characters of content are shown. + + Args: + messages: A sequence of BaseMessage objects to print + """ + if not messages: + console.print("[italic]No messages[/italic]") + return + + for i, msg in enumerate(messages): + msg_type = msg.__class__.__name__ + content = msg.content + + # Process content based on its type + if isinstance(content, str): + display_content = f"{content[:30]}..." if len(content) > 30 else content + elif isinstance(content, list): + # Handle structured content (list of content blocks) + content_preview = [] + for item in content[:2]: # Show first 2 items at most + if isinstance(item, dict): + if item.get("type") == "text": + text = item.get("text", "") + content_preview.append(f"text: {text[:20]}..." if len(text) > 20 else f"text: {text}") + elif item.get("type") == "tool_call": + tool_name = item.get("tool_call", {}).get("name", "unknown") + content_preview.append(f"tool_call: {tool_name}") + else: + content_preview.append(f"{item.get('type', 'unknown')}") + + if len(content) > 2: + content_preview.append(f"...({len(content)-2} more)") + + display_content = ", ".join(content_preview) + else: + display_content = str(content)[:30] + "..." if len(str(content)) > 30 else str(content) + + # Add additional tool message info if available + additional_info = [] + if hasattr(msg, "tool_call_id") and msg.tool_call_id: + additional_info.append(f"tool_call_id: {msg.tool_call_id}") + if hasattr(msg, "name") and msg.name: + additional_info.append(f"name: {msg.name}") + if hasattr(msg, "status") and msg.status: + additional_info.append(f"status: {msg.status}") + + info_str = f" ({', '.join(additional_info)})" if additional_info else "" + console.print(f"[{i}] [bold]{msg_type}{info_str}[/bold]: {display_content}") diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 3091260..baedf02 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -241,8 +241,9 @@ def create_llm_client( else: temp_kwargs = {} + thinking_kwargs = {} if supports_thinking: - temp_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}} + thinking_kwargs = {"thinking": {"type": "enabled", "budget_tokens": 12000}} if provider == "deepseek": return create_deepseek_client( @@ -250,6 +251,7 @@ def create_llm_client( api_key=config["api_key"], base_url=config["base_url"], **temp_kwargs, + **thinking_kwargs, is_expert=is_expert, ) elif provider == "openrouter": @@ -257,6 +259,7 @@ def create_llm_client( model_name=model_name, api_key=config["api_key"], **temp_kwargs, + **thinking_kwargs, is_expert=is_expert, ) elif provider == "openai": @@ -271,6 +274,7 @@ def create_llm_client( return ChatOpenAI( **{ **openai_kwargs, + **thinking_kwargs, "timeout": LLM_REQUEST_TIMEOUT, "max_retries": LLM_MAX_RETRIES, } @@ -283,6 +287,7 @@ def create_llm_client( max_retries=LLM_MAX_RETRIES, max_tokens=model_config.get("max_tokens", 64000), **temp_kwargs, + **thinking_kwargs, ) elif provider == "openai-compatible": return ChatOpenAI( @@ -292,6 +297,7 @@ def create_llm_client( timeout=LLM_REQUEST_TIMEOUT, max_retries=LLM_MAX_RETRIES, **temp_kwargs, + **thinking_kwargs, ) elif provider == "gemini": return ChatGoogleGenerativeAI( @@ -300,6 +306,7 @@ def create_llm_client( timeout=LLM_REQUEST_TIMEOUT, max_retries=LLM_MAX_RETRIES, **temp_kwargs, + **thinking_kwargs, ) else: raise ValueError(f"Unsupported provider: {provider}") diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 26190e3..c2c09be 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -14,6 +14,7 @@ from ra_aid.agent_context import ( is_crashed, reset_completion_flags, ) +from ra_aid.config import DEFAULT_MODEL from ra_aid.console.formatting import print_error from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository @@ -337,7 +338,7 @@ def request_task_implementation(task_spec: str) -> str: config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), - config.get("model", "claude-3-5-sonnet-20241022"), + config.get("model",DEFAULT_MODEL), temperature=config.get("temperature"), ) @@ -475,7 +476,7 @@ def request_implementation(task_spec: str) -> str: config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), - config.get("model", "claude-3-5-sonnet-20241022"), + config.get("model", DEFAULT_MODEL), temperature=config.get("temperature"), ) @@ -592,4 +593,4 @@ def request_implementation(task_spec: str) -> str: # Join all parts into a single markdown string markdown_output = "".join(markdown_parts) - return markdown_output \ No newline at end of file + return markdown_output diff --git a/tests/ra_aid/test_anthropic_token_limiter.py b/tests/ra_aid/test_anthropic_token_limiter.py new file mode 100644 index 0000000..933c73e --- /dev/null +++ b/tests/ra_aid/test_anthropic_token_limiter.py @@ -0,0 +1,198 @@ +import unittest +from unittest.mock import MagicMock, patch + +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langgraph.prebuilt.chat_agent_executor import AgentState + +from ra_aid.anthropic_token_limiter import ( + create_token_counter_wrapper, + estimate_messages_tokens, + get_model_token_limit, + state_modifier, +) + + +class TestAnthropicTokenLimiter(unittest.TestCase): + def setUp(self): + from ra_aid.config import DEFAULT_MODEL + + self.mock_model = MagicMock(spec=ChatAnthropic) + self.mock_model.model = DEFAULT_MODEL + + # Sample messages for testing + self.system_message = SystemMessage(content="You are a helpful assistant.") + self.human_message = HumanMessage(content="Hello, can you help me with a task?") + self.long_message = HumanMessage(content="A" * 1000) # Long message to test trimming + + # Create more messages for testing + self.extra_messages = [ + HumanMessage(content=f"Extra message {i}") for i in range(5) + ] + + # Mock state for testing state_modifier with many messages + self.state = AgentState( + messages=[self.system_message, self.human_message, self.long_message] + self.extra_messages, + next=None, + ) + + @patch("ra_aid.anthropic_token_limiter.token_counter") + def test_create_token_counter_wrapper(self, mock_token_counter): + from ra_aid.config import DEFAULT_MODEL + + # Setup mock return values + mock_token_counter.return_value = 50 + + # Create the wrapper + wrapper = create_token_counter_wrapper(DEFAULT_MODEL) + + # Test with BaseMessage objects + result = wrapper([self.human_message]) + self.assertEqual(result, 50) + + # Test with empty list + result = wrapper([]) + self.assertEqual(result, 0) + + # Verify the mock was called with the right parameters + mock_token_counter.assert_called_with(messages=unittest.mock.ANY, model=DEFAULT_MODEL) + + @patch("ra_aid.anthropic_token_limiter.CiaynAgent._estimate_tokens") + def test_estimate_messages_tokens(self, mock_estimate_tokens): + # Setup mock to return different values for different messages + mock_estimate_tokens.side_effect = lambda msg: 10 if isinstance(msg, SystemMessage) else 20 + + # Test with multiple messages + messages = [self.system_message, self.human_message] + result = estimate_messages_tokens(messages) + + # Should be sum of individual token counts (10 + 20) + self.assertEqual(result, 30) + + # Test with empty list + result = estimate_messages_tokens([]) + self.assertEqual(result, 0) + + @patch("ra_aid.anthropic_token_limiter.create_token_counter_wrapper") + @patch("ra_aid.anthropic_token_limiter.print_messages_compact") + def test_state_modifier(self, mock_print, mock_create_wrapper): + # Setup a proper token counter function that returns integers + # This function needs to return values that will cause trim_messages to keep only the first message + def token_counter(msgs): + # For a single message, return a small token count + if len(msgs) == 1: + return 10 + # For two messages (first + one more), return a value under our limit + elif len(msgs) == 2: + return 30 # This is under our 40 token remaining budget (50-10) + # For three messages, return a value just under our limit + elif len(msgs) == 3: + return 40 # This is exactly at our 40 token remaining budget (50-10) + # For four messages, return a value just at our limit + elif len(msgs) == 4: + return 40 # This is exactly at our 40 token remaining budget (50-10) + # For five messages, return a value that exceeds our 40 token budget + elif len(msgs) == 5: + return 60 # This exceeds our 40 token budget, forcing only 4 more messages + # For more messages, return a value over our limit + else: + return 100 # This exceeds our limit + + # Don't use side_effect here, directly return the function + mock_create_wrapper.return_value = token_counter + + # Call state_modifier with a max token limit of 50 + result = state_modifier(self.state, self.mock_model, max_input_tokens=50) + + # Should keep first message and some of the others (up to 5 total) + self.assertEqual(len(result), 5) # First message plus four more + self.assertEqual(result[0], self.system_message) # First message is preserved + + # Verify the wrapper was created with the right model + mock_create_wrapper.assert_called_with(self.mock_model.model) + + # Verify print_messages_compact was called + mock_print.assert_called_once() + + @patch("ra_aid.anthropic_token_limiter.get_config_repository") + @patch("litellm.get_model_info") + def test_get_model_token_limit_from_litellm(self, mock_get_model_info, mock_get_config_repo): + from ra_aid.config import DEFAULT_MODEL + + # Setup mocks + mock_config = {"provider": "anthropic", "model": DEFAULT_MODEL} + mock_get_config_repo.return_value.get_all.return_value = mock_config + + # Mock litellm's get_model_info to return a token limit + mock_get_model_info.return_value = {"max_input_tokens": 100000} + + # Test getting token limit + result = get_model_token_limit(mock_config) + self.assertEqual(result, 100000) + + # Verify get_model_info was called with the right model + mock_get_model_info.assert_called_with(f"anthropic/{DEFAULT_MODEL}") + + @patch("ra_aid.anthropic_token_limiter.get_config_repository") + @patch("litellm.get_model_info") + def test_get_model_token_limit_fallback(self, mock_get_model_info, mock_get_config_repo): + # Setup mocks + mock_config = {"provider": "anthropic", "model": "claude-2"} + mock_get_config_repo.return_value.get_all.return_value = mock_config + + # Make litellm's get_model_info raise an exception to test fallback + mock_get_model_info.side_effect = Exception("Model not found") + + # Test getting token limit from models_params fallback + with patch("ra_aid.anthropic_token_limiter.models_params", { + "anthropic": { + "claude2": {"token_limit": 100000} + } + }): + result = get_model_token_limit(mock_config) + self.assertEqual(result, 100000) + + @patch("ra_aid.anthropic_token_limiter.get_config_repository") + @patch("litellm.get_model_info") + def test_get_model_token_limit_for_different_agent_types(self, mock_get_model_info, mock_get_config_repo): + from ra_aid.config import DEFAULT_MODEL + + # Setup mocks for different agent types + mock_config = { + "provider": "anthropic", + "model": DEFAULT_MODEL, + "research_provider": "openai", + "research_model": "gpt-4", + "planner_provider": "anthropic", + "planner_model": "claude-3-sonnet-20240229" + } + mock_get_config_repo.return_value.get_all.return_value = mock_config + + # Mock different returns for different models + def model_info_side_effect(model_name): + if DEFAULT_MODEL in model_name or "claude-3-7-sonnet" in model_name: + return {"max_input_tokens": 200000} + elif "gpt-4" in model_name: + return {"max_input_tokens": 8192} + elif "claude-3-sonnet" in model_name: + return {"max_input_tokens": 100000} + else: + raise Exception(f"Unknown model: {model_name}") + + mock_get_model_info.side_effect = model_info_side_effect + + # Test default agent type + result = get_model_token_limit(mock_config, "default") + self.assertEqual(result, 200000) + + # Test research agent type + result = get_model_token_limit(mock_config, "research") + self.assertEqual(result, 8192) + + # Test planner agent type + result = get_model_token_limit(mock_config, "planner") + self.assertEqual(result, 100000) + + +if __name__ == "__main__": + unittest.main()