expert/web enabled based on config
This commit is contained in:
parent
1dc9326154
commit
8d44ba0824
|
|
@ -44,8 +44,6 @@ class SpawnAgentRequest(BaseModel):
|
|||
Attributes:
|
||||
message: The message or task for the agent to process
|
||||
research_only: Whether to use research-only mode (default: False)
|
||||
expert_enabled: Whether to enable expert assistance (default: True)
|
||||
web_research_enabled: Whether to enable web research (default: False)
|
||||
"""
|
||||
message: str = Field(
|
||||
description="The message or task for the agent to process"
|
||||
|
|
@ -54,14 +52,6 @@ class SpawnAgentRequest(BaseModel):
|
|||
default=False,
|
||||
description="Whether to use research-only mode"
|
||||
)
|
||||
expert_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable expert assistance"
|
||||
)
|
||||
web_research_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable web research"
|
||||
)
|
||||
|
||||
class SpawnAgentResponse(BaseModel):
|
||||
"""
|
||||
|
|
@ -93,8 +83,6 @@ def run_agent_thread(
|
|||
message: str,
|
||||
session_id: str,
|
||||
research_only: bool = False,
|
||||
expert_enabled: bool = True,
|
||||
web_research_enabled: bool = False,
|
||||
):
|
||||
"""
|
||||
Run a research agent in a separate thread with proper repository initialization.
|
||||
|
|
@ -103,8 +91,10 @@ def run_agent_thread(
|
|||
message: The message or task for the agent to process
|
||||
session_id: The ID of the session to associate with this agent
|
||||
research_only: Whether to use research-only mode
|
||||
expert_enabled: Whether to enable expert assistance
|
||||
web_research_enabled: Whether to enable web research
|
||||
|
||||
Note:
|
||||
Values for expert_enabled and web_research_enabled are retrieved from the
|
||||
config repository, which stores the values set during server startup.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting agent thread for session {session_id}")
|
||||
|
|
@ -133,11 +123,15 @@ def run_agent_thread(
|
|||
# Import here to avoid circular imports
|
||||
from ra_aid.__main__ import run_research_agent
|
||||
|
||||
# Get the provider and model from config repository
|
||||
# Get configuration values from config repository
|
||||
provider = get_config_repository().get("provider", "anthropic")
|
||||
model_name = get_config_repository().get("model", "claude-3-7-sonnet-20250219")
|
||||
temperature = get_config_repository().get("temperature")
|
||||
|
||||
# Get expert_enabled and web_research_enabled from config repository
|
||||
expert_enabled = get_config_repository().get("expert_enabled", True)
|
||||
web_research_enabled = get_config_repository().get("web_research_enabled", False)
|
||||
|
||||
# Initialize model with provider and model name from config
|
||||
model = initialize_llm(provider, model_name, temperature=temperature)
|
||||
|
||||
|
|
@ -171,7 +165,7 @@ async def spawn_agent(
|
|||
Spawn a new RA.Aid agent to process a message or task.
|
||||
|
||||
Args:
|
||||
request: Request body with message and agent configuration
|
||||
request: Request body with message and agent configuration.
|
||||
repo: SessionRepository dependency injection
|
||||
|
||||
Returns:
|
||||
|
|
@ -181,11 +175,16 @@ async def spawn_agent(
|
|||
HTTPException: With a 500 status code if there's an error spawning the agent
|
||||
"""
|
||||
try:
|
||||
# Create a new session
|
||||
# Get configuration values from config repository
|
||||
config_repo = get_config_repository()
|
||||
expert_enabled = config_repo.get("expert_enabled", True)
|
||||
web_research_enabled = config_repo.get("web_research_enabled", False)
|
||||
|
||||
# Create a new session with config values (not request parameters)
|
||||
metadata = {
|
||||
"agent_type": "research-only" if request.research_only else "research",
|
||||
"expert_enabled": request.expert_enabled,
|
||||
"web_research_enabled": request.web_research_enabled,
|
||||
"expert_enabled": expert_enabled,
|
||||
"web_research_enabled": web_research_enabled,
|
||||
}
|
||||
session = repo.create_session(metadata=metadata)
|
||||
|
||||
|
|
@ -196,8 +195,6 @@ async def spawn_agent(
|
|||
request.message,
|
||||
str(session.id),
|
||||
request.research_only,
|
||||
request.expert_enabled,
|
||||
request.web_research_enabled,
|
||||
)
|
||||
)
|
||||
thread.daemon = True # Thread will terminate when main process exits
|
||||
|
|
|
|||
|
|
@ -48,7 +48,19 @@ def mock_repository(mock_session):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_repository, mock_thread, monkeypatch):
|
||||
def mock_config_repository():
|
||||
"""Create a mock config repository for testing."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: {
|
||||
"expert_enabled": True,
|
||||
"web_research_enabled": False,
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
}.get(key, default)
|
||||
return mock_config
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_repository, mock_thread, mock_config_repository, monkeypatch):
|
||||
"""Set up a test client with mocked dependencies."""
|
||||
# Create FastAPI app with router
|
||||
app = FastAPI()
|
||||
|
|
@ -63,6 +75,12 @@ def client(mock_repository, mock_thread, monkeypatch):
|
|||
lambda *args, **kwargs: None
|
||||
)
|
||||
|
||||
# Mock get_config_repository to use our mock
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.server.api_v1_spawn_agent.get_config_repository",
|
||||
lambda: mock_config_repository
|
||||
)
|
||||
|
||||
# Mock threading.Thread to return our mock thread
|
||||
def mock_thread_constructor(*args, **kwargs):
|
||||
mock_thread.target = kwargs.get('target')
|
||||
|
|
@ -81,6 +99,7 @@ def client(mock_repository, mock_thread, monkeypatch):
|
|||
# Add mocks to client for test access
|
||||
client.mock_repo = mock_repository
|
||||
client.mock_thread = mock_thread
|
||||
client.mock_config = mock_config_repository
|
||||
|
||||
yield client
|
||||
|
||||
|
|
@ -109,7 +128,7 @@ def test_spawn_agent(client, mock_repository, mock_thread):
|
|||
mock_repository.create_session.assert_called_once()
|
||||
|
||||
# Verify thread was created with correct args
|
||||
assert mock_thread.args == ("Test task for the agent", "123", False, True, False)
|
||||
assert mock_thread.args == ("Test task for the agent", "123", False)
|
||||
assert mock_thread.daemon is True
|
||||
|
||||
# Verify thread.start was called
|
||||
|
|
|
|||
Loading…
Reference in New Issue