diff --git a/.gitignore b/.gitignore index e4b8ccb..a500b70 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ __pycache__/ /venv /.idea /htmlcov +.envrc diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index caf3baa..9860b2a 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -6,112 +6,120 @@ from rich.panel import Panel from rich.console import Console from langgraph.checkpoint.memory import MemorySaver from ra_aid.env import validate_environment -from ra_aid.project_info import get_project_info, format_project_info, display_project_status +from ra_aid.project_info import ( + get_project_info, + format_project_info, + display_project_status, +) from ra_aid.tools.memory import _global_memory from ra_aid.tools.human import ask_human from ra_aid import print_stage_header, print_error -from ra_aid.tools.human import ask_human from ra_aid.__version__ import __version__ from ra_aid.agent_utils import ( AgentInterrupt, run_agent_with_retry, run_research_agent, run_planning_agent, - create_agent -) -from ra_aid.prompts import ( - CHAT_PROMPT, - WEB_RESEARCH_PROMPT_SECTION_CHAT + create_agent, ) +from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT from ra_aid.llm import initialize_llm from ra_aid.logging_config import setup_logging, get_logger -from ra_aid.tool_configs import ( - get_chat_tools -) +from ra_aid.tool_configs import get_chat_tools from ra_aid.dependencies import check_dependencies import os logger = get_logger(__name__) + def parse_arguments(args=None): - VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible', 'gemini'] - ANTHROPIC_DEFAULT_MODEL = 'claude-3-5-sonnet-20241022' - OPENAI_DEFAULT_MODEL = 'gpt-4o' + VALID_PROVIDERS = [ + "anthropic", + "openai", + "openrouter", + "openai-compatible", + "gemini", + ] + ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022" + OPENAI_DEFAULT_MODEL = "gpt-4o" parser = argparse.ArgumentParser( - description='RA.Aid - AI Agent for executing programming and research tasks', + description="RA.Aid - AI Agent for executing programming and research tasks", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=''' + epilog=""" Examples: ra-aid -m "Add error handling to the database module" ra-aid -m "Explain the authentication flow" --research-only - ''' + """, ) parser.add_argument( - '-m', '--message', + "-m", + "--message", type=str, - help='The task or query to be executed by the agent' + help="The task or query to be executed by the agent", ) parser.add_argument( - '--version', - action='version', - version=f'%(prog)s {__version__}', - help='Show program version number and exit' + "--version", + action="version", + version=f"%(prog)s {__version__}", + help="Show program version number and exit", ) parser.add_argument( - '--research-only', - action='store_true', - help='Only perform research without implementation' + "--research-only", + action="store_true", + help="Only perform research without implementation", ) parser.add_argument( - '--provider', + "--provider", type=str, - default='openai' if (os.getenv('OPENAI_API_KEY') and not os.getenv('ANTHROPIC_API_KEY')) else 'anthropic', + default="openai" + if (os.getenv("OPENAI_API_KEY") and not os.getenv("ANTHROPIC_API_KEY")) + else "anthropic", choices=VALID_PROVIDERS, - help='The LLM provider to use' + help="The LLM provider to use", + ) + parser.add_argument("--model", type=str, help="The model name to use") + parser.add_argument( + "--cowboy-mode", + action="store_true", + help="Skip interactive approval for shell commands", ) parser.add_argument( - '--model', + "--expert-provider", type=str, - help='The model name to use' - ) - parser.add_argument( - '--cowboy-mode', - action='store_true', - help='Skip interactive approval for shell commands' - ) - parser.add_argument( - '--expert-provider', - type=str, - default='openai', + default="openai", choices=VALID_PROVIDERS, - help='The LLM provider to use for expert knowledge queries (default: openai)' + help="The LLM provider to use for expert knowledge queries (default: openai)", ) parser.add_argument( - '--expert-model', + "--expert-model", type=str, - help='The model name to use for expert knowledge queries (required for non-OpenAI providers)' + help="The model name to use for expert knowledge queries (required for non-OpenAI providers)", ) parser.add_argument( - '--hil', '-H', - action='store_true', - help='Enable human-in-the-loop mode, where the agent can prompt the user for additional information.' + "--hil", + "-H", + action="store_true", + help="Enable human-in-the-loop mode, where the agent can prompt the user for additional information.", ) parser.add_argument( - '--chat', - action='store_true', - help='Enable chat mode with direct human interaction (implies --hil)' + "--chat", + action="store_true", + help="Enable chat mode with direct human interaction (implies --hil)", ) parser.add_argument( - '--verbose', - action='store_true', - help='Enable verbose logging output' + "--verbose", action="store_true", help="Enable verbose logging output" ) parser.add_argument( - '--temperature', + "--temperature", type=float, - help='LLM temperature (0.0-2.0). Controls randomness in responses', - default=None + help="LLM temperature (0.0-2.0). Controls randomness in responses", + default=None, + ) + parser.add_argument( + "--disable-limit-tokens", + action="store_false", + help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.", ) if args is None: @@ -129,23 +137,34 @@ Examples: if parsed_args.provider == "openai": parsed_args.model = parsed_args.model or OPENAI_DEFAULT_MODEL - if parsed_args.provider == 'anthropic': + if parsed_args.provider == "anthropic": # Always use default model for Anthropic parsed_args.model = ANTHROPIC_DEFAULT_MODEL elif not parsed_args.model and not parsed_args.research_only: # Require model for other providers unless in research mode - parser.error(f"--model is required when using provider '{parsed_args.provider}'") + parser.error( + f"--model is required when using provider '{parsed_args.provider}'" + ) # Validate expert model requirement - if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only: - parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'") + if ( + parsed_args.expert_provider != "openai" + and not parsed_args.expert_model + and not parsed_args.research_only + ): + parser.error( + f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'" + ) # Validate temperature range if provided - if parsed_args.temperature is not None and not (0.0 <= parsed_args.temperature <= 2.0): - parser.error('Temperature must be between 0.0 and 2.0') + if parsed_args.temperature is not None and not ( + 0.0 <= parsed_args.temperature <= 2.0 + ): + parser.error("Temperature must be between 0.0 and 2.0") return parsed_args + # Create console instance console = Console() @@ -157,14 +176,18 @@ implementation_memory = MemorySaver() def is_informational_query() -> bool: """Determine if the current query is informational based on implementation_requested state.""" - return _global_memory.get('config', {}).get('research_only', False) or not is_stage_requested('implementation') + return _global_memory.get("config", {}).get( + "research_only", False + ) or not is_stage_requested("implementation") + def is_stage_requested(stage: str) -> bool: """Check if a stage has been requested to proceed.""" - if stage == 'implementation': - return _global_memory.get('implementation_requested', False) + if stage == "implementation": + return _global_memory.get("implementation_requested", False) return False + def main(): """Main entry point for the ra-aid command line tool.""" args = parse_arguments() @@ -175,26 +198,32 @@ def main(): # Check dependencies before proceeding check_dependencies() - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) # Will exit if main env vars missing + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) # Will exit if main env vars missing logger.debug("Environment validation successful") if expert_missing: - console.print(Panel( - f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" + - "\n".join(f"- {m}" for m in expert_missing) + - "\nSet the required environment variables or args to enable expert mode.", - title="Expert Tools Disabled", - style="yellow" - )) + console.print( + Panel( + f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" + + "\n".join(f"- {m}" for m in expert_missing) + + "\nSet the required environment variables or args to enable expert mode.", + title="Expert Tools Disabled", + style="yellow", + ) + ) if web_research_missing: - console.print(Panel( - f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" + - "\n".join(f"- {m}" for m in web_research_missing) + - "\nSet the required environment variables to enable web research.", - title="Web Research Disabled", - style="yellow" - )) + console.print( + Panel( + f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" + + "\n".join(f"- {m}" for m in web_research_missing) + + "\nSet the required environment variables to enable web research.", + title="Web Research Disabled", + style="yellow", + ) + ) # Create the base model after validation model = initialize_llm(args.provider, args.model, temperature=args.temperature) @@ -216,7 +245,9 @@ def main(): formatted_project_info = "" # Get initial request from user - initial_request = ask_human.invoke({"question": "What would you like help with?"}) + initial_request = ask_human.invoke( + {"question": "What would you like help with?"} + ) # Get working directory and current date working_directory = os.getcwd() @@ -230,31 +261,41 @@ def main(): "cowboy_mode": args.cowboy_mode, "hil": True, # Always true in chat mode "web_research_enabled": web_research_enabled, - "initial_request": initial_request + "initial_request": initial_request, + "limit_tokens": args.disable_limit_tokens, } # Store config in global memory - _global_memory['config'] = config - _global_memory['config']['provider'] = args.provider - _global_memory['config']['model'] = args.model - _global_memory['config']['expert_provider'] = args.expert_provider - _global_memory['config']['expert_model'] = args.expert_model + _global_memory["config"] = config + _global_memory["config"]["provider"] = args.provider + _global_memory["config"]["model"] = args.model + _global_memory["config"]["expert_provider"] = args.expert_provider + _global_memory["config"]["expert_model"] = args.expert_model # Create chat agent with appropriate tools chat_agent = create_agent( model, - get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled), - checkpointer=MemorySaver() + get_chat_tools( + expert_enabled=expert_enabled, + web_research_enabled=web_research_enabled, + ), + checkpointer=MemorySaver(), ) # Run chat agent and exit - run_agent_with_retry(chat_agent, CHAT_PROMPT.format( + run_agent_with_retry( + chat_agent, + CHAT_PROMPT.format( initial_request=initial_request, - web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else "", + web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT + if web_research_enabled + else "", working_directory=working_directory, current_date=current_date, - project_info=formatted_project_info - ), config) + project_info=formatted_project_info, + ), + config, + ) return # Validate message is provided @@ -268,19 +309,20 @@ def main(): "recursion_limit": 100, "research_only": args.research_only, "cowboy_mode": args.cowboy_mode, - "web_research_enabled": web_research_enabled + "web_research_enabled": web_research_enabled, + "limit_tokens": args.disable_limit_tokens, } # Store config in global memory for access by is_informational_query - _global_memory['config'] = config + _global_memory["config"] = config # Store model configuration - _global_memory['config']['provider'] = args.provider - _global_memory['config']['model'] = args.model + _global_memory["config"]["provider"] = args.provider + _global_memory["config"]["model"] = args.model # Store expert provider and model in config - _global_memory['config']['expert_provider'] = args.expert_provider - _global_memory['config']['expert_model'] = args.expert_model + _global_memory["config"]["expert_provider"] = args.expert_provider + _global_memory["config"]["expert_model"] = args.expert_model # Run research stage print_stage_header("Research Stage") @@ -292,7 +334,7 @@ def main(): research_only=args.research_only, hil=args.hil, memory=research_memory, - config=config + config=config, ) # Proceed with planning and implementation if not an informational query @@ -304,7 +346,7 @@ def main(): expert_enabled=expert_enabled, hil=args.hil, memory=planning_memory, - config=config + config=config, ) except (KeyboardInterrupt, AgentInterrupt): @@ -313,5 +355,6 @@ def main(): print() sys.exit(0) + if __name__ == "__main__": main() diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 4784c64..e1c5479 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -3,22 +3,27 @@ import sys import time import uuid -from typing import Optional, Any +from typing import Optional, Any, List, Dict, Sequence +from langchain_core.messages import BaseMessage, trim_messages import signal -import threading -import time -from typing import Optional -from ra_aid.project_info import get_project_info, format_project_info, display_project_status +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt.chat_agent_executor import AgentState +from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens +from ra_aid.agents.ciayn_agent import CiaynAgent +import threading + +from ra_aid.project_info import ( + get_project_info, + format_project_info, + display_project_status, +) from langgraph.prebuilt import create_react_agent -from ra_aid.agents.ciayn_agent import CiaynAgent -from ra_aid.project_info import get_project_info, format_project_info, display_project_status from ra_aid.console.formatting import print_stage_header, print_error from langchain_core.language_models import BaseChatModel from langchain_core.tools import tool -from typing import List, Any from ra_aid.console.output import print_agent_output from ra_aid.logging_config import get_logger from ra_aid.exceptions import AgentInterrupt @@ -26,7 +31,7 @@ from ra_aid.tool_configs import ( get_implementation_tools, get_research_tools, get_planning_tools, - get_web_research_tools + get_web_research_tools, ) from ra_aid.prompts import ( IMPLEMENTATION_PROMPT, @@ -41,13 +46,9 @@ from ra_aid.prompts import ( HUMAN_PROMPT_SECTION_RESEARCH, PLANNING_PROMPT, EXPERT_PROMPT_SECTION_PLANNING, - WEB_RESEARCH_PROMPT_SECTION_PLANNING, HUMAN_PROMPT_SECTION_PLANNING, WEB_RESEARCH_PROMPT, - EXPERT_PROMPT_SECTION_CHAT, - CHAT_PROMPT, ) -from langgraph.checkpoint.memory import MemorySaver from langchain_core.messages import HumanMessage from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError @@ -60,58 +61,189 @@ from ra_aid.tools.memory import ( get_memory_value, get_related_files, ) -from ra_aid.tool_configs import get_research_tools -from ra_aid.prompts import ( - RESEARCH_PROMPT, - RESEARCH_ONLY_PROMPT, - EXPERT_PROMPT_SECTION_RESEARCH, - HUMAN_PROMPT_SECTION_RESEARCH -) console = Console() logger = get_logger(__name__) + @tool def output_markdown_message(message: str) -> str: """Outputs a message to the user, optionally prompting for input.""" console.print(Panel(Markdown(message.strip()), title="🤖 Assistant")) return "Message output." + +def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int: + """Helper function to estimate total tokens in a sequence of messages. + + Args: + messages: Sequence of messages to count tokens for + + Returns: + Total estimated token count + """ + if not messages: + return 0 + + estimate_tokens = CiaynAgent._estimate_tokens + return sum(estimate_tokens(msg) for msg in messages) + + +def state_modifier( + state: AgentState, max_tokens: int = DEFAULT_TOKEN_LIMIT +) -> list[BaseMessage]: + """Given the agent state and max_tokens, return a trimmed list of messages. + + Args: + state: The current agent state containing messages + max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT) + + Returns: + list[BaseMessage]: Trimmed list of messages that fits within token limit + """ + messages = state["messages"] + + if not messages: + return [] + + first_message = messages[0] + remaining_messages = messages[1:] + first_tokens = estimate_messages_tokens([first_message]) + new_max_tokens = max_tokens - first_tokens + + trimmed_remaining = trim_messages( + remaining_messages, + token_counter=estimate_messages_tokens, + max_tokens=new_max_tokens, + strategy="last", + allow_partial=False, + ) + + return [first_message] + trimmed_remaining + + +def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]: + """Get the token limit for the current model configuration. + + Returns: + Optional[int]: The token limit if found, None otherwise + """ + try: + provider = config.get("provider", "") + model_name = config.get("model", "") + + provider_tokens = models_tokens.get(provider, {}) + token_limit = provider_tokens.get(model_name, None) + if token_limit: + logger.debug( + f"Found token limit for {provider}/{model_name}: {token_limit}" + ) + else: + logger.debug(f"Could not find token limit for {provider}/{model_name}") + + return token_limit + + except Exception as e: + logger.warning(f"Failed to get model token limit: {e}") + return None + + +def build_agent_kwargs( + checkpointer: Optional[Any] = None, + config: Dict[str, Any] = None, + token_limit: Optional[int] = None, +) -> Dict[str, Any]: + """Build kwargs dictionary for agent creation. + + Args: + checkpointer: Optional memory checkpointer + config: Optional configuration dictionary + token_limit: Optional token limit for the model + + Returns: + Dictionary of kwargs for agent creation + """ + agent_kwargs = {} + + if checkpointer is not None: + agent_kwargs["checkpointer"] = checkpointer + + if config.get("limit_tokens", True) and is_anthropic_claude(config): + + def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]: + return state_modifier(state, max_tokens=token_limit) + + agent_kwargs["state_modifier"] = wrapped_state_modifier + + return agent_kwargs + + +def is_anthropic_claude(config: Dict[str, Any]) -> bool: + """Check if the provider and model name indicate an Anthropic Claude model. + + Args: + provider: The provider name + model_name: The model name + + Returns: + bool: True if this is an Anthropic Claude model + """ + provider = config.get("provider", "") + model_name = config.get("model", "") + return ( + provider.lower() == "anthropic" + and model_name + and "claude" in model_name.lower() + ) + + def create_agent( model: BaseChatModel, tools: List[Any], *, - checkpointer: Any = None + checkpointer: Any = None, ) -> Any: """Create a react agent with the given configuration. - + Args: model: The LLM model to use tools: List of tools to provide to the agent checkpointer: Optional memory checkpointer - + config: Optional configuration dictionary containing settings like: + - limit_tokens (bool): Whether to apply token limiting (default: True) + - provider (str): The LLM provider name + - model (str): The model name + Returns: The created agent instance + + Token limiting helps prevent context window overflow by trimming older messages + while preserving system messages. It can be disabled by setting + config['limit_tokens'] = False. """ try: - # Get model name if available - provider = _global_memory.get('config', {}).get('provider') - model_name = _global_memory.get('config', {}).get('model') - + config = _global_memory.get("config", {}) + token_limit = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT + # Use REACT agent for Anthropic Claude models, otherwise use CIAYN - if provider == 'anthropic' and 'claude' in model_name: + if is_anthropic_claude(config): logger.debug("Using create_react_agent to instantiate agent.") - return create_react_agent(model, tools, checkpointer=checkpointer) + agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit) + return create_react_agent(model, tools, **agent_kwargs) else: - logger.debug("Using CiaynAgent agent instance.") - return CiaynAgent(model, tools) - + logger.debug("Using CiaynAgent agent instance") + return CiaynAgent(model, tools, max_tokens=token_limit) + except Exception as e: # Default to REACT agent if provider/model detection fails logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") - return create_react_agent(model, tools, checkpointer=checkpointer) + config = _global_memory.get("config", {}) + token_limit = get_model_token_limit(config) + agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit) + return create_react_agent(model, tools, **agent_kwargs) + def run_research_agent( base_task_or_query: str, @@ -124,7 +256,7 @@ def run_research_agent( memory: Optional[Any] = None, config: Optional[dict] = None, thread_id: Optional[str] = None, - console_message: Optional[str] = None + console_message: Optional[str] = None, ) -> Optional[str]: """Run a research agent with the given configuration. @@ -153,8 +285,13 @@ def run_research_agent( """ thread_id = thread_id or str(uuid.uuid4()) logger.debug("Starting research agent with thread_id=%s", thread_id) - logger.debug("Research configuration: expert=%s, research_only=%s, hil=%s, web=%s", - expert_enabled, research_only, hil, web_research_enabled) + logger.debug( + "Research configuration: expert=%s, research_only=%s, hil=%s, web=%s", + expert_enabled, + research_only, + hil, + web_research_enabled, + ) # Initialize memory if not provided if memory is None: @@ -169,7 +306,7 @@ def run_research_agent( research_only=research_only, expert_enabled=expert_enabled, human_interaction=hil, - web_research_enabled=config.get('web_research_enabled', False) + web_research_enabled=config.get("web_research_enabled", False), ) # Create agent @@ -178,7 +315,11 @@ def run_research_agent( # Format prompt sections expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" - web_research_section = WEB_RESEARCH_PROMPT_SECTION_RESEARCH if config.get('web_research_enabled') else "" + web_research_section = ( + WEB_RESEARCH_PROMPT_SECTION_RESEARCH + if config.get("web_research_enabled") + else "" + ) # Get research context from memory key_facts = _global_memory.get("key_facts", "") @@ -196,29 +337,30 @@ def run_research_agent( # Build prompt prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( base_task=base_task_or_query, - research_only_note='' if research_only else ' Only request implementation if the user explicitly asked for changes to be made.', + research_only_note="" + if research_only + else " Only request implementation if the user explicitly asked for changes to be made.", expert_section=expert_section, human_section=human_section, web_research_section=web_research_section, key_facts=key_facts, - work_log=get_memory_value('work_log'), + work_log=get_memory_value("work_log"), code_snippets=code_snippets, related_files=related_files, - project_info=formatted_project_info + project_info=formatted_project_info, ) # Set up configuration - run_config = { - "configurable": {"thread_id": thread_id}, - "recursion_limit": 100 - } + run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100} if config: run_config.update(config) try: # Display console message if provided if console_message: - console.print(Panel(Markdown(console_message), title="🔬 Looking into it...")) + console.print( + Panel(Markdown(console_message), title="🔬 Looking into it...") + ) if project_info: display_project_status(project_info) @@ -239,7 +381,7 @@ def run_research_agent( memory=memory, config=config, thread_id=thread_id, - console_message=console_message + console_message=console_message, ) except (KeyboardInterrupt, AgentInterrupt): raise @@ -247,6 +389,7 @@ def run_research_agent( logger.error("Research agent failed: %s", str(e), exc_info=True) raise + def run_web_research_agent( query: str, model, @@ -257,7 +400,7 @@ def run_web_research_agent( memory: Optional[Any] = None, config: Optional[dict] = None, thread_id: Optional[str] = None, - console_message: Optional[str] = None + console_message: Optional[str] = None, ) -> Optional[str]: """Run a web research agent with the given configuration. @@ -284,8 +427,12 @@ def run_web_research_agent( """ thread_id = thread_id or str(uuid.uuid4()) logger.debug("Starting web research agent with thread_id=%s", thread_id) - logger.debug("Web research configuration: expert=%s, hil=%s, web=%s", - expert_enabled, hil, web_research_enabled) + logger.debug( + "Web research configuration: expert=%s, hil=%s, web=%s", + expert_enabled, + hil, + web_research_enabled, + ) # Initialize memory if not provided if memory is None: @@ -317,14 +464,11 @@ def run_web_research_agent( human_section=human_section, key_facts=key_facts, code_snippets=code_snippets, - related_files=related_files + related_files=related_files, ) # Set up configuration - run_config = { - "configurable": {"thread_id": thread_id}, - "recursion_limit": 100 - } + run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100} if config: run_config.update(config) @@ -342,6 +486,7 @@ def run_web_research_agent( logger.error("Web research agent failed: %s", str(e), exc_info=True) raise + def run_planning_agent( base_task: str, model, @@ -350,7 +495,7 @@ def run_planning_agent( hil: bool = False, memory: Optional[Any] = None, config: Optional[dict] = None, - thread_id: Optional[str] = None + thread_id: Optional[str] = None, ) -> Optional[str]: """Run a planning agent to create implementation plans. @@ -379,7 +524,10 @@ def run_planning_agent( thread_id = str(uuid.uuid4()) # Configure tools - tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False)) + tools = get_planning_tools( + expert_enabled=expert_enabled, + web_research_enabled=config.get("web_research_enabled", False), + ) # Create agent agent = create_agent(model, tools, checkpointer=memory) @@ -387,7 +535,11 @@ def run_planning_agent( # Format prompt sections expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else "" - web_research_section = WEB_RESEARCH_PROMPT_SECTION_PLANNING if config.get('web_research_enabled') else "" + web_research_section = ( + WEB_RESEARCH_PROMPT_SECTION_PLANNING + if config.get("web_research_enabled") + else "" + ) # Build prompt planning_prompt = PLANNING_PROMPT.format( @@ -395,19 +547,18 @@ def run_planning_agent( human_section=human_section, web_research_section=web_research_section, base_task=base_task, - research_notes=get_memory_value('research_notes'), + research_notes=get_memory_value("research_notes"), related_files="\n".join(get_related_files()), - key_facts=get_memory_value('key_facts'), - key_snippets=get_memory_value('key_snippets'), - work_log=get_memory_value('work_log'), - research_only_note='' if config.get('research_only') else ' Only request implementation if the user explicitly asked for changes to be made.' + key_facts=get_memory_value("key_facts"), + key_snippets=get_memory_value("key_snippets"), + work_log=get_memory_value("work_log"), + research_only_note="" + if config.get("research_only") + else " Only request implementation if the user explicitly asked for changes to be made.", ) # Set up configuration - run_config = { - "configurable": {"thread_id": thread_id}, - "recursion_limit": 100 - } + run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100} if config: run_config.update(config) @@ -421,6 +572,7 @@ def run_planning_agent( logger.error("Planning agent failed: %s", str(e), exc_info=True) raise + def run_task_implementation_agent( base_task: str, tasks: list, @@ -433,7 +585,7 @@ def run_task_implementation_agent( web_research_enabled: bool = False, memory: Optional[Any] = None, config: Optional[dict] = None, - thread_id: Optional[str] = None + thread_id: Optional[str] = None, ) -> Optional[str]: """Run an implementation agent for a specific task. @@ -454,7 +606,11 @@ def run_task_implementation_agent( """ thread_id = thread_id or str(uuid.uuid4()) logger.debug("Starting implementation agent with thread_id=%s", thread_id) - logger.debug("Implementation configuration: expert=%s, web=%s", expert_enabled, web_research_enabled) + logger.debug( + "Implementation configuration: expert=%s, web=%s", + expert_enabled, + web_research_enabled, + ) logger.debug("Task details: base_task=%s, current_task=%s", base_task, task) logger.debug("Related files: %s", related_files) @@ -467,7 +623,10 @@ def run_task_implementation_agent( thread_id = str(uuid.uuid4()) # Configure tools - tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False)) + tools = get_implementation_tools( + expert_enabled=expert_enabled, + web_research_enabled=config.get("web_research_enabled", False), + ) # Create agent agent = create_agent(model, tools, checkpointer=memory) @@ -479,20 +638,21 @@ def run_task_implementation_agent( tasks=tasks, plan=plan, related_files=related_files, - key_facts=get_memory_value('key_facts'), - key_snippets=get_memory_value('key_snippets'), - research_notes=get_memory_value('research_notes'), - work_log=get_memory_value('work_log'), + key_facts=get_memory_value("key_facts"), + key_snippets=get_memory_value("key_snippets"), + research_notes=get_memory_value("research_notes"), + work_log=get_memory_value("work_log"), expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", - human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION if _global_memory.get('config', {}).get('hil', False) else "", - web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT if config.get('web_research_enabled') else "" + human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION + if _global_memory.get("config", {}).get("hil", False) + else "", + web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT + if config.get("web_research_enabled") + else "", ) # Set up configuration - run_config = { - "configurable": {"thread_id": thread_id}, - "recursion_limit": 100 - } + run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100} if config: run_config.update(config) @@ -505,10 +665,12 @@ def run_task_implementation_agent( logger.error("Implementation agent failed: %s", str(e), exc_info=True) raise + _CONTEXT_STACK = [] _INTERRUPT_CONTEXT = None _FEEDBACK_MODE = False + def _request_interrupt(signum, frame): global _INTERRUPT_CONTEXT if _CONTEXT_STACK: @@ -520,6 +682,7 @@ def _request_interrupt(signum, frame): print() sys.exit(0) + class InterruptibleSection: def __enter__(self): _CONTEXT_STACK.append(self) @@ -528,10 +691,12 @@ class InterruptibleSection: def __exit__(self, exc_type, exc_value, traceback): _CONTEXT_STACK.remove(self) + def check_interrupt(): if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]: raise AgentInterrupt("Interrupt requested") + def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: """Run an agent with retry logic for API errors.""" logger.debug("Running agent with prompt length: %d", len(prompt)) @@ -546,48 +711,68 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: with InterruptibleSection(): try: # Track agent execution depth - current_depth = _global_memory.get('agent_depth', 0) - _global_memory['agent_depth'] = current_depth + 1 + current_depth = _global_memory.get("agent_depth", 0) + _global_memory["agent_depth"] = current_depth + 1 for attempt in range(max_retries): logger.debug("Attempt %d/%d", attempt + 1, max_retries) check_interrupt() try: - for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): + for chunk in agent.stream( + {"messages": [HumanMessage(content=prompt)]}, config + ): logger.debug("Agent output: %s", chunk) check_interrupt() print_agent_output(chunk) - if _global_memory['plan_completed']: - _global_memory['plan_completed'] = False - _global_memory['task_completed'] = False - _global_memory['completion_message'] = '' + if _global_memory["plan_completed"]: + _global_memory["plan_completed"] = False + _global_memory["task_completed"] = False + _global_memory["completion_message"] = "" break - if _global_memory['task_completed']: - _global_memory['task_completed'] = False - _global_memory['completion_message'] = '' + if _global_memory["task_completed"]: + _global_memory["task_completed"] = False + _global_memory["completion_message"] = "" break logger.debug("Agent run completed successfully") return "Agent run completed successfully" except (KeyboardInterrupt, AgentInterrupt): raise - except (InternalServerError, APITimeoutError, RateLimitError, APIError, ValueError) as e: + except ( + InternalServerError, + APITimeoutError, + RateLimitError, + APIError, + ValueError, + ) as e: if isinstance(e, ValueError): error_str = str(e).lower() - if 'code' not in error_str or '429' not in error_str: + if "code" not in error_str or "429" not in error_str: raise # Re-raise ValueError if it's not a Lambda 429 if attempt == max_retries - 1: logger.error("Max retries reached, failing: %s", str(e)) - raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}") - logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e)) - delay = base_delay * (2 ** attempt) - print_error(f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})") + raise RuntimeError( + f"Max retries ({max_retries}) exceeded. Last error: {e}" + ) + logger.warning( + "API error (attempt %d/%d): %s", + attempt + 1, + max_retries, + str(e), + ) + delay = base_delay * (2**attempt) + print_error( + f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})" + ) start = time.monotonic() while time.monotonic() - start < delay: check_interrupt() time.sleep(0.1) finally: # Reset depth tracking - _global_memory['agent_depth'] = _global_memory.get('agent_depth', 1) - 1 + _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 - if original_handler and threading.current_thread() is threading.main_thread(): + if ( + original_handler + and threading.current_thread() is threading.main_thread() + ): signal.signal(signal.SIGINT, original_handler) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index a1bbd82..388cbc5 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -1,8 +1,8 @@ import re from dataclasses import dataclass from typing import Dict, Any, Generator, List, Optional, Union -from typing import Dict, Any, Generator, List, Optional, Union +from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT from ra_aid.tools.reflection import get_function_info from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage @@ -66,7 +66,7 @@ class CiaynAgent: """ - def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000): + def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT): """Initialize the agent with a model and list of tools. Args: @@ -263,6 +263,10 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" text = content.content else: text = content + + # create-react-agent tool calls can be lists + if isinstance(text, List): + return 0 if not text: return 0 diff --git a/ra_aid/llm.py b/ra_aid/llm.py index e14f411..02b3c57 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -104,4 +104,4 @@ def initialize_expert_llm(provider: str = "openai", model_name: str = "o1") -> B model=model_name, ) else: - raise ValueError(f"Unsupported provider: {provider}") \ No newline at end of file + raise ValueError(f"Unsupported provider: {provider}") diff --git a/ra_aid/models_tokens.py b/ra_aid/models_tokens.py new file mode 100644 index 0000000..9f32761 --- /dev/null +++ b/ra_aid/models_tokens.py @@ -0,0 +1,266 @@ +""" +List of model tokens +""" + +DEFAULT_TOKEN_LIMIT = 100000 + +models_tokens = { + "openai": { + "gpt-3.5-turbo-0125": 16385, + "gpt-3.5": 4096, + "gpt-3.5-turbo": 16385, + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo-instruct": 4096, + "gpt-4-0125-preview": 128000, + "gpt-4-turbo-preview": 128000, + "gpt-4-turbo": 128000, + "gpt-4-turbo-2024-04-09": 128000, + "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, + "gpt-4": 8192, + "gpt-4-0613": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0613": 32768, + "gpt-4o": 128000, + "gpt-4o-2024-08-06": 128000, + "gpt-4o-2024-05-13": 128000, + "gpt-4o-mini": 128000, + "o1-preview": 128000, + "o1-mini": 128000, + }, + "azure_openai": { + "gpt-3.5-turbo-0125": 16385, + "gpt-3.5": 4096, + "gpt-3.5-turbo": 16385, + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo-instruct": 4096, + "gpt-4-0125-preview": 128000, + "gpt-4-turbo-preview": 128000, + "gpt-4-turbo": 128000, + "gpt-4-turbo-2024-04-09": 128000, + "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, + "gpt-4": 8192, + "gpt-4-0613": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0613": 32768, + "gpt-4o": 128000, + "gpt-4o-mini": 128000, + "chatgpt-4o-latest": 128000, + "o1-preview": 128000, + "o1-mini": 128000, + }, + "google_genai": { + "gemini-pro": 128000, + "gemini-1.5-flash-latest": 128000, + "gemini-1.5-pro-latest": 128000, + "models/embedding-001": 2048, + }, + "google_vertexai": { + "gemini-1.5-flash": 128000, + "gemini-1.5-pro": 128000, + "gemini-1.0-pro": 128000, + }, + "ollama": { + "command-r": 12800, + "codellama": 16000, + "dbrx": 32768, + "deepseek-coder:33b": 16000, + "falcon": 2048, + "llama2": 4096, + "llama2:7b": 4096, + "llama2:13b": 4096, + "llama2:70b": 4096, + "llama3": 8192, + "llama3:8b": 8192, + "llama3:70b": 8192, + "llama3.1": 128000, + "llama3.1:8b": 128000, + "llama3.1:70b": 128000, + "lama3.1:405b": 128000, + "llama3.2": 128000, + "llama3.2:1b": 128000, + "llama3.2:3b": 128000, + "llama3.3:70b": 128000, + "scrapegraph": 8192, + "mistral-small": 128000, + "mistral-openorca": 32000, + "mistral-large": 128000, + "grok-1": 8192, + "llava": 4096, + "mixtral:8x22b-instruct": 65536, + "nomic-embed-text": 8192, + "nous-hermes2:34b": 4096, + "orca-mini": 2048, + "phi3:3.8b": 12800, + "phi3:14b": 128000, + "qwen:0.5b": 32000, + "qwen:1.8b": 32000, + "qwen:4b": 32000, + "qwen:14b": 32000, + "qwen:32b": 32000, + "qwen:72b": 32000, + "qwen:110b": 32000, + "stablelm-zephyr": 8192, + "wizardlm2:8x22b": 65536, + "mistral": 128000, + "gemma2": 128000, + "gemma2:9b": 128000, + "gemma2:27b": 128000, + # embedding models + "shaw/dmeta-embedding-zh-small-q4": 8192, + "shaw/dmeta-embedding-zh-q4": 8192, + "chevalblanc/acge_text_embedding": 8192, + "martcreation/dmeta-embedding-zh": 8192, + "snowflake-arctic-embed": 8192, + "mxbai-embed-large": 512, + }, + "oneapi": { + "qwen-turbo": 6000, + }, + "nvidia": { + "meta/llama3-70b-instruct": 419, + "meta/llama3-8b-instruct": 419, + "nemotron-4-340b-instruct": 1024, + "databricks/dbrx-instruct": 4096, + "google/codegemma-7b": 8192, + "google/gemma-2b": 2048, + "google/gemma-7b": 8192, + "google/recurrentgemma-2b": 2048, + "meta/codellama-70b": 16384, + "meta/llama2-70b": 4096, + "microsoft/phi-3-mini-128k-instruct": 122880, + "mistralai/mistral-7b-instruct-v0.2": 4096, + "mistralai/mistral-large": 8192, + "mistralai/mixtral-8x22b-instruct-v0.1": 32768, + "mistralai/mixtral-8x7b-instruct-v0.1": 8192, + "snowflake/arctic": 16384, + }, + "groq": { + "llama3-8b-8192": 8192, + "llama3-70b-8192": 8192, + "mixtral-8x7b-32768": 32768, + "gemma-7b-it": 8192, + "claude-3-haiku-20240307'": 8192, + }, + "toghetherai": { + "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": 128000, + "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": 128000, + "mistralai/Mixtral-8x22B-Instruct-v0.1": 128000, + "stabilityai/stable-diffusion-xl-base-1.0": 2048, + "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 128000, + "NousResearch/Hermes-3-Llama-3.1-405B-Turbo": 128000, + "Gryphe/MythoMax-L2-13b-Lite": 8192, + "Salesforce/Llama-Rank-V1": 8192, + "meta-llama/Meta-Llama-Guard-3-8B": 128000, + "meta-llama/Meta-Llama-3-70B-Instruct-Turbo": 128000, + "meta-llama/Llama-3-8b-chat-hf": 8192, + "meta-llama/Llama-3-70b-chat-hf": 8192, + "Qwen/Qwen2-72B-Instruct": 128000, + "google/gemma-2-27b-it": 8192, + }, + "anthropic": { + "claude_instant": 100000, + "claude2": 9000, + "claude2.1": 200000, + "claude3": 200000, + "claude3.5": 200000, + "claude-3-opus-20240229": 200000, + "claude-3-sonnet-20240229": 200000, + "claude-3-haiku-20240307": 200000, + "claude-3-5-sonnet-20240620": 200000, + "claude-3-5-sonnet-20241022": 200000, + "claude-3-5-haiku-latest": 200000, + }, + "bedrock": { + "anthropic.claude-3-haiku-20240307-v1:0": 200000, + "anthropic.claude-3-sonnet-20240229-v1:0": 200000, + "anthropic.claude-3-opus-20240229-v1:0": 200000, + "anthropic.claude-3-5-sonnet-20240620-v1:0": 200000, + "claude-3-5-haiku-latest": 200000, + "anthropic.claude-v2:1": 200000, + "anthropic.claude-v2": 100000, + "anthropic.claude-instant-v1": 100000, + "meta.llama3-8b-instruct-v1:0": 8192, + "meta.llama3-70b-instruct-v1:0": 8192, + "meta.llama2-13b-chat-v1": 4096, + "meta.llama2-70b-chat-v1": 4096, + "mistral.mistral-7b-instruct-v0:2": 32768, + "mistral.mixtral-8x7b-instruct-v0:1": 32768, + "mistral.mistral-large-2402-v1:0": 32768, + "mistral.mistral-small-2402-v1:0": 32768, + "amazon.titan-embed-text-v1": 8000, + "amazon.titan-embed-text-v2:0": 8000, + "cohere.embed-english-v3": 512, + "cohere.embed-multilingual-v3": 512, + }, + "mistralai": { + "mistral-large-latest": 128000, + "open-mistral-nemo": 128000, + "codestral-latest": 32000, + "mistral-embed": 8000, + "open-mistral-7b": 32000, + "open-mixtral-8x7b": 32000, + "open-mixtral-8x22b": 64000, + "open-codestral-mamba": 256000, + }, + "hugging_face": { + "xai-org/grok-1": 8192, + "meta-llama/Meta-Llama-3-8B": 8192, + "meta-llama/Meta-Llama-3-8B-Instruct": 8192, + "meta-llama/Meta-Llama-3-70B": 8192, + "meta-llama/Meta-Llama-3-70B-Instruct": 8192, + "google/gemma-2b": 8192, + "google/gemma-2b-it": 8192, + "google/gemma-7b": 8192, + "google/gemma-7b-it": 8192, + "microsoft/phi-2": 2048, + "openai-community/gpt2": 1024, + "openai-community/gpt2-medium": 1024, + "openai-community/gpt2-large": 1024, + "facebook/opt-125m": 2048, + "petals-team/StableBeluga2": 8192, + "distilbert/distilgpt2": 1024, + "mistralai/Mistral-7B-Instruct-v0.2": 32768, + "gradientai/Llama-3-8B-Instruct-Gradient-1048k": 1040200, + "NousResearch/Hermes-2-Pro-Llama-3-8B": 8192, + "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF": 8192, + "nvidia/Llama3-ChatQA-1.5-8B": 8192, + "microsoft/Phi-3-mini-4k-instruct": 4192, + "microsoft/Phi-3-mini-128k-instruct": 131072, + "mlabonne/Meta-Llama-3-120B-Instruct": 8192, + "cognitivecomputations/dolphin-2.9-llama3-8b": 8192, + "cognitivecomputations/dolphin-2.9-llama3-8b-gguf": 8192, + "cognitivecomputations/dolphin-2.8-mistral-7b-v02": 32768, + "cognitivecomputations/dolphin-2.5-mixtral-8x7b": 32768, + "TheBloke/dolphin-2.7-mixtral-8x7b-GGUF": 32768, + "deepseek-ai/DeepSeek-V2": 131072, + "deepseek-ai/DeepSeek-V2-Chat": 131072, + "claude-3-haiku": 200000, + }, + "deepseek": { + "deepseek-chat": 28672, + "deepseek-coder": 16384, + }, + "ernie": { + "ernie-bot-turbo": 4096, + "ernie-bot": 4096, + "ernie-bot-2": 4096, + "ernie-bot-2-base": 4096, + "ernie-bot-2-base-zh": 4096, + "ernie-bot-2-base-en": 4096, + "ernie-bot-2-base-en-zh": 4096, + "ernie-bot-2-base-zh-en": 4096, + }, + "fireworks": { + "llama-v2-7b": 4096, + "mixtral-8x7b-instruct": 4096, + "nomic-ai/nomic-embed-text-v1.5": 8192, + "llama-3.1-405B-instruct": 131072, + "llama-3.1-70B-instruct": 131072, + "llama-3.1-8B-instruct": 131072, + "mixtral-moe-8x22B-instruct": 65536, + "mixtral-moe-8x7B-instruct": 65536, + }, + "togetherai": {"Meta-Llama-3.1-70B-Instruct-Turbo": 128000}, +} diff --git a/ra_aid/tools/programmer.py b/ra_aid/tools/programmer.py index cf602d4..62e8b1f 100644 --- a/ra_aid/tools/programmer.py +++ b/ra_aid/tools/programmer.py @@ -1,5 +1,5 @@ import os -from typing import List, Optional, Dict, Union +from typing import List, Dict, Union from ra_aid.tools.memory import _global_memory from langchain_core.tools import tool from rich.console import Console @@ -7,7 +7,6 @@ from rich.panel import Panel from rich.markdown import Markdown from rich.text import Text from ra_aid.proc.interactive import run_interactive_command -from pydantic import BaseModel, Field from ra_aid.text.processing import truncate_output console = Console() diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index e69de29..87216f6 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -0,0 +1,204 @@ +"""Unit tests for agent_utils.py.""" + +import pytest +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage +from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT +from ra_aid.agent_utils import state_modifier, AgentState +from unittest.mock import Mock, patch +from langchain_core.language_models import BaseChatModel + +from ra_aid.agent_utils import create_agent, get_model_token_limit +from ra_aid.models_tokens import models_tokens + + +@pytest.fixture +def mock_model(): + """Fixture providing a mock LLM model.""" + model = Mock(spec=BaseChatModel) + return model + + +@pytest.fixture +def mock_memory(): + """Fixture providing a mock global memory store.""" + with patch("ra_aid.agent_utils._global_memory") as mock_mem: + mock_mem.get.return_value = {} + yield mock_mem + + +def test_get_model_token_limit_anthropic(mock_memory): + """Test get_model_token_limit with Anthropic model.""" + config = {"provider": "anthropic", "model": "claude2"} + + token_limit = get_model_token_limit(config) + assert token_limit == models_tokens["anthropic"]["claude2"] + + +def test_get_model_token_limit_openai(mock_memory): + """Test get_model_token_limit with OpenAI model.""" + config = {"provider": "openai", "model": "gpt-4"} + + token_limit = get_model_token_limit(config) + assert token_limit == models_tokens["openai"]["gpt-4"] + + +def test_get_model_token_limit_unknown(mock_memory): + """Test get_model_token_limit with unknown provider/model.""" + config = {"provider": "unknown", "model": "unknown-model"} + + token_limit = get_model_token_limit(config) + assert token_limit is None + + +def test_get_model_token_limit_missing_config(mock_memory): + """Test get_model_token_limit with missing configuration.""" + config = {} + + token_limit = get_model_token_limit(config) + assert token_limit is None + + +def test_create_agent_anthropic(mock_model, mock_memory): + """Test create_agent with Anthropic Claude model.""" + mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"} + + with patch("ra_aid.agent_utils.create_react_agent") as mock_react: + mock_react.return_value = "react_agent" + agent = create_agent(mock_model, []) + + assert agent == "react_agent" + mock_react.assert_called_once_with( + mock_model, [], state_modifier=mock_react.call_args[1]["state_modifier"] + ) + + +def test_create_agent_openai(mock_model, mock_memory): + """Test create_agent with OpenAI model.""" + mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"} + + with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: + mock_ciayn.return_value = "ciayn_agent" + agent = create_agent(mock_model, []) + + assert agent == "ciayn_agent" + mock_ciayn.assert_called_once_with( + mock_model, [], max_tokens=models_tokens["openai"]["gpt-4"] + ) + + +def test_create_agent_no_token_limit(mock_model, mock_memory): + """Test create_agent when no token limit is found.""" + mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"} + + with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: + mock_ciayn.return_value = "ciayn_agent" + agent = create_agent(mock_model, []) + + assert agent == "ciayn_agent" + mock_ciayn.assert_called_once_with( + mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT + ) + + +def test_create_agent_missing_config(mock_model, mock_memory): + """Test create_agent with missing configuration.""" + mock_memory.get.return_value = {"provider": "openai"} + + with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: + mock_ciayn.return_value = "ciayn_agent" + agent = create_agent(mock_model, []) + + assert agent == "ciayn_agent" + mock_ciayn.assert_called_once_with( + mock_model, + [], + max_tokens=DEFAULT_TOKEN_LIMIT, + ) + + +@pytest.fixture +def mock_messages(): + """Fixture providing mock message objects.""" + + return [ + SystemMessage(content="System prompt"), + HumanMessage(content="Human message 1"), + AIMessage(content="AI response 1"), + HumanMessage(content="Human message 2"), + AIMessage(content="AI response 2"), + ] + + +def test_state_modifier(mock_messages): + """Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens.""" + state = AgentState(messages=mock_messages) + + with patch( + "ra_aid.agents.ciayn_agent.CiaynAgent._estimate_tokens" + ) as mock_estimate: + mock_estimate.side_effect = lambda msg: 100 if msg else 0 + + result = state_modifier(state, max_tokens=250) + + assert len(result) < len(mock_messages) + assert isinstance(result[0], SystemMessage) + assert result[-1] == mock_messages[-1] + + +def test_create_agent_with_checkpointer(mock_model, mock_memory): + """Test create_agent with checkpointer argument.""" + mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"} + mock_checkpointer = Mock() + + with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: + mock_ciayn.return_value = "ciayn_agent" + agent = create_agent(mock_model, [], checkpointer=mock_checkpointer) + + assert agent == "ciayn_agent" + mock_ciayn.assert_called_once_with( + mock_model, [], max_tokens=models_tokens["openai"]["gpt-4"] + ) + + +def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory): + """Test create_agent sets up token limiting for Claude models when enabled.""" + mock_memory.get.return_value = { + "provider": "anthropic", + "model": "claude-2", + "limit_tokens": True, + } + + with ( + patch("ra_aid.agent_utils.create_react_agent") as mock_react, + patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, + ): + mock_react.return_value = "react_agent" + mock_limit.return_value = 100000 + + agent = create_agent(mock_model, []) + + assert agent == "react_agent" + args = mock_react.call_args + assert "state_modifier" in args[1] + assert callable(args[1]["state_modifier"]) + + +def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory): + """Test create_agent doesn't set up token limiting for Claude models when disabled.""" + mock_memory.get.return_value = { + "provider": "anthropic", + "model": "claude-2", + "limit_tokens": False, + } + + with ( + patch("ra_aid.agent_utils.create_react_agent") as mock_react, + patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit, + ): + mock_react.return_value = "react_agent" + mock_limit.return_value = 100000 + + agent = create_agent(mock_model, []) + + assert agent == "react_agent" + mock_react.assert_called_once_with(mock_model, [])