expert/web enabled based on config
This commit is contained in:
parent
1dc9326154
commit
8d44ba0824
|
|
@ -44,8 +44,6 @@ class SpawnAgentRequest(BaseModel):
|
||||||
Attributes:
|
Attributes:
|
||||||
message: The message or task for the agent to process
|
message: The message or task for the agent to process
|
||||||
research_only: Whether to use research-only mode (default: False)
|
research_only: Whether to use research-only mode (default: False)
|
||||||
expert_enabled: Whether to enable expert assistance (default: True)
|
|
||||||
web_research_enabled: Whether to enable web research (default: False)
|
|
||||||
"""
|
"""
|
||||||
message: str = Field(
|
message: str = Field(
|
||||||
description="The message or task for the agent to process"
|
description="The message or task for the agent to process"
|
||||||
|
|
@ -54,14 +52,6 @@ class SpawnAgentRequest(BaseModel):
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to use research-only mode"
|
description="Whether to use research-only mode"
|
||||||
)
|
)
|
||||||
expert_enabled: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Whether to enable expert assistance"
|
|
||||||
)
|
|
||||||
web_research_enabled: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="Whether to enable web research"
|
|
||||||
)
|
|
||||||
|
|
||||||
class SpawnAgentResponse(BaseModel):
|
class SpawnAgentResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
@ -93,8 +83,6 @@ def run_agent_thread(
|
||||||
message: str,
|
message: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
research_only: bool = False,
|
research_only: bool = False,
|
||||||
expert_enabled: bool = True,
|
|
||||||
web_research_enabled: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Run a research agent in a separate thread with proper repository initialization.
|
Run a research agent in a separate thread with proper repository initialization.
|
||||||
|
|
@ -103,8 +91,10 @@ def run_agent_thread(
|
||||||
message: The message or task for the agent to process
|
message: The message or task for the agent to process
|
||||||
session_id: The ID of the session to associate with this agent
|
session_id: The ID of the session to associate with this agent
|
||||||
research_only: Whether to use research-only mode
|
research_only: Whether to use research-only mode
|
||||||
expert_enabled: Whether to enable expert assistance
|
|
||||||
web_research_enabled: Whether to enable web research
|
Note:
|
||||||
|
Values for expert_enabled and web_research_enabled are retrieved from the
|
||||||
|
config repository, which stores the values set during server startup.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting agent thread for session {session_id}")
|
logger.info(f"Starting agent thread for session {session_id}")
|
||||||
|
|
@ -133,11 +123,15 @@ def run_agent_thread(
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
from ra_aid.__main__ import run_research_agent
|
from ra_aid.__main__ import run_research_agent
|
||||||
|
|
||||||
# Get the provider and model from config repository
|
# Get configuration values from config repository
|
||||||
provider = get_config_repository().get("provider", "anthropic")
|
provider = get_config_repository().get("provider", "anthropic")
|
||||||
model_name = get_config_repository().get("model", "claude-3-7-sonnet-20250219")
|
model_name = get_config_repository().get("model", "claude-3-7-sonnet-20250219")
|
||||||
temperature = get_config_repository().get("temperature")
|
temperature = get_config_repository().get("temperature")
|
||||||
|
|
||||||
|
# Get expert_enabled and web_research_enabled from config repository
|
||||||
|
expert_enabled = get_config_repository().get("expert_enabled", True)
|
||||||
|
web_research_enabled = get_config_repository().get("web_research_enabled", False)
|
||||||
|
|
||||||
# Initialize model with provider and model name from config
|
# Initialize model with provider and model name from config
|
||||||
model = initialize_llm(provider, model_name, temperature=temperature)
|
model = initialize_llm(provider, model_name, temperature=temperature)
|
||||||
|
|
||||||
|
|
@ -171,7 +165,7 @@ async def spawn_agent(
|
||||||
Spawn a new RA.Aid agent to process a message or task.
|
Spawn a new RA.Aid agent to process a message or task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Request body with message and agent configuration
|
request: Request body with message and agent configuration.
|
||||||
repo: SessionRepository dependency injection
|
repo: SessionRepository dependency injection
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -181,11 +175,16 @@ async def spawn_agent(
|
||||||
HTTPException: With a 500 status code if there's an error spawning the agent
|
HTTPException: With a 500 status code if there's an error spawning the agent
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Create a new session
|
# Get configuration values from config repository
|
||||||
|
config_repo = get_config_repository()
|
||||||
|
expert_enabled = config_repo.get("expert_enabled", True)
|
||||||
|
web_research_enabled = config_repo.get("web_research_enabled", False)
|
||||||
|
|
||||||
|
# Create a new session with config values (not request parameters)
|
||||||
metadata = {
|
metadata = {
|
||||||
"agent_type": "research-only" if request.research_only else "research",
|
"agent_type": "research-only" if request.research_only else "research",
|
||||||
"expert_enabled": request.expert_enabled,
|
"expert_enabled": expert_enabled,
|
||||||
"web_research_enabled": request.web_research_enabled,
|
"web_research_enabled": web_research_enabled,
|
||||||
}
|
}
|
||||||
session = repo.create_session(metadata=metadata)
|
session = repo.create_session(metadata=metadata)
|
||||||
|
|
||||||
|
|
@ -196,8 +195,6 @@ async def spawn_agent(
|
||||||
request.message,
|
request.message,
|
||||||
str(session.id),
|
str(session.id),
|
||||||
request.research_only,
|
request.research_only,
|
||||||
request.expert_enabled,
|
|
||||||
request.web_research_enabled,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
thread.daemon = True # Thread will terminate when main process exits
|
thread.daemon = True # Thread will terminate when main process exits
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,19 @@ def mock_repository(mock_session):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client(mock_repository, mock_thread, monkeypatch):
|
def mock_config_repository():
|
||||||
|
"""Create a mock config repository for testing."""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.get.side_effect = lambda key, default=None: {
|
||||||
|
"expert_enabled": True,
|
||||||
|
"web_research_enabled": False,
|
||||||
|
"provider": "anthropic",
|
||||||
|
"model": "claude-3-7-sonnet-20250219",
|
||||||
|
}.get(key, default)
|
||||||
|
return mock_config
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_repository, mock_thread, mock_config_repository, monkeypatch):
|
||||||
"""Set up a test client with mocked dependencies."""
|
"""Set up a test client with mocked dependencies."""
|
||||||
# Create FastAPI app with router
|
# Create FastAPI app with router
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
@ -63,6 +75,12 @@ def client(mock_repository, mock_thread, monkeypatch):
|
||||||
lambda *args, **kwargs: None
|
lambda *args, **kwargs: None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mock get_config_repository to use our mock
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"ra_aid.server.api_v1_spawn_agent.get_config_repository",
|
||||||
|
lambda: mock_config_repository
|
||||||
|
)
|
||||||
|
|
||||||
# Mock threading.Thread to return our mock thread
|
# Mock threading.Thread to return our mock thread
|
||||||
def mock_thread_constructor(*args, **kwargs):
|
def mock_thread_constructor(*args, **kwargs):
|
||||||
mock_thread.target = kwargs.get('target')
|
mock_thread.target = kwargs.get('target')
|
||||||
|
|
@ -81,6 +99,7 @@ def client(mock_repository, mock_thread, monkeypatch):
|
||||||
# Add mocks to client for test access
|
# Add mocks to client for test access
|
||||||
client.mock_repo = mock_repository
|
client.mock_repo = mock_repository
|
||||||
client.mock_thread = mock_thread
|
client.mock_thread = mock_thread
|
||||||
|
client.mock_config = mock_config_repository
|
||||||
|
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
@ -109,7 +128,7 @@ def test_spawn_agent(client, mock_repository, mock_thread):
|
||||||
mock_repository.create_session.assert_called_once()
|
mock_repository.create_session.assert_called_once()
|
||||||
|
|
||||||
# Verify thread was created with correct args
|
# Verify thread was created with correct args
|
||||||
assert mock_thread.args == ("Test task for the agent", "123", False, True, False)
|
assert mock_thread.args == ("Test task for the agent", "123", False)
|
||||||
assert mock_thread.daemon is True
|
assert mock_thread.daemon is True
|
||||||
|
|
||||||
# Verify thread.start was called
|
# Verify thread.start was called
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue