Add Deepseek Provider Support and Custom Deepseek Reasoner Chat Model (#50)
* chore: Add DeepSeek provider environment variable support in env.py * feat: Add DeepSeek provider validation strategy in provider_strategy.py * feat: Add support for DEEPSEEK provider in initialize_llm function * feat: Create ChatDeepseekReasoner for custom handling of R1 models * feat: Configure custom OpenAI client for DeepSeek API integration * chore: Remove unused json import from deepseek_chat.py * refactor: Simplify invocation_params and update acompletion_with_retry method * feat: Override _generate to ensure message alternation in DeepseekReasoner * feat: Add support for ChatDeepseekReasoner in LLM initialization * feat: Use custom ChatDeepseekReasoner for DeepSeek models in OpenRouter * fix: Remove redundant condition for DeepSeek model initialization * feat: Add DeepSeek support for expert model initialization in llm.py * feat: Add DeepSeek model handling for OpenRouter in expert LLM initialization * fix: Update model name checks for DeepSeek and OpenRouter providers * refactor: Extract common logic for LLM initialization into reusable methods * test: Add unit tests for DeepSeek and OpenRouter functionality * test: Refactor tests to match updated LLM initialization and helpers * fix: Import missing helper functions to resolve NameError in tests * fix: Resolve NameError and improve environment variable fallback logic * feat(readme): add DeepSeek API key requirements to documentation for better clarity on environment variables feat(main.py): include DeepSeek as a supported provider in argument parsing for enhanced functionality feat(deepseek_chat.py): implement ChatDeepseekReasoner class for handling DeepSeek reasoning models feat(llm.py): add DeepSeek client creation logic to support DeepSeek models in the application feat(models_tokens.py): define token limits for DeepSeek models to manage resource allocation fix(provider_strategy.py): correct validation logic for DeepSeek environment variables to ensure proper configuration chore(memory.py): refactor global memory structure for better readability and maintainability in the codebase
This commit is contained in:
parent
7a68de2d06
commit
686ab42f88
15
README.md
15
README.md
|
|
@ -293,6 +293,7 @@ RA.Aid supports multiple providers through environment variables:
|
||||||
- `ANTHROPIC_API_KEY`: Required for the default Anthropic provider
|
- `ANTHROPIC_API_KEY`: Required for the default Anthropic provider
|
||||||
- `OPENAI_API_KEY`: Required for OpenAI provider
|
- `OPENAI_API_KEY`: Required for OpenAI provider
|
||||||
- `OPENROUTER_API_KEY`: Required for OpenRouter provider
|
- `OPENROUTER_API_KEY`: Required for OpenRouter provider
|
||||||
|
- `DEEPSEEK_API_KEY`: Required for DeepSeek provider
|
||||||
- `OPENAI_API_BASE`: Required for OpenAI-compatible providers along with `OPENAI_API_KEY`
|
- `OPENAI_API_BASE`: Required for OpenAI-compatible providers along with `OPENAI_API_KEY`
|
||||||
- `GEMINI_API_KEY`: Required for Gemini provider
|
- `GEMINI_API_KEY`: Required for Gemini provider
|
||||||
|
|
||||||
|
|
@ -302,6 +303,7 @@ Expert Tool Environment Variables:
|
||||||
- `EXPERT_OPENROUTER_API_KEY`: API key for expert tool using OpenRouter provider
|
- `EXPERT_OPENROUTER_API_KEY`: API key for expert tool using OpenRouter provider
|
||||||
- `EXPERT_OPENAI_API_BASE`: Base URL for expert tool using OpenAI-compatible provider
|
- `EXPERT_OPENAI_API_BASE`: Base URL for expert tool using OpenAI-compatible provider
|
||||||
- `EXPERT_GEMINI_API_KEY`: API key for expert tool using Gemini provider
|
- `EXPERT_GEMINI_API_KEY`: API key for expert tool using Gemini provider
|
||||||
|
- `EXPERT_DEEPSEEK_API_KEY`: API key for expert tool using DeepSeek provider
|
||||||
|
|
||||||
You can set these permanently in your shell's configuration file (e.g., `~/.bashrc` or `~/.zshrc`):
|
You can set these permanently in your shell's configuration file (e.g., `~/.bashrc` or `~/.zshrc`):
|
||||||
|
|
||||||
|
|
@ -343,6 +345,15 @@ export GEMINI_API_KEY=your_api_key_here
|
||||||
ra-aid -m "Your task" --provider openrouter --model mistralai/mistral-large-2411
|
ra-aid -m "Your task" --provider openrouter --model mistralai/mistral-large-2411
|
||||||
```
|
```
|
||||||
|
|
||||||
|
4. **Using DeepSeek**
|
||||||
|
```bash
|
||||||
|
# Direct DeepSeek provider (requires DEEPSEEK_API_KEY)
|
||||||
|
ra-aid -m "Your task" --provider deepseek --model deepseek-reasoner
|
||||||
|
|
||||||
|
# DeepSeek via OpenRouter
|
||||||
|
ra-aid -m "Your task" --provider openrouter --model deepseek/deepseek-r1
|
||||||
|
```
|
||||||
|
|
||||||
4. **Configuring Expert Provider**
|
4. **Configuring Expert Provider**
|
||||||
|
|
||||||
The expert tool is used by the agent for complex logic and debugging tasks. It can be configured to use different providers (OpenAI, Anthropic, OpenRouter, Gemini, openai-compatible) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables.
|
The expert tool is used by the agent for complex logic and debugging tasks. It can be configured to use different providers (OpenAI, Anthropic, OpenRouter, Gemini, openai-compatible) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables.
|
||||||
|
|
@ -356,6 +367,10 @@ export GEMINI_API_KEY=your_api_key_here
|
||||||
export OPENROUTER_API_KEY=your_openrouter_api_key
|
export OPENROUTER_API_KEY=your_openrouter_api_key
|
||||||
ra-aid -m "Your task" --expert-provider openrouter --expert-model mistralai/mistral-large-2411
|
ra-aid -m "Your task" --expert-provider openrouter --expert-model mistralai/mistral-large-2411
|
||||||
|
|
||||||
|
# Use DeepSeek for expert tool
|
||||||
|
export DEEPSEEK_API_KEY=your_deepseek_api_key
|
||||||
|
ra-aid -m "Your task" --expert-provider deepseek --expert-model deepseek-reasoner
|
||||||
|
|
||||||
# Use default OpenAI for expert tool
|
# Use default OpenAI for expert tool
|
||||||
export EXPERT_OPENAI_API_KEY=your_openai_api_key
|
export EXPERT_OPENAI_API_KEY=your_openai_api_key
|
||||||
ra-aid -m "Your task" --expert-provider openai --expert-model o1
|
ra-aid -m "Your task" --expert-provider openai --expert-model o1
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ def parse_arguments(args=None):
|
||||||
"openai",
|
"openai",
|
||||||
"openrouter",
|
"openrouter",
|
||||||
"openai-compatible",
|
"openai-compatible",
|
||||||
|
"deepseek",
|
||||||
"gemini",
|
"gemini",
|
||||||
]
|
]
|
||||||
ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.outputs import ChatResult
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from typing import Any, List, Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
|
# Docs: https://api-docs.deepseek.com/guides/reasoning_model
|
||||||
|
class ChatDeepseekReasoner(ChatOpenAI):
|
||||||
|
"""ChatDeepseekReasoner with custom overrides for R1/reasoner models."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def invocation_params(
|
||||||
|
self, options: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
params = super().invocation_params(options, **kwargs)
|
||||||
|
|
||||||
|
# Remove unsupported params for R1 models
|
||||||
|
params.pop("temperature", None)
|
||||||
|
params.pop("top_p", None)
|
||||||
|
params.pop("presence_penalty", None)
|
||||||
|
params.pop("frequency_penalty", None)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Override _generate to ensure message alternation in accordance to Deepseek API."""
|
||||||
|
|
||||||
|
processed = []
|
||||||
|
prev_role = None
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
current_role = "user" if msg.type == "human" else "assistant"
|
||||||
|
|
||||||
|
if prev_role == current_role:
|
||||||
|
if processed:
|
||||||
|
processed[-1].content += f"\n\n{msg.content}"
|
||||||
|
else:
|
||||||
|
processed.append(msg)
|
||||||
|
prev_role = current_role
|
||||||
|
|
||||||
|
return super()._generate(
|
||||||
|
processed, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
@ -50,6 +50,9 @@ def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
|
||||||
'gemini': {
|
'gemini': {
|
||||||
'GEMINI_API_KEY': 'EXPERT_GEMINI_API_KEY',
|
'GEMINI_API_KEY': 'EXPERT_GEMINI_API_KEY',
|
||||||
'GEMINI_MODEL': 'EXPERT_GEMINI_MODEL'
|
'GEMINI_MODEL': 'EXPERT_GEMINI_MODEL'
|
||||||
|
},
|
||||||
|
'deepseek': {
|
||||||
|
'DEEPSEEK_API_KEY': 'EXPERT_DEEPSEEK_API_KEY'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
217
ra_aid/llm.py
217
ra_aid/llm.py
|
|
@ -1,107 +1,184 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_var(name: str, expert: bool = False) -> Optional[str]:
|
||||||
|
"""Get environment variable with optional expert prefix and fallback."""
|
||||||
|
prefix = "EXPERT_" if expert else ""
|
||||||
|
value = os.getenv(f"{prefix}{name}")
|
||||||
|
|
||||||
def initialize_llm(provider: str, model_name: str, temperature: float | None = None) -> BaseChatModel:
|
# If expert mode and no expert value, fall back to base value
|
||||||
"""Initialize a language model client based on the specified provider and model.
|
if expert and not value:
|
||||||
|
value = os.getenv(name)
|
||||||
|
|
||||||
Note: Environment variables must be validated before calling this function.
|
return value
|
||||||
Use validate_environment() to ensure all required variables are set.
|
|
||||||
|
|
||||||
|
def create_deepseek_client(
|
||||||
|
model_name: str,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
is_expert: bool = False,
|
||||||
|
) -> BaseChatModel:
|
||||||
|
"""Create DeepSeek client with appropriate configuration."""
|
||||||
|
if model_name.lower() == "deepseek-reasoner":
|
||||||
|
return ChatDeepseekReasoner(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
temperature=0
|
||||||
|
if is_expert
|
||||||
|
else (temperature if temperature is not None else 1),
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
temperature=0 if is_expert else (temperature if temperature is not None else 1),
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_openrouter_client(
|
||||||
|
model_name: str,
|
||||||
|
api_key: str,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
is_expert: bool = False,
|
||||||
|
) -> BaseChatModel:
|
||||||
|
"""Create OpenRouter client with appropriate configuration."""
|
||||||
|
if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower():
|
||||||
|
return ChatDeepseekReasoner(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
temperature=0
|
||||||
|
if is_expert
|
||||||
|
else (temperature if temperature is not None else 1),
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
model=model_name,
|
||||||
|
**({"temperature": temperature} if temperature is not None else {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any]:
|
||||||
|
"""Get provider-specific configuration."""
|
||||||
|
configs = {
|
||||||
|
"openai": {
|
||||||
|
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
|
||||||
|
"base_url": None,
|
||||||
|
},
|
||||||
|
"anthropic": {
|
||||||
|
"api_key": get_env_var("ANTHROPIC_API_KEY", is_expert),
|
||||||
|
"base_url": None,
|
||||||
|
},
|
||||||
|
"openrouter": {
|
||||||
|
"api_key": get_env_var("OPENROUTER_API_KEY", is_expert),
|
||||||
|
"base_url": "https://openrouter.ai/api/v1",
|
||||||
|
},
|
||||||
|
"openai-compatible": {
|
||||||
|
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
|
||||||
|
"base_url": get_env_var("OPENAI_API_BASE", is_expert),
|
||||||
|
},
|
||||||
|
"gemini": {
|
||||||
|
"api_key": get_env_var("GEMINI_API_KEY", is_expert),
|
||||||
|
"base_url": None,
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"api_key": get_env_var("DEEPSEEK_API_KEY", is_expert),
|
||||||
|
"base_url": "https://api.deepseek.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return configs.get(provider, {})
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm_client(
|
||||||
|
provider: str,
|
||||||
|
model_name: str,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
is_expert: bool = False,
|
||||||
|
) -> BaseChatModel:
|
||||||
|
"""Create a language model client with appropriate configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini')
|
provider: The LLM provider to use
|
||||||
model_name: Name of the model to use
|
model_name: Name of the model to use
|
||||||
temperature: Optional temperature setting for controlling randomness (0.0-2.0).
|
temperature: Optional temperature setting (0.0-2.0)
|
||||||
If not specified, provider-specific defaults are used.
|
is_expert: Whether this is an expert model (uses deterministic output)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseChatModel: Configured language model client
|
Configured language model client
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the provider is not supported
|
|
||||||
"""
|
"""
|
||||||
if provider == "openai":
|
config = get_provider_config(provider, is_expert)
|
||||||
|
if not config:
|
||||||
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
||||||
|
# Handle temperature for expert mode
|
||||||
|
if is_expert:
|
||||||
|
temperature = 0
|
||||||
|
|
||||||
|
if provider == "deepseek":
|
||||||
|
return create_deepseek_client(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config["api_key"],
|
||||||
|
base_url=config["base_url"],
|
||||||
|
temperature=temperature,
|
||||||
|
is_expert=is_expert,
|
||||||
|
)
|
||||||
|
elif provider == "openrouter":
|
||||||
|
return create_openrouter_client(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config["api_key"],
|
||||||
|
temperature=temperature,
|
||||||
|
is_expert=is_expert,
|
||||||
|
)
|
||||||
|
elif provider == "openai":
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=config["api_key"],
|
||||||
model=model_name,
|
model=model_name,
|
||||||
**({"temperature": temperature} if temperature is not None else {})
|
**({"temperature": temperature} if temperature is not None else {}),
|
||||||
)
|
)
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
return ChatAnthropic(
|
return ChatAnthropic(
|
||||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
api_key=config["api_key"],
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
**({"temperature": temperature} if temperature is not None else {})
|
**({"temperature": temperature} if temperature is not None else {}),
|
||||||
)
|
|
||||||
elif provider == "openrouter":
|
|
||||||
return ChatOpenAI(
|
|
||||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
model=model_name,
|
|
||||||
**({"temperature": temperature} if temperature is not None else {})
|
|
||||||
)
|
)
|
||||||
elif provider == "openai-compatible":
|
elif provider == "openai-compatible":
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=config["api_key"],
|
||||||
base_url=os.getenv("OPENAI_API_BASE"),
|
base_url=config["base_url"],
|
||||||
temperature=temperature if temperature is not None else 0.3,
|
temperature=temperature if temperature is not None else 0.3,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
)
|
)
|
||||||
elif provider == "gemini":
|
elif provider == "gemini":
|
||||||
return ChatGoogleGenerativeAI(
|
return ChatGoogleGenerativeAI(
|
||||||
api_key=os.getenv("GEMINI_API_KEY"),
|
api_key=config["api_key"],
|
||||||
model=model_name,
|
model=model_name,
|
||||||
**({"temperature": temperature} if temperature is not None else {})
|
**({"temperature": temperature} if temperature is not None else {}),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
||||||
def initialize_expert_llm(provider: str = "openai", model_name: str = "o1") -> BaseChatModel:
|
|
||||||
"""Initialize an expert language model client based on the specified provider and model.
|
|
||||||
|
|
||||||
Note: Environment variables must be validated before calling this function.
|
def initialize_llm(
|
||||||
Use validate_environment() to ensure all required variables are set.
|
provider: str, model_name: str, temperature: float | None = None
|
||||||
|
) -> BaseChatModel:
|
||||||
|
"""Initialize a language model client based on the specified provider and model."""
|
||||||
|
return create_llm_client(provider, model_name, temperature, is_expert=False)
|
||||||
|
|
||||||
Args:
|
|
||||||
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini').
|
|
||||||
Defaults to 'openai'.
|
|
||||||
model_name: Name of the model to use. Defaults to 'o1'.
|
|
||||||
|
|
||||||
Returns:
|
def initialize_expert_llm(
|
||||||
BaseChatModel: Configured expert language model client
|
provider: str = "openai", model_name: str = "o1"
|
||||||
|
) -> BaseChatModel:
|
||||||
Raises:
|
"""Initialize an expert language model client based on the specified provider and model."""
|
||||||
ValueError: If the provider is not supported
|
return create_llm_client(provider, model_name, temperature=None, is_expert=True)
|
||||||
"""
|
|
||||||
if provider == "openai":
|
|
||||||
return ChatOpenAI(
|
|
||||||
api_key=os.getenv("EXPERT_OPENAI_API_KEY"),
|
|
||||||
model=model_name,
|
|
||||||
)
|
|
||||||
elif provider == "anthropic":
|
|
||||||
return ChatAnthropic(
|
|
||||||
api_key=os.getenv("EXPERT_ANTHROPIC_API_KEY"),
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
elif provider == "openrouter":
|
|
||||||
return ChatOpenAI(
|
|
||||||
api_key=os.getenv("EXPERT_OPENROUTER_API_KEY"),
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
model=model_name,
|
|
||||||
)
|
|
||||||
elif provider == "openai-compatible":
|
|
||||||
return ChatOpenAI(
|
|
||||||
api_key=os.getenv("EXPERT_OPENAI_API_KEY"),
|
|
||||||
base_url=os.getenv("EXPERT_OPENAI_API_BASE"),
|
|
||||||
model=model_name,
|
|
||||||
)
|
|
||||||
elif provider == "gemini":
|
|
||||||
return ChatGoogleGenerativeAI(
|
|
||||||
api_key=os.getenv("EXPERT_GEMINI_API_KEY"),
|
|
||||||
model=model_name,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
|
||||||
|
|
|
||||||
|
|
@ -241,6 +241,10 @@ models_tokens = {
|
||||||
"deepseek": {
|
"deepseek": {
|
||||||
"deepseek-chat": 28672,
|
"deepseek-chat": 28672,
|
||||||
"deepseek-coder": 16384,
|
"deepseek-coder": 16384,
|
||||||
|
"deepseek-reasoner": 65536,
|
||||||
|
},
|
||||||
|
"openrouter": {
|
||||||
|
"deepseek/deepseek-r1": 65536,
|
||||||
},
|
},
|
||||||
"ernie": {
|
"ernie": {
|
||||||
"ernie-bot-turbo": 4096,
|
"ernie-bot-turbo": 4096,
|
||||||
|
|
|
||||||
|
|
@ -239,6 +239,31 @@ class GeminiStrategy(ProviderStrategy):
|
||||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekStrategy(ProviderStrategy):
|
||||||
|
"""DeepSeek provider validation strategy."""
|
||||||
|
|
||||||
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||||
|
"""Validate DeepSeek environment variables."""
|
||||||
|
missing = []
|
||||||
|
|
||||||
|
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'deepseek':
|
||||||
|
key = os.environ.get('EXPERT_DEEPSEEK_API_KEY')
|
||||||
|
if not key or key == '':
|
||||||
|
# Try to copy from base if not set
|
||||||
|
base_key = os.environ.get('DEEPSEEK_API_KEY')
|
||||||
|
if base_key:
|
||||||
|
os.environ['EXPERT_DEEPSEEK_API_KEY'] = base_key
|
||||||
|
key = base_key
|
||||||
|
if not key:
|
||||||
|
missing.append('EXPERT_DEEPSEEK_API_KEY environment variable is not set')
|
||||||
|
else:
|
||||||
|
key = os.environ.get('DEEPSEEK_API_KEY')
|
||||||
|
if not key:
|
||||||
|
missing.append('DEEPSEEK_API_KEY environment variable is not set')
|
||||||
|
|
||||||
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||||
|
|
||||||
|
|
||||||
class OllamaStrategy(ProviderStrategy):
|
class OllamaStrategy(ProviderStrategy):
|
||||||
"""Ollama provider validation strategy."""
|
"""Ollama provider validation strategy."""
|
||||||
|
|
||||||
|
|
@ -272,7 +297,8 @@ class ProviderFactory:
|
||||||
'anthropic': AnthropicStrategy(),
|
'anthropic': AnthropicStrategy(),
|
||||||
'openrouter': OpenRouterStrategy(),
|
'openrouter': OpenRouterStrategy(),
|
||||||
'gemini': GeminiStrategy(),
|
'gemini': GeminiStrategy(),
|
||||||
'ollama': OllamaStrategy()
|
'ollama': OllamaStrategy(),
|
||||||
|
'deepseek': DeepSeekStrategy()
|
||||||
}
|
}
|
||||||
strategy = strategies.get(provider)
|
strategy = strategies.get(provider)
|
||||||
return strategy
|
return strategy
|
||||||
|
|
|
||||||
|
|
@ -1,123 +1,147 @@
|
||||||
import os
|
import os
|
||||||
import os
|
|
||||||
from typing import Dict, List, Any, Union, Optional, Set
|
from typing import Dict, List, Any, Union, Optional, Set
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
class WorkLogEntry(TypedDict):
|
|
||||||
timestamp: str
|
|
||||||
event: str
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
|
||||||
|
class WorkLogEntry(TypedDict):
|
||||||
|
timestamp: str
|
||||||
|
event: str
|
||||||
|
|
||||||
|
|
||||||
class SnippetInfo(TypedDict):
|
class SnippetInfo(TypedDict):
|
||||||
"""Type definition for source code snippet information"""
|
"""Type definition for source code snippet information"""
|
||||||
|
|
||||||
filepath: str
|
filepath: str
|
||||||
line_number: int
|
line_number: int
|
||||||
snippet: str
|
snippet: str
|
||||||
description: Optional[str]
|
description: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
# Global memory store
|
# Global memory store
|
||||||
_global_memory: Dict[str, Union[List[Any], Dict[int, str], Dict[int, SnippetInfo], int, Set[str], bool, str, int, List[WorkLogEntry]]] = {
|
_global_memory: Dict[
|
||||||
'research_notes': [],
|
str,
|
||||||
'plans': [],
|
Union[
|
||||||
'tasks': {}, # Dict[int, str] - ID to task mapping
|
List[Any],
|
||||||
'task_completed': False, # Flag indicating if task is complete
|
Dict[int, str],
|
||||||
'completion_message': '', # Message explaining completion
|
Dict[int, SnippetInfo],
|
||||||
'task_id_counter': 1, # Counter for generating unique task IDs
|
int,
|
||||||
'key_facts': {}, # Dict[int, str] - ID to fact mapping
|
Set[str],
|
||||||
'key_fact_id_counter': 1, # Counter for generating unique fact IDs
|
bool,
|
||||||
'key_snippets': {}, # Dict[int, SnippetInfo] - ID to snippet mapping
|
str,
|
||||||
'key_snippet_id_counter': 1, # Counter for generating unique snippet IDs
|
int,
|
||||||
'implementation_requested': False,
|
List[WorkLogEntry],
|
||||||
'related_files': {}, # Dict[int, str] - ID to filepath mapping
|
],
|
||||||
'related_file_id_counter': 1, # Counter for generating unique file IDs
|
] = {
|
||||||
'plan_completed': False,
|
"research_notes": [],
|
||||||
'agent_depth': 0,
|
"plans": [],
|
||||||
'work_log': [] # List[WorkLogEntry] - Timestamped work events
|
"tasks": {}, # Dict[int, str] - ID to task mapping
|
||||||
|
"task_completed": False, # Flag indicating if task is complete
|
||||||
|
"completion_message": "", # Message explaining completion
|
||||||
|
"task_id_counter": 1, # Counter for generating unique task IDs
|
||||||
|
"key_facts": {}, # Dict[int, str] - ID to fact mapping
|
||||||
|
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
|
||||||
|
"key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping
|
||||||
|
"key_snippet_id_counter": 1, # Counter for generating unique snippet IDs
|
||||||
|
"implementation_requested": False,
|
||||||
|
"related_files": {}, # Dict[int, str] - ID to filepath mapping
|
||||||
|
"related_file_id_counter": 1, # Counter for generating unique file IDs
|
||||||
|
"plan_completed": False,
|
||||||
|
"agent_depth": 0,
|
||||||
|
"work_log": [], # List[WorkLogEntry] - Timestamped work events
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@tool("emit_research_notes")
|
@tool("emit_research_notes")
|
||||||
def emit_research_notes(notes: str) -> str:
|
def emit_research_notes(notes: str) -> str:
|
||||||
"""Store research notes in global memory.
|
"""Store research notes in global memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
notes: REQUIRED The research notes to store
|
notes: REQUIRED The research notes to store
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The stored notes
|
The stored notes
|
||||||
"""
|
"""
|
||||||
_global_memory['research_notes'].append(notes)
|
_global_memory["research_notes"].append(notes)
|
||||||
console.print(Panel(Markdown(notes), title="🔍 Research Notes"))
|
console.print(Panel(Markdown(notes), title="🔍 Research Notes"))
|
||||||
return notes
|
return notes
|
||||||
|
|
||||||
|
|
||||||
@tool("emit_plan")
|
@tool("emit_plan")
|
||||||
def emit_plan(plan: str) -> str:
|
def emit_plan(plan: str) -> str:
|
||||||
"""Store a plan step in global memory.
|
"""Store a plan step in global memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
plan: The plan step to store (markdown format; be clear, complete, use newlines, and use as many tokens as you need)
|
plan: The plan step to store (markdown format; be clear, complete, use newlines, and use as many tokens as you need)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The stored plan
|
The stored plan
|
||||||
"""
|
"""
|
||||||
_global_memory['plans'].append(plan)
|
_global_memory["plans"].append(plan)
|
||||||
console.print(Panel(Markdown(plan), title="📋 Plan"))
|
console.print(Panel(Markdown(plan), title="📋 Plan"))
|
||||||
log_work_event(f"Added plan step:\n\n{plan}")
|
log_work_event(f"Added plan step:\n\n{plan}")
|
||||||
return plan
|
return plan
|
||||||
|
|
||||||
|
|
||||||
@tool("emit_task")
|
@tool("emit_task")
|
||||||
def emit_task(task: str) -> str:
|
def emit_task(task: str) -> str:
|
||||||
"""Store a task in global memory.
|
"""Store a task in global memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task: The task to store
|
task: The task to store
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
String confirming task storage with ID number
|
String confirming task storage with ID number
|
||||||
"""
|
"""
|
||||||
# Get and increment task ID
|
# Get and increment task ID
|
||||||
task_id = _global_memory['task_id_counter']
|
task_id = _global_memory["task_id_counter"]
|
||||||
_global_memory['task_id_counter'] += 1
|
_global_memory["task_id_counter"] += 1
|
||||||
|
|
||||||
# Store task with ID
|
# Store task with ID
|
||||||
_global_memory['tasks'][task_id] = task
|
_global_memory["tasks"][task_id] = task
|
||||||
|
|
||||||
console.print(Panel(Markdown(task), title=f"✅ Task #{task_id}"))
|
console.print(Panel(Markdown(task), title=f"✅ Task #{task_id}"))
|
||||||
log_work_event(f"Task #{task_id} added:\n\n{task}")
|
log_work_event(f"Task #{task_id} added:\n\n{task}")
|
||||||
return f"Task #{task_id} stored."
|
return f"Task #{task_id} stored."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@tool("emit_key_facts")
|
@tool("emit_key_facts")
|
||||||
def emit_key_facts(facts: List[str]) -> str:
|
def emit_key_facts(facts: List[str]) -> str:
|
||||||
"""Store multiple key facts about the project or current task in global memory.
|
"""Store multiple key facts about the project or current task in global memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
facts: List of key facts to store
|
facts: List of key facts to store
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of stored fact confirmation messages
|
List of stored fact confirmation messages
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for fact in facts:
|
for fact in facts:
|
||||||
# Get and increment fact ID
|
# Get and increment fact ID
|
||||||
fact_id = _global_memory['key_fact_id_counter']
|
fact_id = _global_memory["key_fact_id_counter"]
|
||||||
_global_memory['key_fact_id_counter'] += 1
|
_global_memory["key_fact_id_counter"] += 1
|
||||||
|
|
||||||
# Store fact with ID
|
# Store fact with ID
|
||||||
_global_memory['key_facts'][fact_id] = fact
|
_global_memory["key_facts"][fact_id] = fact
|
||||||
|
|
||||||
# Display panel with ID
|
# Display panel with ID
|
||||||
console.print(Panel(Markdown(fact), title=f"💡 Key Fact #{fact_id}", border_style="bright_cyan"))
|
console.print(
|
||||||
|
Panel(
|
||||||
|
Markdown(fact),
|
||||||
|
title=f"💡 Key Fact #{fact_id}",
|
||||||
|
border_style="bright_cyan",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Add result message
|
# Add result message
|
||||||
results.append(f"Stored fact #{fact_id}: {fact}")
|
results.append(f"Stored fact #{fact_id}: {fact}")
|
||||||
|
|
||||||
log_work_event(f"Stored {len(facts)} key facts.")
|
log_work_event(f"Stored {len(facts)} key facts.")
|
||||||
return "Facts stored."
|
return "Facts stored."
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -125,50 +149,54 @@ def emit_key_facts(facts: List[str]) -> str:
|
||||||
def delete_key_facts(fact_ids: List[int]) -> str:
|
def delete_key_facts(fact_ids: List[int]) -> str:
|
||||||
"""Delete multiple key facts from global memory by their IDs.
|
"""Delete multiple key facts from global memory by their IDs.
|
||||||
Silently skips any IDs that don't exist.
|
Silently skips any IDs that don't exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fact_ids: List of fact IDs to delete
|
fact_ids: List of fact IDs to delete
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of success messages for deleted facts
|
List of success messages for deleted facts
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for fact_id in fact_ids:
|
for fact_id in fact_ids:
|
||||||
if fact_id in _global_memory['key_facts']:
|
if fact_id in _global_memory["key_facts"]:
|
||||||
# Delete the fact
|
# Delete the fact
|
||||||
deleted_fact = _global_memory['key_facts'].pop(fact_id)
|
deleted_fact = _global_memory["key_facts"].pop(fact_id)
|
||||||
success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}"
|
success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}"
|
||||||
console.print(Panel(Markdown(success_msg), title="Fact Deleted", border_style="green"))
|
console.print(
|
||||||
|
Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")
|
||||||
|
)
|
||||||
results.append(success_msg)
|
results.append(success_msg)
|
||||||
|
|
||||||
log_work_event(f"Deleted facts {fact_ids}.")
|
log_work_event(f"Deleted facts {fact_ids}.")
|
||||||
return "Facts deleted."
|
return "Facts deleted."
|
||||||
|
|
||||||
|
|
||||||
@tool("delete_tasks")
|
@tool("delete_tasks")
|
||||||
def delete_tasks(task_ids: List[int]) -> str:
|
def delete_tasks(task_ids: List[int]) -> str:
|
||||||
"""Delete multiple tasks from global memory by their IDs.
|
"""Delete multiple tasks from global memory by their IDs.
|
||||||
Silently skips any IDs that don't exist.
|
Silently skips any IDs that don't exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_ids: List of task IDs to delete
|
task_ids: List of task IDs to delete
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Confirmation message
|
Confirmation message
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for task_id in task_ids:
|
for task_id in task_ids:
|
||||||
if task_id in _global_memory['tasks']:
|
if task_id in _global_memory["tasks"]:
|
||||||
# Delete the task
|
# Delete the task
|
||||||
deleted_task = _global_memory['tasks'].pop(task_id)
|
deleted_task = _global_memory["tasks"].pop(task_id)
|
||||||
success_msg = f"Successfully deleted task #{task_id}: {deleted_task}"
|
success_msg = f"Successfully deleted task #{task_id}: {deleted_task}"
|
||||||
console.print(Panel(Markdown(success_msg),
|
console.print(
|
||||||
title="Task Deleted",
|
Panel(Markdown(success_msg), title="Task Deleted", border_style="green")
|
||||||
border_style="green"))
|
)
|
||||||
results.append(success_msg)
|
results.append(success_msg)
|
||||||
|
|
||||||
log_work_event(f"Deleted tasks {task_ids}.")
|
log_work_event(f"Deleted tasks {task_ids}.")
|
||||||
return "Tasks deleted."
|
return "Tasks deleted."
|
||||||
|
|
||||||
|
|
||||||
@tool("request_implementation")
|
@tool("request_implementation")
|
||||||
def request_implementation() -> str:
|
def request_implementation() -> str:
|
||||||
"""Request that implementation proceed after research/planning.
|
"""Request that implementation proceed after research/planning.
|
||||||
|
|
@ -178,46 +206,47 @@ def request_implementation() -> str:
|
||||||
Do you need to request research subtasks first?
|
Do you need to request research subtasks first?
|
||||||
Have you run relevant unit tests, if they exist, to get a baseline (this can be a subtask)?
|
Have you run relevant unit tests, if they exist, to get a baseline (this can be a subtask)?
|
||||||
Do you need to crawl deeper to find all related files and symbols?
|
Do you need to crawl deeper to find all related files and symbols?
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Empty string
|
Empty string
|
||||||
"""
|
"""
|
||||||
_global_memory['implementation_requested'] = True
|
_global_memory["implementation_requested"] = True
|
||||||
console.print(Panel("🚀 Implementation Requested", style="yellow", padding=0))
|
console.print(Panel("🚀 Implementation Requested", style="yellow", padding=0))
|
||||||
log_work_event("Implementation requested.")
|
log_work_event("Implementation requested.")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@tool("emit_key_snippets")
|
@tool("emit_key_snippets")
|
||||||
def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
|
def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
|
||||||
"""Store multiple key source code snippets in global memory.
|
"""Store multiple key source code snippets in global memory.
|
||||||
Automatically adds the filepaths of the snippets to related files.
|
Automatically adds the filepaths of the snippets to related files.
|
||||||
|
|
||||||
This is for **existing**, or **just-written** files, not for things to be created in the future.
|
This is for **existing**, or **just-written** files, not for things to be created in the future.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
snippets: REQUIRED List of snippet information dictionaries containing:
|
snippets: REQUIRED List of snippet information dictionaries containing:
|
||||||
- filepath: Path to the source file
|
- filepath: Path to the source file
|
||||||
- line_number: Line number where the snippet starts
|
- line_number: Line number where the snippet starts
|
||||||
- snippet: The source code snippet text
|
- snippet: The source code snippet text
|
||||||
- description: Optional description of the significance
|
- description: Optional description of the significance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of stored snippet confirmation messages
|
List of stored snippet confirmation messages
|
||||||
"""
|
"""
|
||||||
# First collect unique filepaths to add as related files
|
# First collect unique filepaths to add as related files
|
||||||
emit_related_files.invoke({"files": [snippet_info['filepath'] for snippet_info in snippets]})
|
emit_related_files.invoke(
|
||||||
|
{"files": [snippet_info["filepath"] for snippet_info in snippets]}
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for snippet_info in snippets:
|
for snippet_info in snippets:
|
||||||
# Get and increment snippet ID
|
# Get and increment snippet ID
|
||||||
snippet_id = _global_memory['key_snippet_id_counter']
|
snippet_id = _global_memory["key_snippet_id_counter"]
|
||||||
_global_memory['key_snippet_id_counter'] += 1
|
_global_memory["key_snippet_id_counter"] += 1
|
||||||
|
|
||||||
# Store snippet info
|
# Store snippet info
|
||||||
_global_memory['key_snippets'][snippet_id] = snippet_info
|
_global_memory["key_snippets"][snippet_id] = snippet_info
|
||||||
|
|
||||||
# Format display text as markdown
|
# Format display text as markdown
|
||||||
display_text = [
|
display_text = [
|
||||||
f"**Source Location**:",
|
f"**Source Location**:",
|
||||||
|
|
@ -226,153 +255,170 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
|
||||||
"", # Empty line before code block
|
"", # Empty line before code block
|
||||||
"**Code**:",
|
"**Code**:",
|
||||||
"```python",
|
"```python",
|
||||||
snippet_info['snippet'].rstrip(), # Remove trailing whitespace
|
snippet_info["snippet"].rstrip(), # Remove trailing whitespace
|
||||||
"```"
|
"```",
|
||||||
]
|
]
|
||||||
if snippet_info['description']:
|
if snippet_info["description"]:
|
||||||
display_text.extend(["", "**Description**:", snippet_info['description']])
|
display_text.extend(["", "**Description**:", snippet_info["description"]])
|
||||||
|
|
||||||
# Display panel
|
# Display panel
|
||||||
console.print(Panel(Markdown("\n".join(display_text)),
|
console.print(
|
||||||
title=f"📝 Key Snippet #{snippet_id}",
|
Panel(
|
||||||
border_style="bright_cyan"))
|
Markdown("\n".join(display_text)),
|
||||||
|
title=f"📝 Key Snippet #{snippet_id}",
|
||||||
|
border_style="bright_cyan",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
results.append(f"Stored snippet #{snippet_id}")
|
results.append(f"Stored snippet #{snippet_id}")
|
||||||
|
|
||||||
log_work_event(f"Stored {len(snippets)} code snippets.")
|
log_work_event(f"Stored {len(snippets)} code snippets.")
|
||||||
return "Snippets stored."
|
return "Snippets stored."
|
||||||
|
|
||||||
@tool("delete_key_snippets")
|
|
||||||
|
@tool("delete_key_snippets")
|
||||||
def delete_key_snippets(snippet_ids: List[int]) -> str:
|
def delete_key_snippets(snippet_ids: List[int]) -> str:
|
||||||
"""Delete multiple key snippets from global memory by their IDs.
|
"""Delete multiple key snippets from global memory by their IDs.
|
||||||
Silently skips any IDs that don't exist.
|
Silently skips any IDs that don't exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
snippet_ids: List of snippet IDs to delete
|
snippet_ids: List of snippet IDs to delete
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of success messages for deleted snippets
|
List of success messages for deleted snippets
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for snippet_id in snippet_ids:
|
for snippet_id in snippet_ids:
|
||||||
if snippet_id in _global_memory['key_snippets']:
|
if snippet_id in _global_memory["key_snippets"]:
|
||||||
# Delete the snippet
|
# Delete the snippet
|
||||||
deleted_snippet = _global_memory['key_snippets'].pop(snippet_id)
|
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
|
||||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}"
|
success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}"
|
||||||
console.print(Panel(Markdown(success_msg),
|
console.print(
|
||||||
title="Snippet Deleted",
|
Panel(
|
||||||
border_style="green"))
|
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
||||||
|
)
|
||||||
|
)
|
||||||
results.append(success_msg)
|
results.append(success_msg)
|
||||||
|
|
||||||
log_work_event(f"Deleted snippets {snippet_ids}.")
|
log_work_event(f"Deleted snippets {snippet_ids}.")
|
||||||
return "Snippets deleted."
|
return "Snippets deleted."
|
||||||
|
|
||||||
|
|
||||||
@tool("swap_task_order")
|
@tool("swap_task_order")
|
||||||
def swap_task_order(id1: int, id2: int) -> str:
|
def swap_task_order(id1: int, id2: int) -> str:
|
||||||
"""Swap the order of two tasks in global memory by their IDs.
|
"""Swap the order of two tasks in global memory by their IDs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id1: First task ID
|
id1: First task ID
|
||||||
id2: Second task ID
|
id2: Second task ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Success or error message depending on outcome
|
Success or error message depending on outcome
|
||||||
"""
|
"""
|
||||||
# Validate IDs are different
|
# Validate IDs are different
|
||||||
if id1 == id2:
|
if id1 == id2:
|
||||||
return "Cannot swap task with itself"
|
return "Cannot swap task with itself"
|
||||||
|
|
||||||
# Validate both IDs exist
|
# Validate both IDs exist
|
||||||
if id1 not in _global_memory['tasks'] or id2 not in _global_memory['tasks']:
|
if id1 not in _global_memory["tasks"] or id2 not in _global_memory["tasks"]:
|
||||||
return "Invalid task ID(s)"
|
return "Invalid task ID(s)"
|
||||||
|
|
||||||
# Swap the tasks
|
# Swap the tasks
|
||||||
_global_memory['tasks'][id1], _global_memory['tasks'][id2] = \
|
_global_memory["tasks"][id1], _global_memory["tasks"][id2] = (
|
||||||
_global_memory['tasks'][id2], _global_memory['tasks'][id1]
|
_global_memory["tasks"][id2],
|
||||||
|
_global_memory["tasks"][id1],
|
||||||
|
)
|
||||||
|
|
||||||
# Display what was swapped
|
# Display what was swapped
|
||||||
console.print(Panel(
|
console.print(
|
||||||
Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"),
|
Panel(
|
||||||
title="🔄 Tasks Reordered",
|
Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"),
|
||||||
border_style="green"
|
title="🔄 Tasks Reordered",
|
||||||
))
|
border_style="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return "Tasks swapped."
|
return "Tasks swapped."
|
||||||
|
|
||||||
|
|
||||||
@tool("one_shot_completed")
|
@tool("one_shot_completed")
|
||||||
def one_shot_completed(message: str) -> str:
|
def one_shot_completed(message: str) -> str:
|
||||||
"""Signal that a one-shot task has been completed and execution should stop.
|
"""Signal that a one-shot task has been completed and execution should stop.
|
||||||
|
|
||||||
Only call this if you have already **fully** completed the original request.
|
Only call this if you have already **fully** completed the original request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: Completion message to display
|
message: Completion message to display
|
||||||
"""
|
"""
|
||||||
if _global_memory.get('implementation_requested', False):
|
if _global_memory.get("implementation_requested", False):
|
||||||
return "Cannot complete in one shot - implementation was requested"
|
return "Cannot complete in one shot - implementation was requested"
|
||||||
|
|
||||||
_global_memory['task_completed'] = True
|
_global_memory["task_completed"] = True
|
||||||
_global_memory['completion_message'] = message
|
_global_memory["completion_message"] = message
|
||||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||||
log_work_event(f"Task completed\n\n{message}")
|
log_work_event(f"Task completed\n\n{message}")
|
||||||
return "Completion noted."
|
return "Completion noted."
|
||||||
|
|
||||||
|
|
||||||
@tool("task_completed")
|
@tool("task_completed")
|
||||||
def task_completed(message: str) -> str:
|
def task_completed(message: str) -> str:
|
||||||
"""Mark the current task as completed with a completion message.
|
"""Mark the current task as completed with a completion message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: Message explaining how/why the task is complete
|
message: Message explaining how/why the task is complete
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The completion message
|
The completion message
|
||||||
"""
|
"""
|
||||||
_global_memory['task_completed'] = True
|
_global_memory["task_completed"] = True
|
||||||
_global_memory['completion_message'] = message
|
_global_memory["completion_message"] = message
|
||||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||||
return "Completion noted."
|
return "Completion noted."
|
||||||
|
|
||||||
|
|
||||||
@tool("plan_implementation_completed")
|
@tool("plan_implementation_completed")
|
||||||
def plan_implementation_completed(message: str) -> str:
|
def plan_implementation_completed(message: str) -> str:
|
||||||
"""Mark the entire implementation plan as completed.
|
"""Mark the entire implementation plan as completed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: Message explaining how the implementation plan was completed
|
message: Message explaining how the implementation plan was completed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Confirmation message
|
Confirmation message
|
||||||
"""
|
"""
|
||||||
_global_memory['plan_completed'] = True
|
_global_memory["plan_completed"] = True
|
||||||
_global_memory['completion_message'] = message
|
_global_memory["completion_message"] = message
|
||||||
_global_memory['tasks'].clear() # Clear task list when plan is completed
|
_global_memory["tasks"].clear() # Clear task list when plan is completed
|
||||||
_global_memory['task_id_counter'] = 1
|
_global_memory["task_id_counter"] = 1
|
||||||
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
||||||
log_work_event(f"Plan execution completed:\n\n{message}")
|
log_work_event(f"Plan execution completed:\n\n{message}")
|
||||||
return "Plan completion noted and task list cleared."
|
return "Plan completion noted and task list cleared."
|
||||||
|
|
||||||
|
|
||||||
def get_related_files() -> List[str]:
|
def get_related_files() -> List[str]:
|
||||||
"""Get the current list of related files.
|
"""Get the current list of related files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of formatted strings in the format 'ID#X path/to/file.py'
|
List of formatted strings in the format 'ID#X path/to/file.py'
|
||||||
"""
|
"""
|
||||||
files = _global_memory['related_files']
|
files = _global_memory["related_files"]
|
||||||
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(files.items())]
|
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(files.items())]
|
||||||
|
|
||||||
|
|
||||||
@tool("emit_related_files")
|
@tool("emit_related_files")
|
||||||
def emit_related_files(files: List[str]) -> str:
|
def emit_related_files(files: List[str]) -> str:
|
||||||
"""Store multiple related files that tools should work with.
|
"""Store multiple related files that tools should work with.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
files: List of file paths to add
|
files: List of file paths to add
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted string containing file IDs and paths for all processed files
|
Formatted string containing file IDs and paths for all processed files
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
added_files = []
|
added_files = []
|
||||||
invalid_paths = []
|
invalid_paths = []
|
||||||
|
|
||||||
# Process files
|
# Process files
|
||||||
for file in files:
|
for file in files:
|
||||||
# First check if path exists
|
# First check if path exists
|
||||||
|
|
@ -380,111 +426,115 @@ def emit_related_files(files: List[str]) -> str:
|
||||||
invalid_paths.append(file)
|
invalid_paths.append(file)
|
||||||
results.append(f"Error: Path '{file}' does not exist")
|
results.append(f"Error: Path '{file}' does not exist")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Then check if it's a directory
|
# Then check if it's a directory
|
||||||
if os.path.isdir(file):
|
if os.path.isdir(file):
|
||||||
invalid_paths.append(file)
|
invalid_paths.append(file)
|
||||||
results.append(f"Error: Path '{file}' is a directory, not a file")
|
results.append(f"Error: Path '{file}' is a directory, not a file")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Finally validate it's a regular file
|
# Finally validate it's a regular file
|
||||||
if not os.path.isfile(file):
|
if not os.path.isfile(file):
|
||||||
invalid_paths.append(file)
|
invalid_paths.append(file)
|
||||||
results.append(f"Error: Path '{file}' exists but is not a regular file")
|
results.append(f"Error: Path '{file}' exists but is not a regular file")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if file path already exists in values
|
# Check if file path already exists in values
|
||||||
existing_id = None
|
existing_id = None
|
||||||
for fid, fpath in _global_memory['related_files'].items():
|
for fid, fpath in _global_memory["related_files"].items():
|
||||||
if fpath == file:
|
if fpath == file:
|
||||||
existing_id = fid
|
existing_id = fid
|
||||||
break
|
break
|
||||||
|
|
||||||
if existing_id is not None:
|
if existing_id is not None:
|
||||||
# File exists, use existing ID
|
# File exists, use existing ID
|
||||||
results.append(f"File ID #{existing_id}: {file}")
|
results.append(f"File ID #{existing_id}: {file}")
|
||||||
else:
|
else:
|
||||||
# New file, assign new ID
|
# New file, assign new ID
|
||||||
file_id = _global_memory['related_file_id_counter']
|
file_id = _global_memory["related_file_id_counter"]
|
||||||
_global_memory['related_file_id_counter'] += 1
|
_global_memory["related_file_id_counter"] += 1
|
||||||
|
|
||||||
# Store file with ID
|
# Store file with ID
|
||||||
_global_memory['related_files'][file_id] = file
|
_global_memory["related_files"][file_id] = file
|
||||||
added_files.append((file_id, file))
|
added_files.append((file_id, file))
|
||||||
results.append(f"File ID #{file_id}: {file}")
|
results.append(f"File ID #{file_id}: {file}")
|
||||||
|
|
||||||
# Rich output - single consolidated panel
|
# Rich output - single consolidated panel
|
||||||
if added_files:
|
if added_files:
|
||||||
files_added_md = '\n'.join(f"- `{file}`" for id, file in added_files)
|
files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
|
||||||
md_content = f"**Files Noted:**\n{files_added_md}"
|
md_content = f"**Files Noted:**\n{files_added_md}"
|
||||||
console.print(Panel(Markdown(md_content),
|
console.print(
|
||||||
title="📁 Related Files Noted",
|
Panel(
|
||||||
border_style="green"))
|
Markdown(md_content),
|
||||||
|
title="📁 Related Files Noted",
|
||||||
return '\n'.join(results)
|
border_style="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(results)
|
||||||
|
|
||||||
|
|
||||||
def log_work_event(event: str) -> str:
|
def log_work_event(event: str) -> str:
|
||||||
"""Add timestamped entry to work log.
|
"""Add timestamped entry to work log.
|
||||||
|
|
||||||
Internal function used to track major events during agent execution.
|
Internal function used to track major events during agent execution.
|
||||||
Each entry is stored with an ISO format timestamp.
|
Each entry is stored with an ISO format timestamp.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event: Description of the event to log
|
event: Description of the event to log
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Confirmation message
|
Confirmation message
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Entries can be retrieved with get_work_log() as markdown formatted text.
|
Entries can be retrieved with get_work_log() as markdown formatted text.
|
||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
entry = WorkLogEntry(
|
|
||||||
timestamp=datetime.now().isoformat(),
|
entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event)
|
||||||
event=event
|
_global_memory["work_log"].append(entry)
|
||||||
)
|
|
||||||
_global_memory['work_log'].append(entry)
|
|
||||||
return f"Event logged: {event}"
|
return f"Event logged: {event}"
|
||||||
|
|
||||||
|
|
||||||
def get_work_log() -> str:
|
def get_work_log() -> str:
|
||||||
"""Return formatted markdown of work log entries.
|
"""Return formatted markdown of work log entries.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Markdown formatted text with timestamps as headings and events as content,
|
Markdown formatted text with timestamps as headings and events as content,
|
||||||
or 'No work log entries' if the log is empty.
|
or 'No work log entries' if the log is empty.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
## 2024-12-23T11:39:10
|
## 2024-12-23T11:39:10
|
||||||
|
|
||||||
Task #1 added: Create login form
|
Task #1 added: Create login form
|
||||||
"""
|
"""
|
||||||
if not _global_memory['work_log']:
|
if not _global_memory["work_log"]:
|
||||||
return "No work log entries"
|
return "No work log entries"
|
||||||
|
|
||||||
entries = []
|
entries = []
|
||||||
for entry in _global_memory['work_log']:
|
for entry in _global_memory["work_log"]:
|
||||||
entries.extend([
|
entries.extend(
|
||||||
f"## {entry['timestamp']}",
|
[
|
||||||
"",
|
f"## {entry['timestamp']}",
|
||||||
entry['event'],
|
"",
|
||||||
"" # Blank line between entries
|
entry["event"],
|
||||||
])
|
"", # Blank line between entries
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return "\n".join(entries).rstrip() # Remove trailing newline
|
return "\n".join(entries).rstrip() # Remove trailing newline
|
||||||
|
|
||||||
|
|
||||||
def reset_work_log() -> str:
|
def reset_work_log() -> str:
|
||||||
"""Clear the work log.
|
"""Clear the work log.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Confirmation message
|
Confirmation message
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This permanently removes all work log entries. The operation cannot be undone.
|
This permanently removes all work log entries. The operation cannot be undone.
|
||||||
"""
|
"""
|
||||||
_global_memory['work_log'].clear()
|
_global_memory["work_log"].clear()
|
||||||
return "Work log cleared"
|
return "Work log cleared"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -492,37 +542,44 @@ def reset_work_log() -> str:
|
||||||
def deregister_related_files(file_ids: List[int]) -> str:
|
def deregister_related_files(file_ids: List[int]) -> str:
|
||||||
"""Delete multiple related files from global memory by their IDs.
|
"""Delete multiple related files from global memory by their IDs.
|
||||||
Silently skips any IDs that don't exist.
|
Silently skips any IDs that don't exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_ids: List of file IDs to delete
|
file_ids: List of file IDs to delete
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Success message string
|
Success message string
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
if file_id in _global_memory['related_files']:
|
if file_id in _global_memory["related_files"]:
|
||||||
# Delete the file reference
|
# Delete the file reference
|
||||||
deleted_file = _global_memory['related_files'].pop(file_id)
|
deleted_file = _global_memory["related_files"].pop(file_id)
|
||||||
success_msg = f"Successfully removed related file #{file_id}: {deleted_file}"
|
success_msg = (
|
||||||
console.print(Panel(Markdown(success_msg),
|
f"Successfully removed related file #{file_id}: {deleted_file}"
|
||||||
title="File Reference Removed",
|
)
|
||||||
border_style="green"))
|
console.print(
|
||||||
|
Panel(
|
||||||
|
Markdown(success_msg),
|
||||||
|
title="File Reference Removed",
|
||||||
|
border_style="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
results.append(success_msg)
|
results.append(success_msg)
|
||||||
|
|
||||||
return "File references removed."
|
return "File references removed."
|
||||||
|
|
||||||
|
|
||||||
def get_memory_value(key: str) -> str:
|
def get_memory_value(key: str) -> str:
|
||||||
"""Get a value from global memory.
|
"""Get a value from global memory.
|
||||||
|
|
||||||
Different memory types return different formats:
|
Different memory types return different formats:
|
||||||
- key_facts: Returns numbered list of facts in format '#ID: fact'
|
- key_facts: Returns numbered list of facts in format '#ID: fact'
|
||||||
- key_snippets: Returns formatted snippets with file path, line number and content
|
- key_snippets: Returns formatted snippets with file path, line number and content
|
||||||
- All other types: Returns newline-separated list of values
|
- All other types: Returns newline-separated list of values
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The key to get from memory
|
key: The key to get from memory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
String representation of the memory values:
|
String representation of the memory values:
|
||||||
- For key_facts: '#ID: fact' format, one per line
|
- For key_facts: '#ID: fact' format, one per line
|
||||||
|
|
@ -530,23 +587,25 @@ def get_memory_value(key: str) -> str:
|
||||||
- For other types: One value per line
|
- For other types: One value per line
|
||||||
"""
|
"""
|
||||||
values = _global_memory.get(key, [])
|
values = _global_memory.get(key, [])
|
||||||
|
|
||||||
if key == 'key_facts':
|
if key == "key_facts":
|
||||||
# For empty dict, return empty string
|
# For empty dict, return empty string
|
||||||
if not values:
|
if not values:
|
||||||
return ""
|
return ""
|
||||||
# Sort by ID for consistent output and format as markdown sections
|
# Sort by ID for consistent output and format as markdown sections
|
||||||
facts = []
|
facts = []
|
||||||
for k, v in sorted(values.items()):
|
for k, v in sorted(values.items()):
|
||||||
facts.extend([
|
facts.extend(
|
||||||
f"## 🔑 Key Fact #{k}",
|
[
|
||||||
"", # Empty line for better markdown spacing
|
f"## 🔑 Key Fact #{k}",
|
||||||
v,
|
"", # Empty line for better markdown spacing
|
||||||
"" # Empty line between facts
|
v,
|
||||||
])
|
"", # Empty line between facts
|
||||||
|
]
|
||||||
|
)
|
||||||
return "\n".join(facts).rstrip() # Remove trailing newline
|
return "\n".join(facts).rstrip() # Remove trailing newline
|
||||||
|
|
||||||
if key == 'key_snippets':
|
if key == "key_snippets":
|
||||||
if not values:
|
if not values:
|
||||||
return ""
|
return ""
|
||||||
# Format each snippet with file info and content using markdown
|
# Format each snippet with file info and content using markdown
|
||||||
|
|
@ -555,26 +614,25 @@ def get_memory_value(key: str) -> str:
|
||||||
snippet_text = [
|
snippet_text = [
|
||||||
f"## 📝 Code Snippet #{k}",
|
f"## 📝 Code Snippet #{k}",
|
||||||
"", # Empty line for better markdown spacing
|
"", # Empty line for better markdown spacing
|
||||||
f"**Source Location**:",
|
"**Source Location**:",
|
||||||
f"- File: `{v['filepath']}`",
|
f"- File: `{v['filepath']}`",
|
||||||
f"- Line: `{v['line_number']}`",
|
f"- Line: `{v['line_number']}`",
|
||||||
"", # Empty line before code block
|
"", # Empty line before code block
|
||||||
"**Code**:",
|
"**Code**:",
|
||||||
"```python",
|
"```python",
|
||||||
v['snippet'].rstrip(), # Remove trailing whitespace
|
v["snippet"].rstrip(), # Remove trailing whitespace
|
||||||
"```"
|
"```",
|
||||||
]
|
]
|
||||||
if v['description']:
|
if v["description"]:
|
||||||
# Add empty line and description
|
# Add empty line and description
|
||||||
snippet_text.extend(["", "**Description**:", v['description']])
|
snippet_text.extend(["", "**Description**:", v["description"]])
|
||||||
snippets.append("\n".join(snippet_text))
|
snippets.append("\n".join(snippet_text))
|
||||||
return "\n\n".join(snippets)
|
return "\n\n".join(snippets)
|
||||||
|
|
||||||
if key == 'work_log':
|
if key == "work_log":
|
||||||
if not values:
|
if not values:
|
||||||
return ""
|
return ""
|
||||||
entries = [f"## {entry['timestamp']}\n{entry['event']}"
|
entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values]
|
||||||
for entry in values]
|
|
||||||
return "\n\n".join(entries)
|
return "\n\n".join(entries)
|
||||||
|
|
||||||
# For other types (lists), join with newlines
|
# For other types (lists), join with newlines
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,13 @@ from dataclasses import dataclass
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
|
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.llm import initialize_llm, initialize_expert_llm
|
from ra_aid.llm import (
|
||||||
|
initialize_llm,
|
||||||
|
initialize_expert_llm,
|
||||||
|
get_env_var,
|
||||||
|
get_provider_config,
|
||||||
|
create_llm_client
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def clean_env(monkeypatch):
|
def clean_env(monkeypatch):
|
||||||
|
|
@ -39,7 +45,8 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch):
|
||||||
|
|
||||||
mock_openai.assert_called_once_with(
|
mock_openai.assert_called_once_with(
|
||||||
api_key="test-key",
|
api_key="test-key",
|
||||||
model="o1"
|
model="o1",
|
||||||
|
temperature=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
||||||
|
|
@ -49,7 +56,8 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
||||||
|
|
||||||
mock_openai.assert_called_once_with(
|
mock_openai.assert_called_once_with(
|
||||||
api_key="test-key",
|
api_key="test-key",
|
||||||
model="gpt-4-preview"
|
model="gpt-4-preview",
|
||||||
|
temperature=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
|
def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
|
||||||
|
|
@ -59,7 +67,8 @@ def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
|
||||||
|
|
||||||
mock_gemini.assert_called_once_with(
|
mock_gemini.assert_called_once_with(
|
||||||
api_key="test-key",
|
api_key="test-key",
|
||||||
model="gemini-2.0-flash-thinking-exp-1219"
|
model="gemini-2.0-flash-thinking-exp-1219",
|
||||||
|
temperature=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
|
def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
|
||||||
|
|
@ -69,7 +78,8 @@ def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
|
||||||
|
|
||||||
mock_anthropic.assert_called_once_with(
|
mock_anthropic.assert_called_once_with(
|
||||||
api_key="test-key",
|
api_key="test-key",
|
||||||
model_name="claude-3"
|
model_name="claude-3",
|
||||||
|
temperature=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
|
def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
|
||||||
|
|
@ -80,7 +90,8 @@ def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
|
||||||
mock_openai.assert_called_once_with(
|
mock_openai.assert_called_once_with(
|
||||||
api_key="test-key",
|
api_key="test-key",
|
||||||
base_url="https://openrouter.ai/api/v1",
|
base_url="https://openrouter.ai/api/v1",
|
||||||
model="models/mistral-large"
|
model="models/mistral-large",
|
||||||
|
temperature=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch):
|
def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch):
|
||||||
|
|
@ -92,7 +103,8 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
|
||||||
mock_openai.assert_called_once_with(
|
mock_openai.assert_called_once_with(
|
||||||
api_key="test-key",
|
api_key="test-key",
|
||||||
base_url="http://test-url",
|
base_url="http://test-url",
|
||||||
model="local-model"
|
model="local-model",
|
||||||
|
temperature=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initialize_expert_unsupported_provider(clean_env):
|
def test_initialize_expert_unsupported_provider(clean_env):
|
||||||
|
|
@ -311,45 +323,49 @@ def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, m
|
||||||
model="gemini-2.0-flash-thinking-exp-1219"
|
model="gemini-2.0-flash-thinking-exp-1219"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Args:
|
||||||
|
"""Test arguments class."""
|
||||||
|
provider: str
|
||||||
|
expert_provider: str
|
||||||
|
model: str = None
|
||||||
|
expert_model: str = None
|
||||||
|
|
||||||
def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
|
def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
|
||||||
"""Test environment variable precedence and fallback."""
|
"""Test environment variable precedence and fallback."""
|
||||||
from ra_aid.env import validate_environment
|
# Test get_env_var helper with fallback
|
||||||
from dataclasses import dataclass
|
monkeypatch.setenv("TEST_KEY", "base-value")
|
||||||
|
monkeypatch.setenv("EXPERT_TEST_KEY", "expert-value")
|
||||||
|
|
||||||
@dataclass
|
assert get_env_var("TEST_KEY") == "base-value"
|
||||||
class Args:
|
assert get_env_var("TEST_KEY", expert=True) == "expert-value"
|
||||||
provider: str
|
|
||||||
expert_provider: str
|
# Test fallback when expert value not set
|
||||||
model: str = None
|
monkeypatch.delenv("EXPERT_TEST_KEY", raising=False)
|
||||||
expert_model: str = None
|
assert get_env_var("TEST_KEY", expert=True) == "base-value"
|
||||||
|
|
||||||
# Test expert mode with explicit key
|
# Test provider config
|
||||||
# Set up base environment first
|
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "base-key")
|
|
||||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key")
|
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key")
|
||||||
monkeypatch.setenv("TAVILY_API_KEY", "tavily-key")
|
config = get_provider_config("openai", is_expert=True)
|
||||||
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
|
assert config["api_key"] == "expert-key"
|
||||||
args = Args(provider="openai", expert_provider="openai")
|
|
||||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
# Test LLM client creation with expert mode
|
||||||
assert expert_enabled
|
llm = create_llm_client("openai", "o1", is_expert=True)
|
||||||
assert not expert_missing
|
|
||||||
assert web_enabled
|
|
||||||
assert not web_missing
|
|
||||||
|
|
||||||
llm = initialize_expert_llm()
|
|
||||||
mock_openai.assert_called_with(
|
mock_openai.assert_called_with(
|
||||||
api_key="expert-key",
|
api_key="expert-key",
|
||||||
model="o1"
|
model="o1",
|
||||||
|
temperature=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test empty key validation
|
# Test environment validation
|
||||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "")
|
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "")
|
||||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research
|
monkeypatch.delenv("TAVILY_API_KEY", raising=False)
|
||||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key")
|
||||||
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") # Add for provider validation
|
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
|
||||||
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation
|
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307")
|
||||||
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
|
|
||||||
|
args = Args(provider="anthropic", expert_provider="openai")
|
||||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||||
assert not expert_enabled
|
assert not expert_enabled
|
||||||
assert expert_missing
|
assert expert_missing
|
||||||
|
|
@ -372,3 +388,122 @@ def mock_gemini():
|
||||||
with patch('ra_aid.llm.ChatGoogleGenerativeAI') as mock:
|
with patch('ra_aid.llm.ChatGoogleGenerativeAI') as mock:
|
||||||
mock.return_value = Mock(spec=ChatGoogleGenerativeAI)
|
mock.return_value = Mock(spec=ChatGoogleGenerativeAI)
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_deepseek_reasoner():
|
||||||
|
"""Mock ChatDeepseekReasoner for testing DeepSeek provider initialization."""
|
||||||
|
with patch('ra_aid.llm.ChatDeepseekReasoner') as mock:
|
||||||
|
mock.return_value = Mock()
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def test_initialize_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
|
||||||
|
"""Test DeepSeek provider initialization with different models."""
|
||||||
|
monkeypatch.setenv("DEEPSEEK_API_KEY", "test-key")
|
||||||
|
|
||||||
|
# Test with reasoner model
|
||||||
|
model = initialize_llm("deepseek", "deepseek-reasoner")
|
||||||
|
mock_deepseek_reasoner.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
temperature=1,
|
||||||
|
model="deepseek-reasoner"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with non-reasoner model
|
||||||
|
model = initialize_llm("deepseek", "deepseek-chat")
|
||||||
|
mock_openai.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
temperature=1,
|
||||||
|
model="deepseek-chat"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialize_expert_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
|
||||||
|
"""Test expert DeepSeek provider initialization."""
|
||||||
|
monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "test-key")
|
||||||
|
|
||||||
|
# Test with reasoner model
|
||||||
|
model = initialize_expert_llm("deepseek", "deepseek-reasoner")
|
||||||
|
mock_deepseek_reasoner.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
temperature=0,
|
||||||
|
model="deepseek-reasoner"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with non-reasoner model
|
||||||
|
model = initialize_expert_llm("deepseek", "deepseek-chat")
|
||||||
|
mock_openai.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
temperature=0,
|
||||||
|
model="deepseek-chat"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialize_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
|
||||||
|
"""Test OpenRouter DeepSeek model initialization."""
|
||||||
|
monkeypatch.setenv("OPENROUTER_API_KEY", "test-key")
|
||||||
|
|
||||||
|
# Test with DeepSeek R1 model
|
||||||
|
model = initialize_llm("openrouter", "deepseek/deepseek-r1")
|
||||||
|
mock_deepseek_reasoner.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
temperature=1,
|
||||||
|
model="deepseek/deepseek-r1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with non-DeepSeek model
|
||||||
|
model = initialize_llm("openrouter", "mistral/mistral-large")
|
||||||
|
mock_openai.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
model="mistral/mistral-large"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialize_expert_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
|
||||||
|
"""Test expert OpenRouter DeepSeek model initialization."""
|
||||||
|
monkeypatch.setenv("EXPERT_OPENROUTER_API_KEY", "test-key")
|
||||||
|
|
||||||
|
# Test with DeepSeek R1 model via create_llm_client
|
||||||
|
model = create_llm_client("openrouter", "deepseek/deepseek-r1", is_expert=True)
|
||||||
|
mock_deepseek_reasoner.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
temperature=0,
|
||||||
|
model="deepseek/deepseek-r1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with non-DeepSeek model
|
||||||
|
model = create_llm_client("openrouter", "mistral/mistral-large", is_expert=True)
|
||||||
|
mock_openai.assert_called_with(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
model="mistral/mistral-large",
|
||||||
|
temperature=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_deepseek_environment_fallback(clean_env, mock_deepseek_reasoner, monkeypatch):
|
||||||
|
"""Test DeepSeek environment variable fallback behavior."""
|
||||||
|
# Test environment variable helper with fallback
|
||||||
|
monkeypatch.setenv("DEEPSEEK_API_KEY", "base-key")
|
||||||
|
assert get_env_var("DEEPSEEK_API_KEY", expert=True) == "base-key"
|
||||||
|
|
||||||
|
# Test provider config with fallback
|
||||||
|
config = get_provider_config("deepseek", is_expert=True)
|
||||||
|
assert config["api_key"] == "base-key"
|
||||||
|
assert config["base_url"] == "https://api.deepseek.com"
|
||||||
|
|
||||||
|
# Test with expert key
|
||||||
|
monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "expert-key")
|
||||||
|
config = get_provider_config("deepseek", is_expert=True)
|
||||||
|
assert config["api_key"] == "expert-key"
|
||||||
|
|
||||||
|
# Test client creation with expert key
|
||||||
|
model = create_llm_client("deepseek", "deepseek-reasoner", is_expert=True)
|
||||||
|
mock_deepseek_reasoner.assert_called_with(
|
||||||
|
api_key="expert-key",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
temperature=0,
|
||||||
|
model="deepseek-reasoner"
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue