fix(tests): update type hints in test_agent_utils.py for better clarity and type safety
refactor(tests): modify DummyAgent's stream method to use more descriptive parameter names and types for improved readability
This commit is contained in:
parent
cd8d1c459d
commit
c7712e0114
|
|
@ -1,11 +1,12 @@
|
|||
"""Unit tests for agent_utils.py."""
|
||||
|
||||
from typing import Any, Dict, Literal
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import litellm
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from ra_aid.agent_utils import (
|
||||
AgentState,
|
||||
|
|
@ -317,7 +318,7 @@ def test_run_agent_stream(monkeypatch):
|
|||
|
||||
# Create a dummy agent that yields one chunk
|
||||
class DummyAgent:
|
||||
def stream(self, msg, cfg):
|
||||
def stream(self, input_data, cfg: dict):
|
||||
yield {"content": "chunk1"}
|
||||
|
||||
dummy_agent = DummyAgent()
|
||||
|
|
@ -327,13 +328,15 @@ def test_run_agent_stream(monkeypatch):
|
|||
_global_memory["completion_message"] = "existing"
|
||||
call_flag = {"called": False}
|
||||
|
||||
def fake_print_agent_output(chunk):
|
||||
def fake_print_agent_output(
|
||||
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
|
||||
):
|
||||
call_flag["called"] = True
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils.print_agent_output", fake_print_agent_output
|
||||
)
|
||||
_run_agent_stream(dummy_agent, "dummy prompt", {})
|
||||
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
|
||||
assert call_flag["called"]
|
||||
assert _global_memory["plan_completed"] is False
|
||||
assert _global_memory["task_completed"] is False
|
||||
|
|
|
|||
Loading…
Reference in New Issue