support open models
This commit is contained in:
parent
49078adc10
commit
9f721410ff
|
|
@ -28,6 +28,7 @@ from ra_aid.prompts import (
|
||||||
from ra_aid.exceptions import TaskCompletedException
|
from ra_aid.exceptions import TaskCompletedException
|
||||||
import time
|
import time
|
||||||
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
||||||
|
from ra_aid.llm import initialize_llm
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|
@ -49,18 +50,40 @@ Examples:
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Only perform research without implementation'
|
help='Only perform research without implementation'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--provider',
|
||||||
|
type=str,
|
||||||
|
default='anthropic',
|
||||||
|
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'],
|
||||||
|
help='The LLM provider to use'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--model',
|
||||||
|
type=str,
|
||||||
|
help='The model name to use (required for non-Anthropic providers)'
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--cowboy-mode',
|
'--cowboy-mode',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Skip interactive approval for shell commands'
|
help='Skip interactive approval for shell commands'
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set default model for Anthropic, require model for other providers
|
||||||
|
if args.provider == 'anthropic':
|
||||||
|
if not args.model:
|
||||||
|
args.model = 'claude-3-5-sonnet-20241022'
|
||||||
|
elif not args.model:
|
||||||
|
parser.error(f"--model is required when using provider '{args.provider}'")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
# Create console instance
|
# Create console instance
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
# Create the base model
|
# Create the base model
|
||||||
model = ChatAnthropic(model_name="claude-3-5-sonnet-20241022")
|
model = initialize_llm(parse_arguments().provider, parse_arguments().model)
|
||||||
|
|
||||||
# Create individual memory objects for each agent
|
# Create individual memory objects for each agent
|
||||||
research_memory = MemorySaver()
|
research_memory = MemorySaver()
|
||||||
|
|
@ -239,22 +262,31 @@ def validate_environment():
|
||||||
"""Validate required environment variables and dependencies."""
|
"""Validate required environment variables and dependencies."""
|
||||||
missing = []
|
missing = []
|
||||||
|
|
||||||
# Check API keys
|
args = parse_arguments()
|
||||||
if not os.environ.get('ANTHROPIC_API_KEY'):
|
provider = args.provider
|
||||||
missing.append('ANTHROPIC_API_KEY environment variable is not set')
|
|
||||||
if not os.environ.get('OPENAI_API_KEY'):
|
|
||||||
missing.append('OPENAI_API_KEY environment variable is not set')
|
|
||||||
|
|
||||||
# Check for aider binary
|
# Check API keys based on provider
|
||||||
if not shutil.which('aider'):
|
if provider == "anthropic":
|
||||||
missing.append('aider binary not found in PATH. Please install aider: pip install aider')
|
if not os.environ.get('ANTHROPIC_API_KEY'):
|
||||||
|
missing.append('ANTHROPIC_API_KEY environment variable is not set')
|
||||||
|
elif provider == "openai":
|
||||||
|
if not os.environ.get('OPENAI_API_KEY'):
|
||||||
|
missing.append('OPENAI_API_KEY environment variable is not set')
|
||||||
|
elif provider == "openrouter":
|
||||||
|
if not os.environ.get('OPENROUTER_API_KEY'):
|
||||||
|
missing.append('OPENROUTER_API_KEY environment variable is not set')
|
||||||
|
elif provider == "openai-compatible":
|
||||||
|
if not os.environ.get('OPENAI_API_KEY'):
|
||||||
|
missing.append('OPENAI_API_KEY environment variable is not set')
|
||||||
|
if not os.environ.get('OPENAI_API_BASE'):
|
||||||
|
missing.append('OPENAI_API_BASE environment variable is not set')
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
error_list = "\n".join(f"- {error}" for error in missing)
|
print_error("Missing required dependencies:")
|
||||||
print_error(f"Missing required dependencies:\n\n{error_list}")
|
for item in missing:
|
||||||
|
print_error(f"- {item}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main entry point for the ra-aid command line tool."""
|
"""Main entry point for the ra-aid command line tool."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,8 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
|
||||||
# Handle text content
|
# Handle text content
|
||||||
if isinstance(msg.content, list):
|
if isinstance(msg.content, list):
|
||||||
for content in msg.content:
|
for content in msg.content:
|
||||||
if content['type'] == 'text':
|
if content['type'] == 'text' and content['text'].strip():
|
||||||
console.print(Panel(Markdown(content['text']), title="🤖 Assistant"))
|
console.print(Panel(Markdown(content['text']), title="🤖 Assistant"))
|
||||||
else:
|
else:
|
||||||
console.print(Panel(Markdown(msg.content), title="🤖 Assistant"))
|
if msg.content.strip():
|
||||||
|
console.print(Panel(Markdown(msg.content), title="🤖 Assistant"))
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
import os
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
|
||||||
|
if provider == "openai":
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
||||||
|
return ChatOpenAI(openai_api_key=api_key, model=model_name)
|
||||||
|
elif provider == "anthropic":
|
||||||
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("ANTHROPIC_API_KEY environment variable is not set.")
|
||||||
|
return ChatAnthropic(anthropic_api_key=api_key, model=model_name)
|
||||||
|
elif provider == "openrouter":
|
||||||
|
api_key = os.getenv("OPENROUTER_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
|
||||||
|
return ChatOpenAI(
|
||||||
|
openai_api_key=api_key,
|
||||||
|
openai_api_base="https://openrouter.ai/api/v1",
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
|
elif provider == "openai-compatible":
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
api_base = os.getenv("OPENAI_API_BASE")
|
||||||
|
if not api_key or not api_base:
|
||||||
|
raise ValueError("Both OPENAI_API_KEY and OPENAI_API_BASE environment variables must be set.")
|
||||||
|
return ChatOpenAI(
|
||||||
|
openai_api_key=api_key,
|
||||||
|
openai_api_base=api_base,
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
@ -27,8 +27,8 @@ def fuzzy_find_project_files(
|
||||||
repo_path: str = ".",
|
repo_path: str = ".",
|
||||||
threshold: int = 60,
|
threshold: int = 60,
|
||||||
max_results: int = 10,
|
max_results: int = 10,
|
||||||
include_paths: Optional[List[str]] = None,
|
include_paths: List[str] = None,
|
||||||
exclude_patterns: Optional[List[str]] = None
|
exclude_patterns: List[str] = None
|
||||||
) -> List[Tuple[str, int]]:
|
) -> List[Tuple[str, int]]:
|
||||||
"""Fuzzy find files in a git repository matching the search term.
|
"""Fuzzy find files in a git repository matching the search term.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ def list_directory_tree(
|
||||||
follow_links: bool = False,
|
follow_links: bool = False,
|
||||||
show_size: bool = False, # Default to not showing size
|
show_size: bool = False, # Default to not showing size
|
||||||
show_modified: bool = False, # Default to not showing modified time
|
show_modified: bool = False, # Default to not showing modified time
|
||||||
exclude_patterns: Optional[List[str]] = None
|
exclude_patterns: List[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""List directory contents in a tree format with optional metadata.
|
"""List directory contents in a tree format with optional metadata.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,11 +28,11 @@ DEFAULT_EXCLUDE_DIRS = [
|
||||||
def ripgrep_search(
|
def ripgrep_search(
|
||||||
pattern: str,
|
pattern: str,
|
||||||
*,
|
*,
|
||||||
file_type: Optional[str] = None,
|
file_type: str = None,
|
||||||
case_sensitive: bool = True,
|
case_sensitive: bool = True,
|
||||||
include_hidden: bool = False,
|
include_hidden: bool = False,
|
||||||
follow_links: bool = False,
|
follow_links: bool = False,
|
||||||
exclude_dirs: Optional[List[str]] = None
|
exclude_dirs: List[str] = None
|
||||||
) -> Dict[str, Union[str, int, bool]]:
|
) -> Dict[str, Union[str, int, bool]]:
|
||||||
"""Execute a ripgrep (rg) search with formatting and common options.
|
"""Execute a ripgrep (rg) search with formatting and common options.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,15 +30,6 @@ def run_shell_command(command: str) -> Dict[str, Union[str, int, bool]]:
|
||||||
- Environment: .env, venv, env
|
- Environment: .env, venv, env
|
||||||
- IDE: .idea, .vscode
|
- IDE: .idea, .vscode
|
||||||
3. Avoid doing recursive lists, finds, etc. that could be slow and have a ton of output. Likewise, avoid flags like '-l' that needlessly increase the output. But if you really need to, you can.
|
3. Avoid doing recursive lists, finds, etc. that could be slow and have a ton of output. Likewise, avoid flags like '-l' that needlessly increase the output. But if you really need to, you can.
|
||||||
|
|
||||||
Args:
|
|
||||||
command: List of command arguments. First item is the command, rest are arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing:
|
|
||||||
- output: The command output (stdout + stderr combined)
|
|
||||||
- return_code: The process return code (0 typically means success)
|
|
||||||
- success: Boolean indicating if the command succeeded (return code == 0)
|
|
||||||
"""
|
"""
|
||||||
# Check if we need approval
|
# Check if we need approval
|
||||||
cowboy_mode = _global_memory.get('config', {}).get('cowboy_mode', False)
|
cowboy_mode = _global_memory.get('config', {}).get('cowboy_mode', False)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue