From 250bf0a84c7a9be60ccb11f3700d76454254d01c Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Fri, 13 Dec 2024 13:31:22 -0500 Subject: [PATCH] let aider figure out which model to use --- ra_aid/__main__.py | 46 +++++++++++++++++++++--------------- ra_aid/llm.py | 48 ++++++++++++++++++++++---------------- ra_aid/tools/programmer.py | 1 - 3 files changed, 55 insertions(+), 40 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index b1db5e6..22eff1e 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -82,9 +82,6 @@ Examples: # Create console instance console = Console() -# Create the base model -model = initialize_llm(parse_arguments().provider, parse_arguments().model) - # Create individual memory objects for each agent research_memory = MemorySaver() planning_memory = MemorySaver() @@ -125,10 +122,6 @@ def get_research_tools(research_only: bool = False) -> list: planning_tools = [list_directory_tree, emit_expert_context, ask_expert, emit_plan, emit_task, emit_related_files, emit_key_facts, delete_key_facts, emit_key_snippets, delete_key_snippets, read_file_tool, fuzzy_find_project_files, ripgrep_search] implementation_tools = [list_directory_tree, run_shell_command, emit_expert_context, ask_expert, run_programming_task, emit_related_files, emit_key_facts, delete_key_facts, emit_key_snippets, delete_key_snippets, read_file_tool, fuzzy_find_project_files, ripgrep_search] -# Create stage-specific agents with individual memory objects -planning_agent = create_react_agent(model, planning_tools, checkpointer=planning_memory) -implementation_agent = create_react_agent(model, implementation_tools, checkpointer=implementation_memory) - def is_informational_query() -> bool: """Determine if the current query is informational based on implementation_requested state. @@ -188,7 +181,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict): time.sleep(delay) continue -def run_implementation_stage(base_task, tasks, plan, related_files): +def run_implementation_stage(base_task, tasks, plan, related_files, model): """Run implementation stage with a distinct agent for each task.""" if not is_stage_requested('implementation'): print_stage_header("Implementation Stage Skipped") @@ -224,7 +217,7 @@ def run_implementation_stage(base_task, tasks, plan, related_files): run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100}) -def run_research_subtasks(base_task: str, config: dict): +def run_research_subtasks(base_task: str, config: dict, model): """Run research subtasks with separate agents.""" subtasks = _global_memory.get('research_subtasks', []) if not subtasks: @@ -255,11 +248,13 @@ def run_research_subtasks(base_task: str, config: dict): run_agent_with_retry(subtask_agent, subtask_prompt, config) -def validate_environment(): - """Validate required environment variables and dependencies.""" - missing = [] +def validate_environment(args): + """Validate required environment variables and dependencies. - args = parse_arguments() + Args: + args: The parsed command line arguments + """ + missing = [] provider = args.provider # Check API keys based on provider @@ -288,8 +283,11 @@ def main(): """Main entry point for the ra-aid command line tool.""" try: try: - validate_environment() args = parse_arguments() + validate_environment(args) # Will exit if env vars missing + + # Create the base model after validation + model = initialize_llm(args.provider, args.model) # Validate message is provided if not args.message: @@ -309,11 +307,16 @@ def main(): # Store config in global memory for access by is_informational_query _global_memory['config'] = config - # Create research agent now that config is available - research_agent = create_react_agent(model, get_research_tools(research_only=_global_memory.get('config', {}).get('research_only', False)), checkpointer=research_memory) - # Run research stage print_stage_header("Research Stage") + + # Create research agent with local model + research_agent = create_react_agent( + model, + get_research_tools(research_only=_global_memory.get('config', {}).get('research_only', False)), + checkpointer=research_memory + ) + research_prompt = f"""User query: {base_task} --keep it simple {RESEARCH_PROMPT} @@ -327,11 +330,15 @@ Be very thorough in your research and emit lots of snippets, key facts. If you t raise # Re-raise to be caught by outer handler # Run any research subtasks - run_research_subtasks(base_task, config) + run_research_subtasks(base_task, config, model) # Proceed with planning and implementation if not an informational query if not is_informational_query(): print_stage_header("Planning Stage") + + # Create planning agent + planning_agent = create_react_agent(model, planning_tools, checkpointer=planning_memory) + planning_prompt = PLANNING_PROMPT.format( research_notes=get_memory_value('research_notes'), key_facts=get_memory_value('key_facts'), @@ -348,7 +355,8 @@ Be very thorough in your research and emit lots of snippets, key facts. If you t base_task, get_memory_value('tasks'), get_memory_value('plan'), - get_related_files() + get_related_files(), + model ) except TaskCompletedException: sys.exit(0) diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 22b153e..b94cbd8 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -4,33 +4,41 @@ from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel def initialize_llm(provider: str, model_name: str) -> BaseChatModel: + """Initialize a language model client based on the specified provider and model. + + Note: Environment variables must be validated before calling this function. + Use validate_environment() to ensure all required variables are set. + + Args: + provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible') + model_name: Name of the model to use + + Returns: + BaseChatModel: Configured language model client + + Raises: + ValueError: If the provider is not supported + """ if provider == "openai": - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise ValueError("OPENAI_API_KEY environment variable is not set.") - return ChatOpenAI(openai_api_key=api_key, model=model_name) - elif provider == "anthropic": - api_key = os.getenv("ANTHROPIC_API_KEY") - if not api_key: - raise ValueError("ANTHROPIC_API_KEY environment variable is not set.") - return ChatAnthropic(anthropic_api_key=api_key, model=model_name) - elif provider == "openrouter": - api_key = os.getenv("OPENROUTER_API_KEY") - if not api_key: - raise ValueError("OPENROUTER_API_KEY environment variable is not set.") return ChatOpenAI( - openai_api_key=api_key, + openai_api_key=os.getenv("OPENAI_API_KEY"), + model=model_name + ) + elif provider == "anthropic": + return ChatAnthropic( + anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"), + model=model_name + ) + elif provider == "openrouter": + return ChatOpenAI( + openai_api_key=os.getenv("OPENROUTER_API_KEY"), openai_api_base="https://openrouter.ai/api/v1", model=model_name ) elif provider == "openai-compatible": - api_key = os.getenv("OPENAI_API_KEY") - api_base = os.getenv("OPENAI_API_BASE") - if not api_key or not api_base: - raise ValueError("Both OPENAI_API_KEY and OPENAI_API_BASE environment variables must be set.") return ChatOpenAI( - openai_api_key=api_key, - openai_api_base=api_base, + openai_api_key=os.getenv("OPENAI_API_KEY"), + openai_api_base=os.getenv("OPENAI_API_BASE"), model=model_name ) else: diff --git a/ra_aid/tools/programmer.py b/ra_aid/tools/programmer.py index a27ea5b..1f893d0 100644 --- a/ra_aid/tools/programmer.py +++ b/ra_aid/tools/programmer.py @@ -42,7 +42,6 @@ def run_programming_task(input: RunProgrammingTaskInput) -> Dict[str, Union[str, # Build command command = [ "aider", - "--sonnet", "--yes-always", "--no-auto-commits", "--dark-mode",