support open models

This commit is contained in:
AI Christianson 2024-12-13 09:44:06 -05:00
parent 49078adc10
commit 9f721410ff
7 changed files with 93 additions and 32 deletions

View File

@ -28,6 +28,7 @@ from ra_aid.prompts import (
from ra_aid.exceptions import TaskCompletedException from ra_aid.exceptions import TaskCompletedException
import time import time
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
from ra_aid.llm import initialize_llm
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -49,18 +50,40 @@ Examples:
action='store_true', action='store_true',
help='Only perform research without implementation' help='Only perform research without implementation'
) )
parser.add_argument(
'--provider',
type=str,
default='anthropic',
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'],
help='The LLM provider to use'
)
parser.add_argument(
'--model',
type=str,
help='The model name to use (required for non-Anthropic providers)'
)
parser.add_argument( parser.add_argument(
'--cowboy-mode', '--cowboy-mode',
action='store_true', action='store_true',
help='Skip interactive approval for shell commands' help='Skip interactive approval for shell commands'
) )
return parser.parse_args()
args = parser.parse_args()
# Set default model for Anthropic, require model for other providers
if args.provider == 'anthropic':
if not args.model:
args.model = 'claude-3-5-sonnet-20241022'
elif not args.model:
parser.error(f"--model is required when using provider '{args.provider}'")
return args
# Create console instance # Create console instance
console = Console() console = Console()
# Create the base model # Create the base model
model = ChatAnthropic(model_name="claude-3-5-sonnet-20241022") model = initialize_llm(parse_arguments().provider, parse_arguments().model)
# Create individual memory objects for each agent # Create individual memory objects for each agent
research_memory = MemorySaver() research_memory = MemorySaver()
@ -239,22 +262,31 @@ def validate_environment():
"""Validate required environment variables and dependencies.""" """Validate required environment variables and dependencies."""
missing = [] missing = []
# Check API keys args = parse_arguments()
provider = args.provider
# Check API keys based on provider
if provider == "anthropic":
if not os.environ.get('ANTHROPIC_API_KEY'): if not os.environ.get('ANTHROPIC_API_KEY'):
missing.append('ANTHROPIC_API_KEY environment variable is not set') missing.append('ANTHROPIC_API_KEY environment variable is not set')
elif provider == "openai":
if not os.environ.get('OPENAI_API_KEY'): if not os.environ.get('OPENAI_API_KEY'):
missing.append('OPENAI_API_KEY environment variable is not set') missing.append('OPENAI_API_KEY environment variable is not set')
elif provider == "openrouter":
# Check for aider binary if not os.environ.get('OPENROUTER_API_KEY'):
if not shutil.which('aider'): missing.append('OPENROUTER_API_KEY environment variable is not set')
missing.append('aider binary not found in PATH. Please install aider: pip install aider') elif provider == "openai-compatible":
if not os.environ.get('OPENAI_API_KEY'):
missing.append('OPENAI_API_KEY environment variable is not set')
if not os.environ.get('OPENAI_API_BASE'):
missing.append('OPENAI_API_BASE environment variable is not set')
if missing: if missing:
error_list = "\n".join(f"- {error}" for error in missing) print_error("Missing required dependencies:")
print_error(f"Missing required dependencies:\n\n{error_list}") for item in missing:
print_error(f"- {item}")
sys.exit(1) sys.exit(1)
def main(): def main():
"""Main entry point for the ra-aid command line tool.""" """Main entry point for the ra-aid command line tool."""
try: try:

View File

@ -20,7 +20,8 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
# Handle text content # Handle text content
if isinstance(msg.content, list): if isinstance(msg.content, list):
for content in msg.content: for content in msg.content:
if content['type'] == 'text': if content['type'] == 'text' and content['text'].strip():
console.print(Panel(Markdown(content['text']), title="🤖 Assistant")) console.print(Panel(Markdown(content['text']), title="🤖 Assistant"))
else: else:
if msg.content.strip():
console.print(Panel(Markdown(msg.content), title="🤖 Assistant")) console.print(Panel(Markdown(msg.content), title="🤖 Assistant"))

37
ra_aid/llm.py Normal file
View File

@ -0,0 +1,37 @@
import os
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
if provider == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set.")
return ChatOpenAI(openai_api_key=api_key, model=model_name)
elif provider == "anthropic":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("ANTHROPIC_API_KEY environment variable is not set.")
return ChatAnthropic(anthropic_api_key=api_key, model=model_name)
elif provider == "openrouter":
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
return ChatOpenAI(
openai_api_key=api_key,
openai_api_base="https://openrouter.ai/api/v1",
model=model_name
)
elif provider == "openai-compatible":
api_key = os.getenv("OPENAI_API_KEY")
api_base = os.getenv("OPENAI_API_BASE")
if not api_key or not api_base:
raise ValueError("Both OPENAI_API_KEY and OPENAI_API_BASE environment variables must be set.")
return ChatOpenAI(
openai_api_key=api_key,
openai_api_base=api_base,
model=model_name
)
else:
raise ValueError(f"Unsupported provider: {provider}")

View File

@ -27,8 +27,8 @@ def fuzzy_find_project_files(
repo_path: str = ".", repo_path: str = ".",
threshold: int = 60, threshold: int = 60,
max_results: int = 10, max_results: int = 10,
include_paths: Optional[List[str]] = None, include_paths: List[str] = None,
exclude_patterns: Optional[List[str]] = None exclude_patterns: List[str] = None
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
"""Fuzzy find files in a git repository matching the search term. """Fuzzy find files in a git repository matching the search term.

View File

@ -151,7 +151,7 @@ def list_directory_tree(
follow_links: bool = False, follow_links: bool = False,
show_size: bool = False, # Default to not showing size show_size: bool = False, # Default to not showing size
show_modified: bool = False, # Default to not showing modified time show_modified: bool = False, # Default to not showing modified time
exclude_patterns: Optional[List[str]] = None exclude_patterns: List[str] = None
) -> str: ) -> str:
"""List directory contents in a tree format with optional metadata. """List directory contents in a tree format with optional metadata.

View File

@ -28,11 +28,11 @@ DEFAULT_EXCLUDE_DIRS = [
def ripgrep_search( def ripgrep_search(
pattern: str, pattern: str,
*, *,
file_type: Optional[str] = None, file_type: str = None,
case_sensitive: bool = True, case_sensitive: bool = True,
include_hidden: bool = False, include_hidden: bool = False,
follow_links: bool = False, follow_links: bool = False,
exclude_dirs: Optional[List[str]] = None exclude_dirs: List[str] = None
) -> Dict[str, Union[str, int, bool]]: ) -> Dict[str, Union[str, int, bool]]:
"""Execute a ripgrep (rg) search with formatting and common options. """Execute a ripgrep (rg) search with formatting and common options.

View File

@ -30,15 +30,6 @@ def run_shell_command(command: str) -> Dict[str, Union[str, int, bool]]:
- Environment: .env, venv, env - Environment: .env, venv, env
- IDE: .idea, .vscode - IDE: .idea, .vscode
3. Avoid doing recursive lists, finds, etc. that could be slow and have a ton of output. Likewise, avoid flags like '-l' that needlessly increase the output. But if you really need to, you can. 3. Avoid doing recursive lists, finds, etc. that could be slow and have a ton of output. Likewise, avoid flags like '-l' that needlessly increase the output. But if you really need to, you can.
Args:
command: List of command arguments. First item is the command, rest are arguments.
Returns:
A dictionary containing:
- output: The command output (stdout + stderr combined)
- return_code: The process return code (0 typically means success)
- success: Boolean indicating if the command succeeded (return code == 0)
""" """
# Check if we need approval # Check if we need approval
cowboy_mode = _global_memory.get('config', {}).get('cowboy_mode', False) cowboy_mode = _global_memory.get('config', {}).get('cowboy_mode', False)