let aider figure out which model to use
This commit is contained in:
parent
ed04230267
commit
250bf0a84c
|
|
@ -82,9 +82,6 @@ Examples:
|
||||||
# Create console instance
|
# Create console instance
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
# Create the base model
|
|
||||||
model = initialize_llm(parse_arguments().provider, parse_arguments().model)
|
|
||||||
|
|
||||||
# Create individual memory objects for each agent
|
# Create individual memory objects for each agent
|
||||||
research_memory = MemorySaver()
|
research_memory = MemorySaver()
|
||||||
planning_memory = MemorySaver()
|
planning_memory = MemorySaver()
|
||||||
|
|
@ -125,10 +122,6 @@ def get_research_tools(research_only: bool = False) -> list:
|
||||||
planning_tools = [list_directory_tree, emit_expert_context, ask_expert, emit_plan, emit_task, emit_related_files, emit_key_facts, delete_key_facts, emit_key_snippets, delete_key_snippets, read_file_tool, fuzzy_find_project_files, ripgrep_search]
|
planning_tools = [list_directory_tree, emit_expert_context, ask_expert, emit_plan, emit_task, emit_related_files, emit_key_facts, delete_key_facts, emit_key_snippets, delete_key_snippets, read_file_tool, fuzzy_find_project_files, ripgrep_search]
|
||||||
implementation_tools = [list_directory_tree, run_shell_command, emit_expert_context, ask_expert, run_programming_task, emit_related_files, emit_key_facts, delete_key_facts, emit_key_snippets, delete_key_snippets, read_file_tool, fuzzy_find_project_files, ripgrep_search]
|
implementation_tools = [list_directory_tree, run_shell_command, emit_expert_context, ask_expert, run_programming_task, emit_related_files, emit_key_facts, delete_key_facts, emit_key_snippets, delete_key_snippets, read_file_tool, fuzzy_find_project_files, ripgrep_search]
|
||||||
|
|
||||||
# Create stage-specific agents with individual memory objects
|
|
||||||
planning_agent = create_react_agent(model, planning_tools, checkpointer=planning_memory)
|
|
||||||
implementation_agent = create_react_agent(model, implementation_tools, checkpointer=implementation_memory)
|
|
||||||
|
|
||||||
|
|
||||||
def is_informational_query() -> bool:
|
def is_informational_query() -> bool:
|
||||||
"""Determine if the current query is informational based on implementation_requested state.
|
"""Determine if the current query is informational based on implementation_requested state.
|
||||||
|
|
@ -188,7 +181,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict):
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def run_implementation_stage(base_task, tasks, plan, related_files):
|
def run_implementation_stage(base_task, tasks, plan, related_files, model):
|
||||||
"""Run implementation stage with a distinct agent for each task."""
|
"""Run implementation stage with a distinct agent for each task."""
|
||||||
if not is_stage_requested('implementation'):
|
if not is_stage_requested('implementation'):
|
||||||
print_stage_header("Implementation Stage Skipped")
|
print_stage_header("Implementation Stage Skipped")
|
||||||
|
|
@ -224,7 +217,7 @@ def run_implementation_stage(base_task, tasks, plan, related_files):
|
||||||
run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100})
|
run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100})
|
||||||
|
|
||||||
|
|
||||||
def run_research_subtasks(base_task: str, config: dict):
|
def run_research_subtasks(base_task: str, config: dict, model):
|
||||||
"""Run research subtasks with separate agents."""
|
"""Run research subtasks with separate agents."""
|
||||||
subtasks = _global_memory.get('research_subtasks', [])
|
subtasks = _global_memory.get('research_subtasks', [])
|
||||||
if not subtasks:
|
if not subtasks:
|
||||||
|
|
@ -255,11 +248,13 @@ def run_research_subtasks(base_task: str, config: dict):
|
||||||
run_agent_with_retry(subtask_agent, subtask_prompt, config)
|
run_agent_with_retry(subtask_agent, subtask_prompt, config)
|
||||||
|
|
||||||
|
|
||||||
def validate_environment():
|
def validate_environment(args):
|
||||||
"""Validate required environment variables and dependencies."""
|
"""Validate required environment variables and dependencies.
|
||||||
missing = []
|
|
||||||
|
|
||||||
args = parse_arguments()
|
Args:
|
||||||
|
args: The parsed command line arguments
|
||||||
|
"""
|
||||||
|
missing = []
|
||||||
provider = args.provider
|
provider = args.provider
|
||||||
|
|
||||||
# Check API keys based on provider
|
# Check API keys based on provider
|
||||||
|
|
@ -288,8 +283,11 @@ def main():
|
||||||
"""Main entry point for the ra-aid command line tool."""
|
"""Main entry point for the ra-aid command line tool."""
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
validate_environment()
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
validate_environment(args) # Will exit if env vars missing
|
||||||
|
|
||||||
|
# Create the base model after validation
|
||||||
|
model = initialize_llm(args.provider, args.model)
|
||||||
|
|
||||||
# Validate message is provided
|
# Validate message is provided
|
||||||
if not args.message:
|
if not args.message:
|
||||||
|
|
@ -309,11 +307,16 @@ def main():
|
||||||
# Store config in global memory for access by is_informational_query
|
# Store config in global memory for access by is_informational_query
|
||||||
_global_memory['config'] = config
|
_global_memory['config'] = config
|
||||||
|
|
||||||
# Create research agent now that config is available
|
|
||||||
research_agent = create_react_agent(model, get_research_tools(research_only=_global_memory.get('config', {}).get('research_only', False)), checkpointer=research_memory)
|
|
||||||
|
|
||||||
# Run research stage
|
# Run research stage
|
||||||
print_stage_header("Research Stage")
|
print_stage_header("Research Stage")
|
||||||
|
|
||||||
|
# Create research agent with local model
|
||||||
|
research_agent = create_react_agent(
|
||||||
|
model,
|
||||||
|
get_research_tools(research_only=_global_memory.get('config', {}).get('research_only', False)),
|
||||||
|
checkpointer=research_memory
|
||||||
|
)
|
||||||
|
|
||||||
research_prompt = f"""User query: {base_task} --keep it simple
|
research_prompt = f"""User query: {base_task} --keep it simple
|
||||||
|
|
||||||
{RESEARCH_PROMPT}
|
{RESEARCH_PROMPT}
|
||||||
|
|
@ -327,11 +330,15 @@ Be very thorough in your research and emit lots of snippets, key facts. If you t
|
||||||
raise # Re-raise to be caught by outer handler
|
raise # Re-raise to be caught by outer handler
|
||||||
|
|
||||||
# Run any research subtasks
|
# Run any research subtasks
|
||||||
run_research_subtasks(base_task, config)
|
run_research_subtasks(base_task, config, model)
|
||||||
|
|
||||||
# Proceed with planning and implementation if not an informational query
|
# Proceed with planning and implementation if not an informational query
|
||||||
if not is_informational_query():
|
if not is_informational_query():
|
||||||
print_stage_header("Planning Stage")
|
print_stage_header("Planning Stage")
|
||||||
|
|
||||||
|
# Create planning agent
|
||||||
|
planning_agent = create_react_agent(model, planning_tools, checkpointer=planning_memory)
|
||||||
|
|
||||||
planning_prompt = PLANNING_PROMPT.format(
|
planning_prompt = PLANNING_PROMPT.format(
|
||||||
research_notes=get_memory_value('research_notes'),
|
research_notes=get_memory_value('research_notes'),
|
||||||
key_facts=get_memory_value('key_facts'),
|
key_facts=get_memory_value('key_facts'),
|
||||||
|
|
@ -348,7 +355,8 @@ Be very thorough in your research and emit lots of snippets, key facts. If you t
|
||||||
base_task,
|
base_task,
|
||||||
get_memory_value('tasks'),
|
get_memory_value('tasks'),
|
||||||
get_memory_value('plan'),
|
get_memory_value('plan'),
|
||||||
get_related_files()
|
get_related_files(),
|
||||||
|
model
|
||||||
)
|
)
|
||||||
except TaskCompletedException:
|
except TaskCompletedException:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
|
||||||
|
|
@ -4,33 +4,41 @@ from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
|
def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
|
||||||
|
"""Initialize a language model client based on the specified provider and model.
|
||||||
|
|
||||||
|
Note: Environment variables must be validated before calling this function.
|
||||||
|
Use validate_environment() to ensure all required variables are set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible')
|
||||||
|
model_name: Name of the model to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseChatModel: Configured language model client
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provider is not supported
|
||||||
|
"""
|
||||||
if provider == "openai":
|
if provider == "openai":
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
|
||||||
return ChatOpenAI(openai_api_key=api_key, model=model_name)
|
|
||||||
elif provider == "anthropic":
|
|
||||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError("ANTHROPIC_API_KEY environment variable is not set.")
|
|
||||||
return ChatAnthropic(anthropic_api_key=api_key, model=model_name)
|
|
||||||
elif provider == "openrouter":
|
|
||||||
api_key = os.getenv("OPENROUTER_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
|
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
openai_api_key=api_key,
|
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
|
elif provider == "anthropic":
|
||||||
|
return ChatAnthropic(
|
||||||
|
anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
|
elif provider == "openrouter":
|
||||||
|
return ChatOpenAI(
|
||||||
|
openai_api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||||
openai_api_base="https://openrouter.ai/api/v1",
|
openai_api_base="https://openrouter.ai/api/v1",
|
||||||
model=model_name
|
model=model_name
|
||||||
)
|
)
|
||||||
elif provider == "openai-compatible":
|
elif provider == "openai-compatible":
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
api_base = os.getenv("OPENAI_API_BASE")
|
|
||||||
if not api_key or not api_base:
|
|
||||||
raise ValueError("Both OPENAI_API_KEY and OPENAI_API_BASE environment variables must be set.")
|
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
openai_api_key=api_key,
|
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
openai_api_base=api_base,
|
openai_api_base=os.getenv("OPENAI_API_BASE"),
|
||||||
model=model_name
|
model=model_name
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,6 @@ def run_programming_task(input: RunProgrammingTaskInput) -> Dict[str, Union[str,
|
||||||
# Build command
|
# Build command
|
||||||
command = [
|
command = [
|
||||||
"aider",
|
"aider",
|
||||||
"--sonnet",
|
|
||||||
"--yes-always",
|
"--yes-always",
|
||||||
"--no-auto-commits",
|
"--no-auto-commits",
|
||||||
"--dark-mode",
|
"--dark-mode",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue