fix tests
This commit is contained in:
parent
ae9cf5021b
commit
37764c7d56
|
|
@ -63,6 +63,42 @@ def mock_config_repository():
|
|||
yield mock_repo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_trajectory_repository():
|
||||
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup create method to return a mock trajectory
|
||||
def mock_create(**kwargs):
|
||||
mock_trajectory = MagicMock()
|
||||
mock_trajectory.id = 1
|
||||
return mock_trajectory
|
||||
mock_repo.create.side_effect = mock_create
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_human_input_repository():
|
||||
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup get_most_recent_id method to return a dummy ID
|
||||
mock_repo.get_most_recent_id.return_value = 1
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
def test_get_model_token_limit_anthropic(mock_config_repository):
|
||||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
|
|
@ -725,4 +761,4 @@ def test_handle_api_error_resource_exhausted():
|
|||
|
||||
# ResourceExhausted exception should be handled without raising
|
||||
resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).")
|
||||
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
||||
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
||||
|
|
@ -113,6 +113,40 @@ def mock_work_log_repository():
|
|||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_trajectory_repository():
|
||||
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup create method to return a mock trajectory
|
||||
def mock_create(**kwargs):
|
||||
mock_trajectory = MagicMock()
|
||||
mock_trajectory.id = 1
|
||||
return mock_trajectory
|
||||
mock_repo.create.side_effect = mock_create
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_human_input_repository():
|
||||
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup get_most_recent_id method to return a dummy ID
|
||||
mock_repo.get_most_recent_id.return_value = 1
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture
|
||||
def mock_functions():
|
||||
"""Mock functions used in agent.py"""
|
||||
|
|
@ -126,7 +160,9 @@ def mock_functions():
|
|||
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
|
||||
patch('ra_aid.tools.agent.get_work_log') as mock_get_work_log, \
|
||||
patch('ra_aid.tools.agent.reset_completion_flags') as mock_reset, \
|
||||
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion:
|
||||
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion, \
|
||||
patch('ra_aid.tools.agent.get_trajectory_repository') as mock_get_trajectory_repo, \
|
||||
patch('ra_aid.tools.agent.get_human_input_repository') as mock_get_human_input_repo:
|
||||
|
||||
# Setup mock return values
|
||||
mock_fact_repo.get_facts_dict.return_value = {1: "Test fact 1", 2: "Test fact 2"}
|
||||
|
|
@ -138,6 +174,15 @@ def mock_functions():
|
|||
mock_get_work_log.return_value = "Test work log"
|
||||
mock_get_completion.return_value = "Task completed"
|
||||
|
||||
# Setup mock for trajectory repository
|
||||
mock_trajectory_repo = MagicMock()
|
||||
mock_get_trajectory_repo.return_value = mock_trajectory_repo
|
||||
|
||||
# Setup mock for human input repository
|
||||
mock_human_input_repo = MagicMock()
|
||||
mock_human_input_repo.get_most_recent_id.return_value = 1
|
||||
mock_get_human_input_repo.return_value = mock_human_input_repo
|
||||
|
||||
# Return all mocks as a dictionary
|
||||
yield {
|
||||
'get_key_fact_repository': mock_get_fact_repo,
|
||||
|
|
@ -148,7 +193,9 @@ def mock_functions():
|
|||
'get_related_files': mock_get_files,
|
||||
'get_work_log': mock_get_work_log,
|
||||
'reset_completion_flags': mock_reset,
|
||||
'get_completion_message': mock_get_completion
|
||||
'get_completion_message': mock_get_completion,
|
||||
'get_trajectory_repository': mock_get_trajectory_repo,
|
||||
'get_human_input_repository': mock_get_human_input_repo
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,40 @@ def mock_config_repository():
|
|||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_trajectory_repository():
|
||||
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup create method to return a mock trajectory
|
||||
def mock_create(**kwargs):
|
||||
mock_trajectory = MagicMock()
|
||||
mock_trajectory.id = 1
|
||||
return mock_trajectory
|
||||
mock_repo.create.side_effect = mock_create
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_human_input_repository():
|
||||
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup get_most_recent_id method to return a dummy ID
|
||||
mock_repo.get_most_recent_id.return_value = 1
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
||||
"""Test shell command execution in cowboy mode (no approval)"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue