support open models
This commit is contained in:
parent
49078adc10
commit
9f721410ff
|
|
@ -28,6 +28,7 @@ from ra_aid.prompts import (
|
|||
from ra_aid.exceptions import TaskCompletedException
|
||||
import time
|
||||
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
||||
from ra_aid.llm import initialize_llm
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -49,18 +50,40 @@ Examples:
|
|||
action='store_true',
|
||||
help='Only perform research without implementation'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--provider',
|
||||
type=str,
|
||||
default='anthropic',
|
||||
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'],
|
||||
help='The LLM provider to use'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
help='The model name to use (required for non-Anthropic providers)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cowboy-mode',
|
||||
action='store_true',
|
||||
help='Skip interactive approval for shell commands'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set default model for Anthropic, require model for other providers
|
||||
if args.provider == 'anthropic':
|
||||
if not args.model:
|
||||
args.model = 'claude-3-5-sonnet-20241022'
|
||||
elif not args.model:
|
||||
parser.error(f"--model is required when using provider '{args.provider}'")
|
||||
|
||||
return args
|
||||
|
||||
# Create console instance
|
||||
console = Console()
|
||||
|
||||
# Create the base model
|
||||
model = ChatAnthropic(model_name="claude-3-5-sonnet-20241022")
|
||||
model = initialize_llm(parse_arguments().provider, parse_arguments().model)
|
||||
|
||||
# Create individual memory objects for each agent
|
||||
research_memory = MemorySaver()
|
||||
|
|
@ -239,22 +262,31 @@ def validate_environment():
|
|||
"""Validate required environment variables and dependencies."""
|
||||
missing = []
|
||||
|
||||
# Check API keys
|
||||
args = parse_arguments()
|
||||
provider = args.provider
|
||||
|
||||
# Check API keys based on provider
|
||||
if provider == "anthropic":
|
||||
if not os.environ.get('ANTHROPIC_API_KEY'):
|
||||
missing.append('ANTHROPIC_API_KEY environment variable is not set')
|
||||
elif provider == "openai":
|
||||
if not os.environ.get('OPENAI_API_KEY'):
|
||||
missing.append('OPENAI_API_KEY environment variable is not set')
|
||||
|
||||
# Check for aider binary
|
||||
if not shutil.which('aider'):
|
||||
missing.append('aider binary not found in PATH. Please install aider: pip install aider')
|
||||
elif provider == "openrouter":
|
||||
if not os.environ.get('OPENROUTER_API_KEY'):
|
||||
missing.append('OPENROUTER_API_KEY environment variable is not set')
|
||||
elif provider == "openai-compatible":
|
||||
if not os.environ.get('OPENAI_API_KEY'):
|
||||
missing.append('OPENAI_API_KEY environment variable is not set')
|
||||
if not os.environ.get('OPENAI_API_BASE'):
|
||||
missing.append('OPENAI_API_BASE environment variable is not set')
|
||||
|
||||
if missing:
|
||||
error_list = "\n".join(f"- {error}" for error in missing)
|
||||
print_error(f"Missing required dependencies:\n\n{error_list}")
|
||||
print_error("Missing required dependencies:")
|
||||
for item in missing:
|
||||
print_error(f"- {item}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the ra-aid command line tool."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
|
|||
# Handle text content
|
||||
if isinstance(msg.content, list):
|
||||
for content in msg.content:
|
||||
if content['type'] == 'text':
|
||||
if content['type'] == 'text' and content['text'].strip():
|
||||
console.print(Panel(Markdown(content['text']), title="🤖 Assistant"))
|
||||
else:
|
||||
if msg.content.strip():
|
||||
console.print(Panel(Markdown(msg.content), title="🤖 Assistant"))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
import os
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
|
||||
if provider == "openai":
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
||||
return ChatOpenAI(openai_api_key=api_key, model=model_name)
|
||||
elif provider == "anthropic":
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ANTHROPIC_API_KEY environment variable is not set.")
|
||||
return ChatAnthropic(anthropic_api_key=api_key, model=model_name)
|
||||
elif provider == "openrouter":
|
||||
api_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
|
||||
return ChatOpenAI(
|
||||
openai_api_key=api_key,
|
||||
openai_api_base="https://openrouter.ai/api/v1",
|
||||
model=model_name
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
api_base = os.getenv("OPENAI_API_BASE")
|
||||
if not api_key or not api_base:
|
||||
raise ValueError("Both OPENAI_API_KEY and OPENAI_API_BASE environment variables must be set.")
|
||||
return ChatOpenAI(
|
||||
openai_api_key=api_key,
|
||||
openai_api_base=api_base,
|
||||
model=model_name
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
|
@ -27,8 +27,8 @@ def fuzzy_find_project_files(
|
|||
repo_path: str = ".",
|
||||
threshold: int = 60,
|
||||
max_results: int = 10,
|
||||
include_paths: Optional[List[str]] = None,
|
||||
exclude_patterns: Optional[List[str]] = None
|
||||
include_paths: List[str] = None,
|
||||
exclude_patterns: List[str] = None
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""Fuzzy find files in a git repository matching the search term.
|
||||
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ def list_directory_tree(
|
|||
follow_links: bool = False,
|
||||
show_size: bool = False, # Default to not showing size
|
||||
show_modified: bool = False, # Default to not showing modified time
|
||||
exclude_patterns: Optional[List[str]] = None
|
||||
exclude_patterns: List[str] = None
|
||||
) -> str:
|
||||
"""List directory contents in a tree format with optional metadata.
|
||||
|
||||
|
|
|
|||
|
|
@ -28,11 +28,11 @@ DEFAULT_EXCLUDE_DIRS = [
|
|||
def ripgrep_search(
|
||||
pattern: str,
|
||||
*,
|
||||
file_type: Optional[str] = None,
|
||||
file_type: str = None,
|
||||
case_sensitive: bool = True,
|
||||
include_hidden: bool = False,
|
||||
follow_links: bool = False,
|
||||
exclude_dirs: Optional[List[str]] = None
|
||||
exclude_dirs: List[str] = None
|
||||
) -> Dict[str, Union[str, int, bool]]:
|
||||
"""Execute a ripgrep (rg) search with formatting and common options.
|
||||
|
||||
|
|
|
|||
|
|
@ -30,15 +30,6 @@ def run_shell_command(command: str) -> Dict[str, Union[str, int, bool]]:
|
|||
- Environment: .env, venv, env
|
||||
- IDE: .idea, .vscode
|
||||
3. Avoid doing recursive lists, finds, etc. that could be slow and have a ton of output. Likewise, avoid flags like '-l' that needlessly increase the output. But if you really need to, you can.
|
||||
|
||||
Args:
|
||||
command: List of command arguments. First item is the command, rest are arguments.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- output: The command output (stdout + stderr combined)
|
||||
- return_code: The process return code (0 typically means success)
|
||||
- success: Boolean indicating if the command succeeded (return code == 0)
|
||||
"""
|
||||
# Check if we need approval
|
||||
cowboy_mode = _global_memory.get('config', {}).get('cowboy_mode', False)
|
||||
|
|
|
|||
Loading…
Reference in New Issue