expert/web enabled based on config

This commit is contained in:
AI Christianson 2025-03-15 22:16:50 -04:00
parent 1dc9326154
commit 8d44ba0824
2 changed files with 39 additions and 23 deletions

View File

@ -44,8 +44,6 @@ class SpawnAgentRequest(BaseModel):
Attributes: Attributes:
message: The message or task for the agent to process message: The message or task for the agent to process
research_only: Whether to use research-only mode (default: False) research_only: Whether to use research-only mode (default: False)
expert_enabled: Whether to enable expert assistance (default: True)
web_research_enabled: Whether to enable web research (default: False)
""" """
message: str = Field( message: str = Field(
description="The message or task for the agent to process" description="The message or task for the agent to process"
@ -54,14 +52,6 @@ class SpawnAgentRequest(BaseModel):
default=False, default=False,
description="Whether to use research-only mode" description="Whether to use research-only mode"
) )
expert_enabled: bool = Field(
default=True,
description="Whether to enable expert assistance"
)
web_research_enabled: bool = Field(
default=False,
description="Whether to enable web research"
)
class SpawnAgentResponse(BaseModel): class SpawnAgentResponse(BaseModel):
""" """
@ -93,8 +83,6 @@ def run_agent_thread(
message: str, message: str,
session_id: str, session_id: str,
research_only: bool = False, research_only: bool = False,
expert_enabled: bool = True,
web_research_enabled: bool = False,
): ):
""" """
Run a research agent in a separate thread with proper repository initialization. Run a research agent in a separate thread with proper repository initialization.
@ -103,8 +91,10 @@ def run_agent_thread(
message: The message or task for the agent to process message: The message or task for the agent to process
session_id: The ID of the session to associate with this agent session_id: The ID of the session to associate with this agent
research_only: Whether to use research-only mode research_only: Whether to use research-only mode
expert_enabled: Whether to enable expert assistance
web_research_enabled: Whether to enable web research Note:
Values for expert_enabled and web_research_enabled are retrieved from the
config repository, which stores the values set during server startup.
""" """
try: try:
logger.info(f"Starting agent thread for session {session_id}") logger.info(f"Starting agent thread for session {session_id}")
@ -133,11 +123,15 @@ def run_agent_thread(
# Import here to avoid circular imports # Import here to avoid circular imports
from ra_aid.__main__ import run_research_agent from ra_aid.__main__ import run_research_agent
# Get the provider and model from config repository # Get configuration values from config repository
provider = get_config_repository().get("provider", "anthropic") provider = get_config_repository().get("provider", "anthropic")
model_name = get_config_repository().get("model", "claude-3-7-sonnet-20250219") model_name = get_config_repository().get("model", "claude-3-7-sonnet-20250219")
temperature = get_config_repository().get("temperature") temperature = get_config_repository().get("temperature")
# Get expert_enabled and web_research_enabled from config repository
expert_enabled = get_config_repository().get("expert_enabled", True)
web_research_enabled = get_config_repository().get("web_research_enabled", False)
# Initialize model with provider and model name from config # Initialize model with provider and model name from config
model = initialize_llm(provider, model_name, temperature=temperature) model = initialize_llm(provider, model_name, temperature=temperature)
@ -171,7 +165,7 @@ async def spawn_agent(
Spawn a new RA.Aid agent to process a message or task. Spawn a new RA.Aid agent to process a message or task.
Args: Args:
request: Request body with message and agent configuration request: Request body with message and agent configuration.
repo: SessionRepository dependency injection repo: SessionRepository dependency injection
Returns: Returns:
@ -181,11 +175,16 @@ async def spawn_agent(
HTTPException: With a 500 status code if there's an error spawning the agent HTTPException: With a 500 status code if there's an error spawning the agent
""" """
try: try:
# Create a new session # Get configuration values from config repository
config_repo = get_config_repository()
expert_enabled = config_repo.get("expert_enabled", True)
web_research_enabled = config_repo.get("web_research_enabled", False)
# Create a new session with config values (not request parameters)
metadata = { metadata = {
"agent_type": "research-only" if request.research_only else "research", "agent_type": "research-only" if request.research_only else "research",
"expert_enabled": request.expert_enabled, "expert_enabled": expert_enabled,
"web_research_enabled": request.web_research_enabled, "web_research_enabled": web_research_enabled,
} }
session = repo.create_session(metadata=metadata) session = repo.create_session(metadata=metadata)
@ -196,8 +195,6 @@ async def spawn_agent(
request.message, request.message,
str(session.id), str(session.id),
request.research_only, request.research_only,
request.expert_enabled,
request.web_research_enabled,
) )
) )
thread.daemon = True # Thread will terminate when main process exits thread.daemon = True # Thread will terminate when main process exits

View File

@ -48,7 +48,19 @@ def mock_repository(mock_session):
@pytest.fixture @pytest.fixture
def client(mock_repository, mock_thread, monkeypatch): def mock_config_repository():
"""Create a mock config repository for testing."""
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: {
"expert_enabled": True,
"web_research_enabled": False,
"provider": "anthropic",
"model": "claude-3-7-sonnet-20250219",
}.get(key, default)
return mock_config
@pytest.fixture
def client(mock_repository, mock_thread, mock_config_repository, monkeypatch):
"""Set up a test client with mocked dependencies.""" """Set up a test client with mocked dependencies."""
# Create FastAPI app with router # Create FastAPI app with router
app = FastAPI() app = FastAPI()
@ -63,6 +75,12 @@ def client(mock_repository, mock_thread, monkeypatch):
lambda *args, **kwargs: None lambda *args, **kwargs: None
) )
# Mock get_config_repository to use our mock
monkeypatch.setattr(
"ra_aid.server.api_v1_spawn_agent.get_config_repository",
lambda: mock_config_repository
)
# Mock threading.Thread to return our mock thread # Mock threading.Thread to return our mock thread
def mock_thread_constructor(*args, **kwargs): def mock_thread_constructor(*args, **kwargs):
mock_thread.target = kwargs.get('target') mock_thread.target = kwargs.get('target')
@ -81,6 +99,7 @@ def client(mock_repository, mock_thread, monkeypatch):
# Add mocks to client for test access # Add mocks to client for test access
client.mock_repo = mock_repository client.mock_repo = mock_repository
client.mock_thread = mock_thread client.mock_thread = mock_thread
client.mock_config = mock_config_repository
yield client yield client
@ -109,7 +128,7 @@ def test_spawn_agent(client, mock_repository, mock_thread):
mock_repository.create_session.assert_called_once() mock_repository.create_session.assert_called_once()
# Verify thread was created with correct args # Verify thread was created with correct args
assert mock_thread.args == ("Test task for the agent", "123", False, True, False) assert mock_thread.args == ("Test task for the agent", "123", False)
assert mock_thread.daemon is True assert mock_thread.daemon is True
# Verify thread.start was called # Verify thread.start was called