agent context
This commit is contained in:
parent
724dbd4fda
commit
28d9032ca5
|
|
@ -0,0 +1,150 @@
|
||||||
|
"""Context manager for tracking agent state and completion status."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Dict, Optional, Set
|
||||||
|
|
||||||
|
# Thread-local storage for context variables
|
||||||
|
_thread_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
class AgentContext:
|
||||||
|
"""Context manager for agent state tracking."""
|
||||||
|
|
||||||
|
def __init__(self, parent_context=None):
|
||||||
|
"""Initialize a new agent context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent_context: Optional parent context to inherit state from
|
||||||
|
"""
|
||||||
|
# Initialize completion flags
|
||||||
|
self.task_completed = False
|
||||||
|
self.plan_completed = False
|
||||||
|
self.completion_message = ""
|
||||||
|
|
||||||
|
# Inherit state from parent if provided
|
||||||
|
if parent_context:
|
||||||
|
self.task_completed = parent_context.task_completed
|
||||||
|
self.plan_completed = parent_context.plan_completed
|
||||||
|
self.completion_message = parent_context.completion_message
|
||||||
|
|
||||||
|
def mark_task_completed(self, message: str) -> None:
|
||||||
|
"""Mark the current task as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Completion message explaining how/why the task is complete
|
||||||
|
"""
|
||||||
|
self.task_completed = True
|
||||||
|
self.completion_message = message
|
||||||
|
|
||||||
|
def mark_plan_completed(self, message: str) -> None:
|
||||||
|
"""Mark the current plan as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Completion message explaining how the plan was completed
|
||||||
|
"""
|
||||||
|
self.task_completed = True
|
||||||
|
self.plan_completed = True
|
||||||
|
self.completion_message = message
|
||||||
|
|
||||||
|
def reset_completion_flags(self) -> None:
|
||||||
|
"""Reset all completion flags."""
|
||||||
|
self.task_completed = False
|
||||||
|
self.plan_completed = False
|
||||||
|
self.completion_message = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_completed(self) -> bool:
|
||||||
|
"""Check if the current context is marked as completed."""
|
||||||
|
return self.task_completed or self.plan_completed
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_context() -> Optional[AgentContext]:
|
||||||
|
"""Get the current agent context for this thread.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current AgentContext or None if no context is active
|
||||||
|
"""
|
||||||
|
return getattr(_thread_local, "current_context", None)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def agent_context(parent_context=None):
|
||||||
|
"""Context manager for agent execution.
|
||||||
|
|
||||||
|
Creates a new agent context and makes it the current context for the duration
|
||||||
|
of the with block. Restores the previous context when exiting the block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent_context: Optional parent context to inherit state from
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The newly created AgentContext
|
||||||
|
"""
|
||||||
|
# Save the previous context
|
||||||
|
previous_context = getattr(_thread_local, "current_context", None)
|
||||||
|
|
||||||
|
# Create a new context, inheriting from parent if provided
|
||||||
|
# If parent_context is None but previous_context exists, use previous_context as parent
|
||||||
|
if parent_context is None and previous_context is not None:
|
||||||
|
context = AgentContext(previous_context)
|
||||||
|
else:
|
||||||
|
context = AgentContext(parent_context)
|
||||||
|
|
||||||
|
# Set as current context
|
||||||
|
_thread_local.current_context = context
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield context
|
||||||
|
finally:
|
||||||
|
# Restore previous context
|
||||||
|
_thread_local.current_context = previous_context
|
||||||
|
|
||||||
|
|
||||||
|
def mark_task_completed(message: str) -> None:
|
||||||
|
"""Mark the current task as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Completion message explaining how/why the task is complete
|
||||||
|
"""
|
||||||
|
context = get_current_context()
|
||||||
|
if context:
|
||||||
|
context.mark_task_completed(message)
|
||||||
|
|
||||||
|
|
||||||
|
def mark_plan_completed(message: str) -> None:
|
||||||
|
"""Mark the current plan as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Completion message explaining how the plan was completed
|
||||||
|
"""
|
||||||
|
context = get_current_context()
|
||||||
|
if context:
|
||||||
|
context.mark_plan_completed(message)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_completion_flags() -> None:
|
||||||
|
"""Reset completion flags in the current context."""
|
||||||
|
context = get_current_context()
|
||||||
|
if context:
|
||||||
|
context.reset_completion_flags()
|
||||||
|
|
||||||
|
|
||||||
|
def is_completed() -> bool:
|
||||||
|
"""Check if the current context is marked as completed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the current context is marked as completed, False otherwise
|
||||||
|
"""
|
||||||
|
context = get_current_context()
|
||||||
|
return context.is_completed if context else False
|
||||||
|
|
||||||
|
|
||||||
|
def get_completion_message() -> str:
|
||||||
|
"""Get the completion message from the current context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The completion message or empty string if no context or no message
|
||||||
|
"""
|
||||||
|
context = get_current_context()
|
||||||
|
return context.completion_message if context else ""
|
||||||
|
|
@ -7,7 +7,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
from typing import Any, Dict, List, Literal, Optional, Sequence, ContextManager
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||||
|
|
@ -69,6 +69,13 @@ from ra_aid.tool_configs import (
|
||||||
get_web_research_tools,
|
get_web_research_tools,
|
||||||
)
|
)
|
||||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||||
|
from ra_aid.agent_context import (
|
||||||
|
agent_context,
|
||||||
|
get_current_context,
|
||||||
|
is_completed,
|
||||||
|
reset_completion_flags,
|
||||||
|
get_completion_message,
|
||||||
|
)
|
||||||
from ra_aid.tools.memory import (
|
from ra_aid.tools.memory import (
|
||||||
_global_memory,
|
_global_memory,
|
||||||
get_memory_value,
|
get_memory_value,
|
||||||
|
|
@ -821,9 +828,8 @@ def _decrement_agent_depth():
|
||||||
|
|
||||||
|
|
||||||
def reset_agent_completion_flags():
|
def reset_agent_completion_flags():
|
||||||
_global_memory["plan_completed"] = False
|
"""Reset completion flags in the current context."""
|
||||||
_global_memory["task_completed"] = False
|
reset_completion_flags()
|
||||||
_global_memory["completion_message"] = ""
|
|
||||||
|
|
||||||
|
|
||||||
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
|
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
|
||||||
|
|
@ -897,8 +903,8 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
agent_type = get_agent_type(agent)
|
agent_type = get_agent_type(agent)
|
||||||
print_agent_output(chunk, agent_type)
|
print_agent_output(chunk, agent_type)
|
||||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
if is_completed():
|
||||||
reset_agent_completion_flags()
|
reset_completion_flags()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -919,7 +925,8 @@ def run_agent_with_retry(
|
||||||
original_prompt = prompt
|
original_prompt = prompt
|
||||||
msg_list = [HumanMessage(content=prompt)]
|
msg_list = [HumanMessage(content=prompt)]
|
||||||
|
|
||||||
with InterruptibleSection():
|
# Create a new agent context for this run
|
||||||
|
with InterruptibleSection(), agent_context() as ctx:
|
||||||
try:
|
try:
|
||||||
_increment_agent_depth()
|
_increment_agent_depth()
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
|
|
|
||||||
|
|
@ -247,7 +247,7 @@ def create_llm_client(
|
||||||
if supports_thinking:
|
if supports_thinking:
|
||||||
temp_kwargs = {"thinking": {
|
temp_kwargs = {"thinking": {
|
||||||
"type": "enabled",
|
"type": "enabled",
|
||||||
"budget_tokens": 8000
|
"budget_tokens": 12000
|
||||||
}}
|
}}
|
||||||
|
|
||||||
if provider == "deepseek":
|
if provider == "deepseek":
|
||||||
|
|
|
||||||
|
|
@ -527,6 +527,8 @@ FOLLOW TEST DRIVEN DEVELOPMENT (TDD) PRACTICES WHERE POSSIBE. E.G. COMPILE CODE
|
||||||
|
|
||||||
IF YOU CAN SEE THE CODE WRITTEN/CHANGED BY THE PROGRAMMER, TRUST IT. YOU DO NOT NEED TO RE-READ EVERY FILE WITH EVERY SMALL EDIT.
|
IF YOU CAN SEE THE CODE WRITTEN/CHANGED BY THE PROGRAMMER, TRUST IT. YOU DO NOT NEED TO RE-READ EVERY FILE WITH EVERY SMALL EDIT.
|
||||||
|
|
||||||
|
YOU MUST CALL emit_related_files BEFORE CALLING run_programming_task WITH ALL RELEVANT FILES, UNLESS THEY ARE ALREADY RECORDED AS RELATED FILES.
|
||||||
|
|
||||||
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,372 @@
|
||||||
|
"""Unit tests for the agent_context module."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ra_aid.agent_context import (
|
||||||
|
AgentContext,
|
||||||
|
agent_context,
|
||||||
|
get_current_context,
|
||||||
|
mark_task_completed,
|
||||||
|
mark_plan_completed,
|
||||||
|
reset_completion_flags,
|
||||||
|
is_completed,
|
||||||
|
get_completion_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentContext:
|
||||||
|
"""Test cases for the AgentContext class and related functions."""
|
||||||
|
|
||||||
|
def test_context_creation(self):
|
||||||
|
"""Test creating a new context."""
|
||||||
|
context = AgentContext()
|
||||||
|
assert context.task_completed is False
|
||||||
|
assert context.plan_completed is False
|
||||||
|
assert context.completion_message == ""
|
||||||
|
|
||||||
|
def test_context_inheritance(self):
|
||||||
|
"""Test that child contexts inherit state from parent contexts."""
|
||||||
|
parent = AgentContext()
|
||||||
|
parent.mark_task_completed("Parent task completed")
|
||||||
|
|
||||||
|
child = AgentContext(parent_context=parent)
|
||||||
|
assert child.task_completed is True
|
||||||
|
assert child.completion_message == "Parent task completed"
|
||||||
|
|
||||||
|
def test_mark_task_completed(self):
|
||||||
|
"""Test marking a task as completed."""
|
||||||
|
context = AgentContext()
|
||||||
|
context.mark_task_completed("Task done")
|
||||||
|
|
||||||
|
assert context.task_completed is True
|
||||||
|
assert context.plan_completed is False
|
||||||
|
assert context.completion_message == "Task done"
|
||||||
|
|
||||||
|
def test_mark_plan_completed(self):
|
||||||
|
"""Test marking a plan as completed."""
|
||||||
|
context = AgentContext()
|
||||||
|
context.mark_plan_completed("Plan done")
|
||||||
|
|
||||||
|
assert context.task_completed is True
|
||||||
|
assert context.plan_completed is True
|
||||||
|
assert context.completion_message == "Plan done"
|
||||||
|
|
||||||
|
def test_reset_completion_flags(self):
|
||||||
|
"""Test resetting completion flags."""
|
||||||
|
context = AgentContext()
|
||||||
|
context.mark_task_completed("Task done")
|
||||||
|
|
||||||
|
context.reset_completion_flags()
|
||||||
|
assert context.task_completed is False
|
||||||
|
assert context.plan_completed is False
|
||||||
|
assert context.completion_message == ""
|
||||||
|
|
||||||
|
def test_is_completed_property(self):
|
||||||
|
"""Test the is_completed property."""
|
||||||
|
context = AgentContext()
|
||||||
|
assert context.is_completed is False
|
||||||
|
|
||||||
|
context.mark_task_completed("Task done")
|
||||||
|
assert context.is_completed is True
|
||||||
|
|
||||||
|
context.reset_completion_flags()
|
||||||
|
assert context.is_completed is False
|
||||||
|
|
||||||
|
context.mark_plan_completed("Plan done")
|
||||||
|
assert context.is_completed is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextManager:
|
||||||
|
"""Test cases for the agent_context context manager."""
|
||||||
|
|
||||||
|
def test_context_manager_basic(self):
|
||||||
|
"""Test basic context manager functionality."""
|
||||||
|
assert get_current_context() is None
|
||||||
|
|
||||||
|
with agent_context() as ctx:
|
||||||
|
assert get_current_context() is ctx
|
||||||
|
assert ctx.task_completed is False
|
||||||
|
|
||||||
|
assert get_current_context() is None
|
||||||
|
|
||||||
|
def test_nested_context_managers(self):
|
||||||
|
"""Test nested context managers."""
|
||||||
|
with agent_context() as outer_ctx:
|
||||||
|
assert get_current_context() is outer_ctx
|
||||||
|
|
||||||
|
with agent_context() as inner_ctx:
|
||||||
|
assert get_current_context() is inner_ctx
|
||||||
|
assert inner_ctx is not outer_ctx
|
||||||
|
|
||||||
|
assert get_current_context() is outer_ctx
|
||||||
|
|
||||||
|
def test_context_manager_with_parent(self):
|
||||||
|
"""Test context manager with explicit parent context."""
|
||||||
|
parent = AgentContext()
|
||||||
|
parent.mark_task_completed("Parent task")
|
||||||
|
|
||||||
|
with agent_context(parent_context=parent) as ctx:
|
||||||
|
assert ctx.task_completed is True
|
||||||
|
assert ctx.completion_message == "Parent task"
|
||||||
|
|
||||||
|
def test_context_manager_inheritance(self):
|
||||||
|
"""Test that nested contexts inherit from outer contexts by default."""
|
||||||
|
with agent_context() as outer:
|
||||||
|
outer.mark_task_completed("Outer task")
|
||||||
|
|
||||||
|
with agent_context() as inner:
|
||||||
|
assert inner.task_completed is True
|
||||||
|
assert inner.completion_message == "Outer task"
|
||||||
|
|
||||||
|
inner.mark_plan_completed("Inner plan")
|
||||||
|
|
||||||
|
# Outer context should not be affected by inner context changes
|
||||||
|
assert outer.task_completed is True
|
||||||
|
assert outer.plan_completed is False
|
||||||
|
assert outer.completion_message == "Outer task"
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadIsolation:
|
||||||
|
"""Test thread isolation of context variables."""
|
||||||
|
|
||||||
|
def test_thread_isolation(self):
|
||||||
|
"""Test that contexts are isolated between threads."""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
def thread_func(thread_id):
|
||||||
|
with agent_context() as ctx:
|
||||||
|
ctx.mark_task_completed(f"Thread {thread_id}")
|
||||||
|
time.sleep(0.1) # Give other threads time to run
|
||||||
|
# Store the context's message for verification
|
||||||
|
results[thread_id] = get_completion_message()
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(3):
|
||||||
|
t = threading.Thread(target=thread_func, args=(i,))
|
||||||
|
threads.append(t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Each thread should have its own message
|
||||||
|
assert results[0] == "Thread 0"
|
||||||
|
assert results[1] == "Thread 1"
|
||||||
|
assert results[2] == "Thread 2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtilityFunctions:
|
||||||
|
"""Test utility functions that operate on the current context."""
|
||||||
|
|
||||||
|
def test_mark_task_completed_utility(self):
|
||||||
|
"""Test the mark_task_completed utility function."""
|
||||||
|
with agent_context():
|
||||||
|
mark_task_completed("Task done via utility")
|
||||||
|
assert is_completed() is True
|
||||||
|
assert get_completion_message() == "Task done via utility"
|
||||||
|
|
||||||
|
def test_mark_plan_completed_utility(self):
|
||||||
|
"""Test the mark_plan_completed utility function."""
|
||||||
|
with agent_context():
|
||||||
|
mark_plan_completed("Plan done via utility")
|
||||||
|
assert is_completed() is True
|
||||||
|
assert get_completion_message() == "Plan done via utility"
|
||||||
|
|
||||||
|
def test_reset_completion_flags_utility(self):
|
||||||
|
"""Test the reset_completion_flags utility function."""
|
||||||
|
with agent_context():
|
||||||
|
mark_task_completed("Task done")
|
||||||
|
reset_completion_flags()
|
||||||
|
assert is_completed() is False
|
||||||
|
assert get_completion_message() == ""
|
||||||
|
|
||||||
|
def test_utility_functions_without_context(self):
|
||||||
|
"""Test utility functions when no context is active."""
|
||||||
|
# These should not raise exceptions even without an active context
|
||||||
|
mark_task_completed("No context")
|
||||||
|
mark_plan_completed("No context")
|
||||||
|
reset_completion_flags()
|
||||||
|
|
||||||
|
# These should have safe default returns
|
||||||
|
assert is_completed() is False
|
||||||
|
assert get_completion_message() == ""
|
||||||
|
"""Unit tests for the agent_context module."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ra_aid.agent_context import (
|
||||||
|
AgentContext,
|
||||||
|
agent_context,
|
||||||
|
get_current_context,
|
||||||
|
mark_task_completed,
|
||||||
|
mark_plan_completed,
|
||||||
|
reset_completion_flags,
|
||||||
|
is_completed,
|
||||||
|
get_completion_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentContext:
|
||||||
|
"""Test cases for the AgentContext class and related functions."""
|
||||||
|
|
||||||
|
def test_context_creation(self):
|
||||||
|
"""Test creating a new context."""
|
||||||
|
context = AgentContext()
|
||||||
|
assert context.task_completed is False
|
||||||
|
assert context.plan_completed is False
|
||||||
|
assert context.completion_message == ""
|
||||||
|
|
||||||
|
def test_context_inheritance(self):
|
||||||
|
"""Test that child contexts inherit state from parent contexts."""
|
||||||
|
parent = AgentContext()
|
||||||
|
parent.mark_task_completed("Parent task completed")
|
||||||
|
child = AgentContext(parent_context=parent)
|
||||||
|
assert child.task_completed is True
|
||||||
|
assert child.completion_message == "Parent task completed"
|
||||||
|
|
||||||
|
def test_mark_task_completed(self):
|
||||||
|
"""Test marking a task as completed."""
|
||||||
|
context = AgentContext()
|
||||||
|
context.mark_task_completed("Task done")
|
||||||
|
assert context.task_completed is True
|
||||||
|
assert context.plan_completed is False
|
||||||
|
assert context.completion_message == "Task done"
|
||||||
|
|
||||||
|
def test_mark_plan_completed(self):
|
||||||
|
"""Test marking a plan as completed."""
|
||||||
|
context = AgentContext()
|
||||||
|
context.mark_plan_completed("Plan done")
|
||||||
|
assert context.task_completed is True
|
||||||
|
assert context.plan_completed is True
|
||||||
|
assert context.completion_message == "Plan done"
|
||||||
|
|
||||||
|
def test_reset_completion_flags(self):
|
||||||
|
"""Test resetting completion flags."""
|
||||||
|
context = AgentContext()
|
||||||
|
context.mark_task_completed("Task done")
|
||||||
|
context.reset_completion_flags()
|
||||||
|
assert context.task_completed is False
|
||||||
|
assert context.plan_completed is False
|
||||||
|
assert context.completion_message == ""
|
||||||
|
|
||||||
|
def test_is_completed_property(self):
|
||||||
|
"""Test the is_completed property."""
|
||||||
|
context = AgentContext()
|
||||||
|
assert context.is_completed is False
|
||||||
|
context.mark_task_completed("Task done")
|
||||||
|
assert context.is_completed is True
|
||||||
|
context.reset_completion_flags()
|
||||||
|
assert context.is_completed is False
|
||||||
|
context.mark_plan_completed("Plan done")
|
||||||
|
assert context.is_completed is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextManager:
|
||||||
|
"""Test cases for the agent_context context manager."""
|
||||||
|
|
||||||
|
def test_context_manager_basic(self):
|
||||||
|
"""Test basic context manager functionality."""
|
||||||
|
assert get_current_context() is None
|
||||||
|
with agent_context() as ctx:
|
||||||
|
assert get_current_context() is ctx
|
||||||
|
assert ctx.task_completed is False
|
||||||
|
assert get_current_context() is None
|
||||||
|
|
||||||
|
def test_nested_context_managers(self):
|
||||||
|
"""Test nested context managers."""
|
||||||
|
with agent_context() as outer_ctx:
|
||||||
|
assert get_current_context() is outer_ctx
|
||||||
|
with agent_context() as inner_ctx:
|
||||||
|
assert get_current_context() is inner_ctx
|
||||||
|
assert inner_ctx is not outer_ctx
|
||||||
|
assert get_current_context() is outer_ctx
|
||||||
|
|
||||||
|
def test_context_manager_with_parent(self):
|
||||||
|
"""Test context manager with explicit parent context."""
|
||||||
|
parent = AgentContext()
|
||||||
|
parent.mark_task_completed("Parent task")
|
||||||
|
with agent_context(parent_context=parent) as ctx:
|
||||||
|
assert ctx.task_completed is True
|
||||||
|
assert ctx.completion_message == "Parent task"
|
||||||
|
|
||||||
|
def test_context_manager_inheritance(self):
|
||||||
|
"""Test that nested contexts inherit from outer contexts by default."""
|
||||||
|
with agent_context() as outer:
|
||||||
|
outer.mark_task_completed("Outer task")
|
||||||
|
with agent_context() as inner:
|
||||||
|
assert inner.task_completed is True
|
||||||
|
assert inner.completion_message == "Outer task"
|
||||||
|
inner.mark_plan_completed("Inner plan")
|
||||||
|
# Outer context should not be affected by inner context changes
|
||||||
|
assert outer.task_completed is True
|
||||||
|
assert outer.plan_completed is False
|
||||||
|
assert outer.completion_message == "Outer task"
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadIsolation:
|
||||||
|
"""Test thread isolation of context variables."""
|
||||||
|
|
||||||
|
def test_thread_isolation(self):
|
||||||
|
"""Test that contexts are isolated between threads."""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
def thread_func(thread_id):
|
||||||
|
with agent_context() as ctx:
|
||||||
|
ctx.mark_task_completed(f"Thread {thread_id}")
|
||||||
|
time.sleep(0.1) # Give other threads time to run
|
||||||
|
# Store the context's message for verification
|
||||||
|
results[thread_id] = get_completion_message()
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(3):
|
||||||
|
t = threading.Thread(target=thread_func, args=(i,))
|
||||||
|
threads.append(t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Each thread should have its own message
|
||||||
|
assert results[0] == "Thread 0"
|
||||||
|
assert results[1] == "Thread 1"
|
||||||
|
assert results[2] == "Thread 2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtilityFunctions:
|
||||||
|
"""Test utility functions that operate on the current context."""
|
||||||
|
|
||||||
|
def test_mark_task_completed_utility(self):
|
||||||
|
"""Test the mark_task_completed utility function."""
|
||||||
|
with agent_context():
|
||||||
|
mark_task_completed("Task done via utility")
|
||||||
|
assert is_completed() is True
|
||||||
|
assert get_completion_message() == "Task done via utility"
|
||||||
|
|
||||||
|
def test_mark_plan_completed_utility(self):
|
||||||
|
"""Test the mark_plan_completed utility function."""
|
||||||
|
with agent_context():
|
||||||
|
mark_plan_completed("Plan done via utility")
|
||||||
|
assert is_completed() is True
|
||||||
|
assert get_completion_message() == "Plan done via utility"
|
||||||
|
|
||||||
|
def test_reset_completion_flags_utility(self):
|
||||||
|
"""Test the reset_completion_flags utility function."""
|
||||||
|
with agent_context():
|
||||||
|
mark_task_completed("Task done")
|
||||||
|
reset_completion_flags()
|
||||||
|
assert is_completed() is False
|
||||||
|
assert get_completion_message() == ""
|
||||||
|
|
||||||
|
def test_utility_functions_without_context(self):
|
||||||
|
"""Test utility functions when no context is active."""
|
||||||
|
# These should not raise exceptions even without an active context
|
||||||
|
mark_task_completed("No context")
|
||||||
|
mark_plan_completed("No context")
|
||||||
|
reset_completion_flags()
|
||||||
|
# These should have safe default returns
|
||||||
|
assert is_completed() is False
|
||||||
|
assert get_completion_message() == ""
|
||||||
|
|
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Union
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
from ra_aid.agent_context import get_completion_message, reset_completion_flags
|
||||||
from ra_aid.console.formatting import print_error
|
from ra_aid.console.formatting import print_error
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
|
@ -84,16 +85,12 @@ def request_research(query: str) -> ResearchResult:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
finally:
|
finally:
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = _global_memory.get(
|
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||||
"completion_message",
|
|
||||||
"Task was completed successfully." if success else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
||||||
# Clear completion state from global memory
|
# Clear completion state
|
||||||
_global_memory["completion_message"] = ""
|
reset_completion_flags()
|
||||||
_global_memory["task_completed"] = False
|
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"completion_message": completion_message,
|
"completion_message": completion_message,
|
||||||
|
|
@ -152,16 +149,12 @@ def request_web_research(query: str) -> ResearchResult:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
finally:
|
finally:
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = _global_memory.get(
|
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||||
"completion_message",
|
|
||||||
"Task was completed successfully." if success else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
||||||
# Clear completion state from global memory
|
# Clear completion state
|
||||||
_global_memory["completion_message"] = ""
|
reset_completion_flags()
|
||||||
_global_memory["task_completed"] = False
|
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"completion_message": completion_message,
|
"completion_message": completion_message,
|
||||||
|
|
@ -222,16 +215,12 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
|
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = _global_memory.get(
|
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||||
"completion_message", "Task was completed successfully." if success else None
|
|
||||||
)
|
|
||||||
|
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
||||||
# Clear completion state from global memory
|
# Clear completion state
|
||||||
_global_memory["completion_message"] = ""
|
reset_completion_flags()
|
||||||
_global_memory["task_completed"] = False
|
|
||||||
_global_memory["plan_completed"] = False
|
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"completion_message": completion_message,
|
"completion_message": completion_message,
|
||||||
|
|
@ -276,7 +265,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
|
||||||
# Run implementation agent
|
# Run implementation agent
|
||||||
from ..agent_utils import run_task_implementation_agent
|
from ..agent_utils import run_task_implementation_agent
|
||||||
|
|
||||||
_global_memory["completion_message"] = ""
|
reset_completion_flags()
|
||||||
|
|
||||||
_result = run_task_implementation_agent(
|
_result = run_task_implementation_agent(
|
||||||
base_task=_global_memory.get("base_task", ""),
|
base_task=_global_memory.get("base_task", ""),
|
||||||
|
|
@ -304,16 +293,13 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
|
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = _global_memory.get(
|
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||||
"completion_message", "Task was completed successfully." if success else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get and reset work log if at root depth
|
# Get and reset work log if at root depth
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
||||||
# Clear completion state from global memory
|
# Clear completion state
|
||||||
_global_memory["completion_message"] = ""
|
reset_completion_flags()
|
||||||
_global_memory["task_completed"] = False
|
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"key_facts": get_memory_value("key_facts"),
|
"key_facts": get_memory_value("key_facts"),
|
||||||
|
|
@ -325,6 +311,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
|
||||||
}
|
}
|
||||||
if work_log is not None:
|
if work_log is not None:
|
||||||
response_data["work_log"] = work_log
|
response_data["work_log"] = work_log
|
||||||
|
print("TASK HERE", response_data)
|
||||||
return response_data
|
return response_data
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -347,7 +334,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
|
||||||
# Run planning agent
|
# Run planning agent
|
||||||
from ..agent_utils import run_planning_agent
|
from ..agent_utils import run_planning_agent
|
||||||
|
|
||||||
_global_memory["completion_message"] = ""
|
reset_completion_flags()
|
||||||
|
|
||||||
_result = run_planning_agent(
|
_result = run_planning_agent(
|
||||||
task_spec,
|
task_spec,
|
||||||
|
|
@ -372,17 +359,13 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
|
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = _global_memory.get(
|
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||||
"completion_message", "Task was completed successfully." if success else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get and reset work log if at root depth
|
# Get and reset work log if at root depth
|
||||||
work_log = get_work_log()
|
work_log = get_work_log()
|
||||||
|
|
||||||
# Clear completion state from global memory
|
# Clear completion state
|
||||||
_global_memory["completion_message"] = ""
|
reset_completion_flags()
|
||||||
_global_memory["task_completed"] = False
|
|
||||||
_global_memory["plan_completed"] = False
|
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"completion_message": completion_message,
|
"completion_message": completion_message,
|
||||||
|
|
@ -394,4 +377,6 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
|
||||||
}
|
}
|
||||||
if work_log is not None:
|
if work_log is not None:
|
||||||
response_data["work_log"] = work_log
|
response_data["work_log"] = work_log
|
||||||
|
|
||||||
|
print("HERE", response_data)
|
||||||
return response_data
|
return response_data
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import magic
|
import magic
|
||||||
|
|
@ -12,6 +12,8 @@ from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from ra_aid.agent_context import mark_task_completed, mark_plan_completed
|
||||||
|
|
||||||
|
|
||||||
class WorkLogEntry(TypedDict):
|
class WorkLogEntry(TypedDict):
|
||||||
timestamp: str
|
timestamp: str
|
||||||
|
|
@ -28,25 +30,10 @@ class SnippetInfo(TypedDict):
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
# Global memory store
|
# Global memory store
|
||||||
_global_memory: Dict[
|
_global_memory: Dict[str, Any] = {
|
||||||
str,
|
|
||||||
Union[
|
|
||||||
Dict[int, str],
|
|
||||||
Dict[int, SnippetInfo],
|
|
||||||
Dict[int, WorkLogEntry],
|
|
||||||
int,
|
|
||||||
Set[str],
|
|
||||||
bool,
|
|
||||||
str,
|
|
||||||
List[str],
|
|
||||||
List[WorkLogEntry],
|
|
||||||
],
|
|
||||||
] = {
|
|
||||||
"research_notes": [],
|
"research_notes": [],
|
||||||
"plans": [],
|
"plans": [],
|
||||||
"tasks": {}, # Dict[int, str] - ID to task mapping
|
"tasks": {}, # Dict[int, str] - ID to task mapping
|
||||||
"task_completed": False, # Flag indicating if task is complete
|
|
||||||
"completion_message": "", # Message explaining completion
|
|
||||||
"task_id_counter": 1, # Counter for generating unique task IDs
|
"task_id_counter": 1, # Counter for generating unique task IDs
|
||||||
"key_facts": {}, # Dict[int, str] - ID to fact mapping
|
"key_facts": {}, # Dict[int, str] - ID to fact mapping
|
||||||
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
|
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
|
||||||
|
|
@ -55,7 +42,6 @@ _global_memory: Dict[
|
||||||
"implementation_requested": False,
|
"implementation_requested": False,
|
||||||
"related_files": {}, # Dict[int, str] - ID to filepath mapping
|
"related_files": {}, # Dict[int, str] - ID to filepath mapping
|
||||||
"related_file_id_counter": 1, # Counter for generating unique file IDs
|
"related_file_id_counter": 1, # Counter for generating unique file IDs
|
||||||
"plan_completed": False,
|
|
||||||
"agent_depth": 0,
|
"agent_depth": 0,
|
||||||
"work_log": [], # List[WorkLogEntry] - Timestamped work events
|
"work_log": [], # List[WorkLogEntry] - Timestamped work events
|
||||||
}
|
}
|
||||||
|
|
@ -327,10 +313,9 @@ def one_shot_completed(message: str) -> str:
|
||||||
if _global_memory.get("implementation_requested", False):
|
if _global_memory.get("implementation_requested", False):
|
||||||
return "Cannot complete in one shot - implementation was requested"
|
return "Cannot complete in one shot - implementation was requested"
|
||||||
|
|
||||||
_global_memory["task_completed"] = True
|
mark_task_completed(message)
|
||||||
_global_memory["completion_message"] = message
|
|
||||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||||
log_work_event(f"Task completed\n\n{message}")
|
log_work_event(f"Task completed:\n\n{message}")
|
||||||
return "Completion noted."
|
return "Completion noted."
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -341,9 +326,9 @@ def task_completed(message: str) -> str:
|
||||||
Args:
|
Args:
|
||||||
message: Message explaining how/why the task is complete
|
message: Message explaining how/why the task is complete
|
||||||
"""
|
"""
|
||||||
_global_memory["task_completed"] = True
|
mark_task_completed(message)
|
||||||
_global_memory["completion_message"] = message
|
|
||||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||||
|
log_work_event(f"Task completed:\n\n{message}")
|
||||||
return "Completion noted."
|
return "Completion noted."
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -354,14 +339,11 @@ def plan_implementation_completed(message: str) -> str:
|
||||||
Args:
|
Args:
|
||||||
message: Message explaining how the implementation plan was completed
|
message: Message explaining how the implementation plan was completed
|
||||||
"""
|
"""
|
||||||
_global_memory["task_completed"] = True
|
mark_plan_completed(message)
|
||||||
_global_memory["completion_message"] = message
|
|
||||||
_global_memory["plan_completed"] = True
|
|
||||||
_global_memory["completion_message"] = message
|
|
||||||
_global_memory["tasks"].clear() # Clear task list when plan is completed
|
_global_memory["tasks"].clear() # Clear task list when plan is completed
|
||||||
_global_memory["task_id_counter"] = 1
|
_global_memory["task_id_counter"] = 1
|
||||||
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
||||||
log_work_event(f"Plan execution completed:\n\n{message}")
|
log_work_event(f"Completed implementation:\n\n{message}")
|
||||||
return "Plan completion noted and task list cleared."
|
return "Plan completion noted and task list cleared."
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,178 @@
|
||||||
|
"""Unit tests for the agent_context module."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ra_aid.agent_context import (
|
||||||
|
AgentContext,
|
||||||
|
agent_context,
|
||||||
|
get_current_context,
|
||||||
|
mark_task_completed,
|
||||||
|
mark_plan_completed,
|
||||||
|
reset_completion_flags,
|
||||||
|
is_completed,
|
||||||
|
get_completion_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentContext:
|
||||||
|
"""Test cases for the AgentContext class and related functions."""
|
||||||
|
|
||||||
|
def test_context_creation(self):
|
||||||
|
"""Test creating a new context."""
|
||||||
|
context = AgentContext()
|
||||||
|
assert context.task_completed is False
|
||||||
|
assert context.plan_completed is False
|
||||||
|
assert context.completion_message == ""
|
||||||
|
|
||||||
|
def test_context_inheritance(self):
|
||||||
|
"""Test that child contexts inherit state from parent contexts."""
|
||||||
|
parent = AgentContext()
|
||||||
|
parent.mark_task_completed("Parent task completed")
|
||||||
|
child = AgentContext(parent_context=parent)
|
||||||
|
assert child.task_completed is True
|
||||||
|
assert child.completion_message == "Parent task completed"
|
||||||
|
|
||||||
|
def test_mark_task_completed(self):
|
||||||
|
"""Test marking a task as completed."""
|
||||||
|
context = AgentContext()
|
||||||
|
context.mark_task_completed("Task done")
|
||||||
|
assert context.task_completed is True
|
||||||
|
assert context.plan_completed is False
|
||||||
|
assert context.completion_message == "Task done"
|
||||||
|
|
||||||
|
def test_mark_plan_completed(self):
|
||||||
|
"""Test marking a plan as completed."""
|
||||||
|
context = AgentContext()
|
||||||
|
context.mark_plan_completed("Plan done")
|
||||||
|
assert context.task_completed is True
|
||||||
|
assert context.plan_completed is True
|
||||||
|
assert context.completion_message == "Plan done"
|
||||||
|
|
||||||
|
def test_reset_completion_flags(self):
|
||||||
|
"""Test resetting completion flags."""
|
||||||
|
context = AgentContext()
|
||||||
|
context.mark_task_completed("Task done")
|
||||||
|
context.reset_completion_flags()
|
||||||
|
assert context.task_completed is False
|
||||||
|
assert context.plan_completed is False
|
||||||
|
assert context.completion_message == ""
|
||||||
|
|
||||||
|
def test_is_completed_property(self):
|
||||||
|
"""Test the is_completed property."""
|
||||||
|
context = AgentContext()
|
||||||
|
assert context.is_completed is False
|
||||||
|
context.mark_task_completed("Task done")
|
||||||
|
assert context.is_completed is True
|
||||||
|
context.reset_completion_flags()
|
||||||
|
assert context.is_completed is False
|
||||||
|
context.mark_plan_completed("Plan done")
|
||||||
|
assert context.is_completed is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextManager:
|
||||||
|
"""Test cases for the agent_context context manager."""
|
||||||
|
|
||||||
|
def test_context_manager_basic(self):
|
||||||
|
"""Test basic context manager functionality."""
|
||||||
|
assert get_current_context() is None
|
||||||
|
with agent_context() as ctx:
|
||||||
|
assert get_current_context() is ctx
|
||||||
|
assert ctx.task_completed is False
|
||||||
|
assert get_current_context() is None
|
||||||
|
|
||||||
|
def test_nested_context_managers(self):
|
||||||
|
"""Test nested context managers."""
|
||||||
|
with agent_context() as outer_ctx:
|
||||||
|
assert get_current_context() is outer_ctx
|
||||||
|
with agent_context() as inner_ctx:
|
||||||
|
assert get_current_context() is inner_ctx
|
||||||
|
assert inner_ctx is not outer_ctx
|
||||||
|
assert get_current_context() is outer_ctx
|
||||||
|
|
||||||
|
def test_context_manager_with_parent(self):
|
||||||
|
"""Test context manager with explicit parent context."""
|
||||||
|
parent = AgentContext()
|
||||||
|
parent.mark_task_completed("Parent task")
|
||||||
|
with agent_context(parent_context=parent) as ctx:
|
||||||
|
assert ctx.task_completed is True
|
||||||
|
assert ctx.completion_message == "Parent task"
|
||||||
|
|
||||||
|
def test_context_manager_inheritance(self):
|
||||||
|
"""Test that nested contexts inherit from outer contexts by default."""
|
||||||
|
with agent_context() as outer:
|
||||||
|
outer.mark_task_completed("Outer task")
|
||||||
|
with agent_context() as inner:
|
||||||
|
assert inner.task_completed is True
|
||||||
|
assert inner.completion_message == "Outer task"
|
||||||
|
inner.mark_plan_completed("Inner plan")
|
||||||
|
# Outer context should not be affected by inner context changes
|
||||||
|
assert outer.task_completed is True
|
||||||
|
assert outer.plan_completed is False
|
||||||
|
assert outer.completion_message == "Outer task"
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadIsolation:
|
||||||
|
"""Test thread isolation of context variables."""
|
||||||
|
|
||||||
|
def test_thread_isolation(self):
|
||||||
|
"""Test that contexts are isolated between threads."""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
def thread_func(thread_id):
|
||||||
|
with agent_context() as ctx:
|
||||||
|
ctx.mark_task_completed(f"Thread {thread_id}")
|
||||||
|
time.sleep(0.1) # Give other threads time to run
|
||||||
|
# Store the context's message for verification
|
||||||
|
results[thread_id] = get_completion_message()
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(3):
|
||||||
|
t = threading.Thread(target=thread_func, args=(i,))
|
||||||
|
threads.append(t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Each thread should have its own message
|
||||||
|
assert results[0] == "Thread 0"
|
||||||
|
assert results[1] == "Thread 1"
|
||||||
|
assert results[2] == "Thread 2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtilityFunctions:
|
||||||
|
"""Test utility functions that operate on the current context."""
|
||||||
|
|
||||||
|
def test_mark_task_completed_utility(self):
|
||||||
|
"""Test the mark_task_completed utility function."""
|
||||||
|
with agent_context():
|
||||||
|
mark_task_completed("Task done via utility")
|
||||||
|
assert is_completed() is True
|
||||||
|
assert get_completion_message() == "Task done via utility"
|
||||||
|
|
||||||
|
def test_mark_plan_completed_utility(self):
|
||||||
|
"""Test the mark_plan_completed utility function."""
|
||||||
|
with agent_context():
|
||||||
|
mark_plan_completed("Plan done via utility")
|
||||||
|
assert is_completed() is True
|
||||||
|
assert get_completion_message() == "Plan done via utility"
|
||||||
|
|
||||||
|
def test_reset_completion_flags_utility(self):
|
||||||
|
"""Test the reset_completion_flags utility function."""
|
||||||
|
with agent_context():
|
||||||
|
mark_task_completed("Task done")
|
||||||
|
reset_completion_flags()
|
||||||
|
assert is_completed() is False
|
||||||
|
assert get_completion_message() == ""
|
||||||
|
|
||||||
|
def test_utility_functions_without_context(self):
|
||||||
|
"""Test utility functions when no context is active."""
|
||||||
|
# These should not raise exceptions even without an active context
|
||||||
|
mark_task_completed("No context")
|
||||||
|
mark_plan_completed("No context")
|
||||||
|
reset_completion_flags()
|
||||||
|
# These should have safe default returns
|
||||||
|
assert is_completed() is False
|
||||||
|
assert get_completion_message() == ""
|
||||||
|
|
@ -8,6 +8,7 @@ import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from ra_aid.agent_context import agent_context, get_current_context, reset_completion_flags
|
||||||
from ra_aid.agent_utils import (
|
from ra_aid.agent_utils import (
|
||||||
AgentState,
|
AgentState,
|
||||||
create_agent,
|
create_agent,
|
||||||
|
|
@ -325,7 +326,7 @@ def test_increment_and_decrement_agent_depth():
|
||||||
|
|
||||||
|
|
||||||
def test_run_agent_stream(monkeypatch):
|
def test_run_agent_stream(monkeypatch):
|
||||||
from ra_aid.agent_utils import _global_memory, _run_agent_stream
|
from ra_aid.agent_utils import _run_agent_stream
|
||||||
|
|
||||||
# Create a dummy agent that yields one chunk
|
# Create a dummy agent that yields one chunk
|
||||||
class DummyAgent:
|
class DummyAgent:
|
||||||
|
|
@ -334,9 +335,11 @@ def test_run_agent_stream(monkeypatch):
|
||||||
|
|
||||||
dummy_agent = DummyAgent()
|
dummy_agent = DummyAgent()
|
||||||
# Set flags so that _run_agent_stream will reset them
|
# Set flags so that _run_agent_stream will reset them
|
||||||
_global_memory["plan_completed"] = True
|
with agent_context() as ctx:
|
||||||
_global_memory["task_completed"] = True
|
ctx.plan_completed = True
|
||||||
_global_memory["completion_message"] = "existing"
|
ctx.task_completed = True
|
||||||
|
ctx.completion_message = "existing"
|
||||||
|
|
||||||
call_flag = {"called": False}
|
call_flag = {"called": False}
|
||||||
|
|
||||||
def fake_print_agent_output(
|
def fake_print_agent_output(
|
||||||
|
|
@ -349,9 +352,11 @@ def test_run_agent_stream(monkeypatch):
|
||||||
)
|
)
|
||||||
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
|
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
|
||||||
assert call_flag["called"]
|
assert call_flag["called"]
|
||||||
assert _global_memory["plan_completed"] is False
|
|
||||||
assert _global_memory["task_completed"] is False
|
with agent_context() as ctx:
|
||||||
assert _global_memory["completion_message"] == ""
|
assert ctx.plan_completed is False
|
||||||
|
assert ctx.task_completed is False
|
||||||
|
assert ctx.completion_message == ""
|
||||||
|
|
||||||
|
|
||||||
def test_execute_test_command_wrapper(monkeypatch):
|
def test_execute_test_command_wrapper(monkeypatch):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue