diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index becccb9..c2c2eef 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -28,6 +28,7 @@ from ra_aid.prompts import ( from ra_aid.exceptions import TaskCompletedException import time from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError +from ra_aid.llm import initialize_llm def parse_arguments(): parser = argparse.ArgumentParser( @@ -49,18 +50,40 @@ Examples: action='store_true', help='Only perform research without implementation' ) + parser.add_argument( + '--provider', + type=str, + default='anthropic', + choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'], + help='The LLM provider to use' + ) + parser.add_argument( + '--model', + type=str, + help='The model name to use (required for non-Anthropic providers)' + ) parser.add_argument( '--cowboy-mode', action='store_true', help='Skip interactive approval for shell commands' ) - return parser.parse_args() + + args = parser.parse_args() + + # Set default model for Anthropic, require model for other providers + if args.provider == 'anthropic': + if not args.model: + args.model = 'claude-3-5-sonnet-20241022' + elif not args.model: + parser.error(f"--model is required when using provider '{args.provider}'") + + return args # Create console instance console = Console() # Create the base model -model = ChatAnthropic(model_name="claude-3-5-sonnet-20241022") +model = initialize_llm(parse_arguments().provider, parse_arguments().model) # Create individual memory objects for each agent research_memory = MemorySaver() @@ -239,21 +262,30 @@ def validate_environment(): """Validate required environment variables and dependencies.""" missing = [] - # Check API keys - if not os.environ.get('ANTHROPIC_API_KEY'): - missing.append('ANTHROPIC_API_KEY environment variable is not set') - if not os.environ.get('OPENAI_API_KEY'): - missing.append('OPENAI_API_KEY environment variable is not set') - - # Check for aider binary - if not shutil.which('aider'): - missing.append('aider binary not found in PATH. Please install aider: pip install aider') - - if missing: - error_list = "\n".join(f"- {error}" for error in missing) - print_error(f"Missing required dependencies:\n\n{error_list}") - sys.exit(1) + args = parse_arguments() + provider = args.provider + # Check API keys based on provider + if provider == "anthropic": + if not os.environ.get('ANTHROPIC_API_KEY'): + missing.append('ANTHROPIC_API_KEY environment variable is not set') + elif provider == "openai": + if not os.environ.get('OPENAI_API_KEY'): + missing.append('OPENAI_API_KEY environment variable is not set') + elif provider == "openrouter": + if not os.environ.get('OPENROUTER_API_KEY'): + missing.append('OPENROUTER_API_KEY environment variable is not set') + elif provider == "openai-compatible": + if not os.environ.get('OPENAI_API_KEY'): + missing.append('OPENAI_API_KEY environment variable is not set') + if not os.environ.get('OPENAI_API_BASE'): + missing.append('OPENAI_API_BASE environment variable is not set') + + if missing: + print_error("Missing required dependencies:") + for item in missing: + print_error(f"- {item}") + sys.exit(1) def main(): """Main entry point for the ra-aid command line tool.""" diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 88f4593..acaeb50 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -20,7 +20,8 @@ def print_agent_output(chunk: Dict[str, Any]) -> None: # Handle text content if isinstance(msg.content, list): for content in msg.content: - if content['type'] == 'text': + if content['type'] == 'text' and content['text'].strip(): console.print(Panel(Markdown(content['text']), title="🤖 Assistant")) else: - console.print(Panel(Markdown(msg.content), title="🤖 Assistant")) + if msg.content.strip(): + console.print(Panel(Markdown(msg.content), title="🤖 Assistant")) diff --git a/ra_aid/llm.py b/ra_aid/llm.py new file mode 100644 index 0000000..22b153e --- /dev/null +++ b/ra_aid/llm.py @@ -0,0 +1,37 @@ +import os +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_core.language_models import BaseChatModel + +def initialize_llm(provider: str, model_name: str) -> BaseChatModel: + if provider == "openai": + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable is not set.") + return ChatOpenAI(openai_api_key=api_key, model=model_name) + elif provider == "anthropic": + api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key: + raise ValueError("ANTHROPIC_API_KEY environment variable is not set.") + return ChatAnthropic(anthropic_api_key=api_key, model=model_name) + elif provider == "openrouter": + api_key = os.getenv("OPENROUTER_API_KEY") + if not api_key: + raise ValueError("OPENROUTER_API_KEY environment variable is not set.") + return ChatOpenAI( + openai_api_key=api_key, + openai_api_base="https://openrouter.ai/api/v1", + model=model_name + ) + elif provider == "openai-compatible": + api_key = os.getenv("OPENAI_API_KEY") + api_base = os.getenv("OPENAI_API_BASE") + if not api_key or not api_base: + raise ValueError("Both OPENAI_API_KEY and OPENAI_API_BASE environment variables must be set.") + return ChatOpenAI( + openai_api_key=api_key, + openai_api_base=api_base, + model=model_name + ) + else: + raise ValueError(f"Unsupported provider: {provider}") diff --git a/ra_aid/tools/fuzzy_find.py b/ra_aid/tools/fuzzy_find.py index 6ec71c3..c1457ab 100644 --- a/ra_aid/tools/fuzzy_find.py +++ b/ra_aid/tools/fuzzy_find.py @@ -27,8 +27,8 @@ def fuzzy_find_project_files( repo_path: str = ".", threshold: int = 60, max_results: int = 10, - include_paths: Optional[List[str]] = None, - exclude_patterns: Optional[List[str]] = None + include_paths: List[str] = None, + exclude_patterns: List[str] = None ) -> List[Tuple[str, int]]: """Fuzzy find files in a git repository matching the search term. diff --git a/ra_aid/tools/list_directory.py b/ra_aid/tools/list_directory.py index 972c9bd..7d89a0c 100644 --- a/ra_aid/tools/list_directory.py +++ b/ra_aid/tools/list_directory.py @@ -151,7 +151,7 @@ def list_directory_tree( follow_links: bool = False, show_size: bool = False, # Default to not showing size show_modified: bool = False, # Default to not showing modified time - exclude_patterns: Optional[List[str]] = None + exclude_patterns: List[str] = None ) -> str: """List directory contents in a tree format with optional metadata. diff --git a/ra_aid/tools/ripgrep.py b/ra_aid/tools/ripgrep.py index fff7a57..85bf823 100644 --- a/ra_aid/tools/ripgrep.py +++ b/ra_aid/tools/ripgrep.py @@ -28,11 +28,11 @@ DEFAULT_EXCLUDE_DIRS = [ def ripgrep_search( pattern: str, *, - file_type: Optional[str] = None, + file_type: str = None, case_sensitive: bool = True, include_hidden: bool = False, follow_links: bool = False, - exclude_dirs: Optional[List[str]] = None + exclude_dirs: List[str] = None ) -> Dict[str, Union[str, int, bool]]: """Execute a ripgrep (rg) search with formatting and common options. diff --git a/ra_aid/tools/shell.py b/ra_aid/tools/shell.py index 18942e7..9324fe4 100644 --- a/ra_aid/tools/shell.py +++ b/ra_aid/tools/shell.py @@ -30,15 +30,6 @@ def run_shell_command(command: str) -> Dict[str, Union[str, int, bool]]: - Environment: .env, venv, env - IDE: .idea, .vscode 3. Avoid doing recursive lists, finds, etc. that could be slow and have a ton of output. Likewise, avoid flags like '-l' that needlessly increase the output. But if you really need to, you can. - - Args: - command: List of command arguments. First item is the command, rest are arguments. - - Returns: - A dictionary containing: - - output: The command output (stdout + stderr combined) - - return_code: The process return code (0 typically means success) - - success: Boolean indicating if the command succeeded (return code == 0) """ # Check if we need approval cowboy_mode = _global_memory.get('config', {}).get('cowboy_mode', False)