agent context
This commit is contained in:
parent
724dbd4fda
commit
28d9032ca5
|
|
@ -0,0 +1,150 @@
|
|||
"""Context manager for tracking agent state and completion status."""
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
# Thread-local storage for context variables
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
class AgentContext:
|
||||
"""Context manager for agent state tracking."""
|
||||
|
||||
def __init__(self, parent_context=None):
|
||||
"""Initialize a new agent context.
|
||||
|
||||
Args:
|
||||
parent_context: Optional parent context to inherit state from
|
||||
"""
|
||||
# Initialize completion flags
|
||||
self.task_completed = False
|
||||
self.plan_completed = False
|
||||
self.completion_message = ""
|
||||
|
||||
# Inherit state from parent if provided
|
||||
if parent_context:
|
||||
self.task_completed = parent_context.task_completed
|
||||
self.plan_completed = parent_context.plan_completed
|
||||
self.completion_message = parent_context.completion_message
|
||||
|
||||
def mark_task_completed(self, message: str) -> None:
|
||||
"""Mark the current task as completed.
|
||||
|
||||
Args:
|
||||
message: Completion message explaining how/why the task is complete
|
||||
"""
|
||||
self.task_completed = True
|
||||
self.completion_message = message
|
||||
|
||||
def mark_plan_completed(self, message: str) -> None:
|
||||
"""Mark the current plan as completed.
|
||||
|
||||
Args:
|
||||
message: Completion message explaining how the plan was completed
|
||||
"""
|
||||
self.task_completed = True
|
||||
self.plan_completed = True
|
||||
self.completion_message = message
|
||||
|
||||
def reset_completion_flags(self) -> None:
|
||||
"""Reset all completion flags."""
|
||||
self.task_completed = False
|
||||
self.plan_completed = False
|
||||
self.completion_message = ""
|
||||
|
||||
@property
|
||||
def is_completed(self) -> bool:
|
||||
"""Check if the current context is marked as completed."""
|
||||
return self.task_completed or self.plan_completed
|
||||
|
||||
|
||||
def get_current_context() -> Optional[AgentContext]:
|
||||
"""Get the current agent context for this thread.
|
||||
|
||||
Returns:
|
||||
The current AgentContext or None if no context is active
|
||||
"""
|
||||
return getattr(_thread_local, "current_context", None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def agent_context(parent_context=None):
|
||||
"""Context manager for agent execution.
|
||||
|
||||
Creates a new agent context and makes it the current context for the duration
|
||||
of the with block. Restores the previous context when exiting the block.
|
||||
|
||||
Args:
|
||||
parent_context: Optional parent context to inherit state from
|
||||
|
||||
Yields:
|
||||
The newly created AgentContext
|
||||
"""
|
||||
# Save the previous context
|
||||
previous_context = getattr(_thread_local, "current_context", None)
|
||||
|
||||
# Create a new context, inheriting from parent if provided
|
||||
# If parent_context is None but previous_context exists, use previous_context as parent
|
||||
if parent_context is None and previous_context is not None:
|
||||
context = AgentContext(previous_context)
|
||||
else:
|
||||
context = AgentContext(parent_context)
|
||||
|
||||
# Set as current context
|
||||
_thread_local.current_context = context
|
||||
|
||||
try:
|
||||
yield context
|
||||
finally:
|
||||
# Restore previous context
|
||||
_thread_local.current_context = previous_context
|
||||
|
||||
|
||||
def mark_task_completed(message: str) -> None:
|
||||
"""Mark the current task as completed.
|
||||
|
||||
Args:
|
||||
message: Completion message explaining how/why the task is complete
|
||||
"""
|
||||
context = get_current_context()
|
||||
if context:
|
||||
context.mark_task_completed(message)
|
||||
|
||||
|
||||
def mark_plan_completed(message: str) -> None:
|
||||
"""Mark the current plan as completed.
|
||||
|
||||
Args:
|
||||
message: Completion message explaining how the plan was completed
|
||||
"""
|
||||
context = get_current_context()
|
||||
if context:
|
||||
context.mark_plan_completed(message)
|
||||
|
||||
|
||||
def reset_completion_flags() -> None:
|
||||
"""Reset completion flags in the current context."""
|
||||
context = get_current_context()
|
||||
if context:
|
||||
context.reset_completion_flags()
|
||||
|
||||
|
||||
def is_completed() -> bool:
|
||||
"""Check if the current context is marked as completed.
|
||||
|
||||
Returns:
|
||||
True if the current context is marked as completed, False otherwise
|
||||
"""
|
||||
context = get_current_context()
|
||||
return context.is_completed if context else False
|
||||
|
||||
|
||||
def get_completion_message() -> str:
|
||||
"""Get the completion message from the current context.
|
||||
|
||||
Returns:
|
||||
The completion message or empty string if no context or no message
|
||||
"""
|
||||
context = get_current_context()
|
||||
return context.completion_message if context else ""
|
||||
|
|
@ -7,7 +7,7 @@ import threading
|
|||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, ContextManager
|
||||
|
||||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
|
|
@ -69,6 +69,13 @@ from ra_aid.tool_configs import (
|
|||
get_web_research_tools,
|
||||
)
|
||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||
from ra_aid.agent_context import (
|
||||
agent_context,
|
||||
get_current_context,
|
||||
is_completed,
|
||||
reset_completion_flags,
|
||||
get_completion_message,
|
||||
)
|
||||
from ra_aid.tools.memory import (
|
||||
_global_memory,
|
||||
get_memory_value,
|
||||
|
|
@ -821,9 +828,8 @@ def _decrement_agent_depth():
|
|||
|
||||
|
||||
def reset_agent_completion_flags():
|
||||
_global_memory["plan_completed"] = False
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["completion_message"] = ""
|
||||
"""Reset completion flags in the current context."""
|
||||
reset_completion_flags()
|
||||
|
||||
|
||||
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
|
||||
|
|
@ -897,8 +903,8 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
|
|||
check_interrupt()
|
||||
agent_type = get_agent_type(agent)
|
||||
print_agent_output(chunk, agent_type)
|
||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||
reset_agent_completion_flags()
|
||||
if is_completed():
|
||||
reset_completion_flags()
|
||||
break
|
||||
|
||||
|
||||
|
|
@ -919,7 +925,8 @@ def run_agent_with_retry(
|
|||
original_prompt = prompt
|
||||
msg_list = [HumanMessage(content=prompt)]
|
||||
|
||||
with InterruptibleSection():
|
||||
# Create a new agent context for this run
|
||||
with InterruptibleSection(), agent_context() as ctx:
|
||||
try:
|
||||
_increment_agent_depth()
|
||||
for attempt in range(max_retries):
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ def create_llm_client(
|
|||
if supports_thinking:
|
||||
temp_kwargs = {"thinking": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": 8000
|
||||
"budget_tokens": 12000
|
||||
}}
|
||||
|
||||
if provider == "deepseek":
|
||||
|
|
|
|||
|
|
@ -527,6 +527,8 @@ FOLLOW TEST DRIVEN DEVELOPMENT (TDD) PRACTICES WHERE POSSIBE. E.G. COMPILE CODE
|
|||
|
||||
IF YOU CAN SEE THE CODE WRITTEN/CHANGED BY THE PROGRAMMER, TRUST IT. YOU DO NOT NEED TO RE-READ EVERY FILE WITH EVERY SMALL EDIT.
|
||||
|
||||
YOU MUST CALL emit_related_files BEFORE CALLING run_programming_task WITH ALL RELEVANT FILES, UNLESS THEY ARE ALREADY RECORDED AS RELATED FILES.
|
||||
|
||||
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,372 @@
|
|||
"""Unit tests for the agent_context module."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
get_current_context,
|
||||
mark_task_completed,
|
||||
mark_plan_completed,
|
||||
reset_completion_flags,
|
||||
is_completed,
|
||||
get_completion_message,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentContext:
|
||||
"""Test cases for the AgentContext class and related functions."""
|
||||
|
||||
def test_context_creation(self):
|
||||
"""Test creating a new context."""
|
||||
context = AgentContext()
|
||||
assert context.task_completed is False
|
||||
assert context.plan_completed is False
|
||||
assert context.completion_message == ""
|
||||
|
||||
def test_context_inheritance(self):
|
||||
"""Test that child contexts inherit state from parent contexts."""
|
||||
parent = AgentContext()
|
||||
parent.mark_task_completed("Parent task completed")
|
||||
|
||||
child = AgentContext(parent_context=parent)
|
||||
assert child.task_completed is True
|
||||
assert child.completion_message == "Parent task completed"
|
||||
|
||||
def test_mark_task_completed(self):
|
||||
"""Test marking a task as completed."""
|
||||
context = AgentContext()
|
||||
context.mark_task_completed("Task done")
|
||||
|
||||
assert context.task_completed is True
|
||||
assert context.plan_completed is False
|
||||
assert context.completion_message == "Task done"
|
||||
|
||||
def test_mark_plan_completed(self):
|
||||
"""Test marking a plan as completed."""
|
||||
context = AgentContext()
|
||||
context.mark_plan_completed("Plan done")
|
||||
|
||||
assert context.task_completed is True
|
||||
assert context.plan_completed is True
|
||||
assert context.completion_message == "Plan done"
|
||||
|
||||
def test_reset_completion_flags(self):
|
||||
"""Test resetting completion flags."""
|
||||
context = AgentContext()
|
||||
context.mark_task_completed("Task done")
|
||||
|
||||
context.reset_completion_flags()
|
||||
assert context.task_completed is False
|
||||
assert context.plan_completed is False
|
||||
assert context.completion_message == ""
|
||||
|
||||
def test_is_completed_property(self):
|
||||
"""Test the is_completed property."""
|
||||
context = AgentContext()
|
||||
assert context.is_completed is False
|
||||
|
||||
context.mark_task_completed("Task done")
|
||||
assert context.is_completed is True
|
||||
|
||||
context.reset_completion_flags()
|
||||
assert context.is_completed is False
|
||||
|
||||
context.mark_plan_completed("Plan done")
|
||||
assert context.is_completed is True
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test cases for the agent_context context manager."""
|
||||
|
||||
def test_context_manager_basic(self):
|
||||
"""Test basic context manager functionality."""
|
||||
assert get_current_context() is None
|
||||
|
||||
with agent_context() as ctx:
|
||||
assert get_current_context() is ctx
|
||||
assert ctx.task_completed is False
|
||||
|
||||
assert get_current_context() is None
|
||||
|
||||
def test_nested_context_managers(self):
|
||||
"""Test nested context managers."""
|
||||
with agent_context() as outer_ctx:
|
||||
assert get_current_context() is outer_ctx
|
||||
|
||||
with agent_context() as inner_ctx:
|
||||
assert get_current_context() is inner_ctx
|
||||
assert inner_ctx is not outer_ctx
|
||||
|
||||
assert get_current_context() is outer_ctx
|
||||
|
||||
def test_context_manager_with_parent(self):
|
||||
"""Test context manager with explicit parent context."""
|
||||
parent = AgentContext()
|
||||
parent.mark_task_completed("Parent task")
|
||||
|
||||
with agent_context(parent_context=parent) as ctx:
|
||||
assert ctx.task_completed is True
|
||||
assert ctx.completion_message == "Parent task"
|
||||
|
||||
def test_context_manager_inheritance(self):
|
||||
"""Test that nested contexts inherit from outer contexts by default."""
|
||||
with agent_context() as outer:
|
||||
outer.mark_task_completed("Outer task")
|
||||
|
||||
with agent_context() as inner:
|
||||
assert inner.task_completed is True
|
||||
assert inner.completion_message == "Outer task"
|
||||
|
||||
inner.mark_plan_completed("Inner plan")
|
||||
|
||||
# Outer context should not be affected by inner context changes
|
||||
assert outer.task_completed is True
|
||||
assert outer.plan_completed is False
|
||||
assert outer.completion_message == "Outer task"
|
||||
|
||||
|
||||
class TestThreadIsolation:
|
||||
"""Test thread isolation of context variables."""
|
||||
|
||||
def test_thread_isolation(self):
|
||||
"""Test that contexts are isolated between threads."""
|
||||
results = {}
|
||||
|
||||
def thread_func(thread_id):
|
||||
with agent_context() as ctx:
|
||||
ctx.mark_task_completed(f"Thread {thread_id}")
|
||||
time.sleep(0.1) # Give other threads time to run
|
||||
# Store the context's message for verification
|
||||
results[thread_id] = get_completion_message()
|
||||
|
||||
threads = []
|
||||
for i in range(3):
|
||||
t = threading.Thread(target=thread_func, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Each thread should have its own message
|
||||
assert results[0] == "Thread 0"
|
||||
assert results[1] == "Thread 1"
|
||||
assert results[2] == "Thread 2"
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test utility functions that operate on the current context."""
|
||||
|
||||
def test_mark_task_completed_utility(self):
|
||||
"""Test the mark_task_completed utility function."""
|
||||
with agent_context():
|
||||
mark_task_completed("Task done via utility")
|
||||
assert is_completed() is True
|
||||
assert get_completion_message() == "Task done via utility"
|
||||
|
||||
def test_mark_plan_completed_utility(self):
|
||||
"""Test the mark_plan_completed utility function."""
|
||||
with agent_context():
|
||||
mark_plan_completed("Plan done via utility")
|
||||
assert is_completed() is True
|
||||
assert get_completion_message() == "Plan done via utility"
|
||||
|
||||
def test_reset_completion_flags_utility(self):
|
||||
"""Test the reset_completion_flags utility function."""
|
||||
with agent_context():
|
||||
mark_task_completed("Task done")
|
||||
reset_completion_flags()
|
||||
assert is_completed() is False
|
||||
assert get_completion_message() == ""
|
||||
|
||||
def test_utility_functions_without_context(self):
|
||||
"""Test utility functions when no context is active."""
|
||||
# These should not raise exceptions even without an active context
|
||||
mark_task_completed("No context")
|
||||
mark_plan_completed("No context")
|
||||
reset_completion_flags()
|
||||
|
||||
# These should have safe default returns
|
||||
assert is_completed() is False
|
||||
assert get_completion_message() == ""
|
||||
"""Unit tests for the agent_context module."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
get_current_context,
|
||||
mark_task_completed,
|
||||
mark_plan_completed,
|
||||
reset_completion_flags,
|
||||
is_completed,
|
||||
get_completion_message,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentContext:
|
||||
"""Test cases for the AgentContext class and related functions."""
|
||||
|
||||
def test_context_creation(self):
|
||||
"""Test creating a new context."""
|
||||
context = AgentContext()
|
||||
assert context.task_completed is False
|
||||
assert context.plan_completed is False
|
||||
assert context.completion_message == ""
|
||||
|
||||
def test_context_inheritance(self):
|
||||
"""Test that child contexts inherit state from parent contexts."""
|
||||
parent = AgentContext()
|
||||
parent.mark_task_completed("Parent task completed")
|
||||
child = AgentContext(parent_context=parent)
|
||||
assert child.task_completed is True
|
||||
assert child.completion_message == "Parent task completed"
|
||||
|
||||
def test_mark_task_completed(self):
|
||||
"""Test marking a task as completed."""
|
||||
context = AgentContext()
|
||||
context.mark_task_completed("Task done")
|
||||
assert context.task_completed is True
|
||||
assert context.plan_completed is False
|
||||
assert context.completion_message == "Task done"
|
||||
|
||||
def test_mark_plan_completed(self):
|
||||
"""Test marking a plan as completed."""
|
||||
context = AgentContext()
|
||||
context.mark_plan_completed("Plan done")
|
||||
assert context.task_completed is True
|
||||
assert context.plan_completed is True
|
||||
assert context.completion_message == "Plan done"
|
||||
|
||||
def test_reset_completion_flags(self):
|
||||
"""Test resetting completion flags."""
|
||||
context = AgentContext()
|
||||
context.mark_task_completed("Task done")
|
||||
context.reset_completion_flags()
|
||||
assert context.task_completed is False
|
||||
assert context.plan_completed is False
|
||||
assert context.completion_message == ""
|
||||
|
||||
def test_is_completed_property(self):
|
||||
"""Test the is_completed property."""
|
||||
context = AgentContext()
|
||||
assert context.is_completed is False
|
||||
context.mark_task_completed("Task done")
|
||||
assert context.is_completed is True
|
||||
context.reset_completion_flags()
|
||||
assert context.is_completed is False
|
||||
context.mark_plan_completed("Plan done")
|
||||
assert context.is_completed is True
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test cases for the agent_context context manager."""
|
||||
|
||||
def test_context_manager_basic(self):
|
||||
"""Test basic context manager functionality."""
|
||||
assert get_current_context() is None
|
||||
with agent_context() as ctx:
|
||||
assert get_current_context() is ctx
|
||||
assert ctx.task_completed is False
|
||||
assert get_current_context() is None
|
||||
|
||||
def test_nested_context_managers(self):
|
||||
"""Test nested context managers."""
|
||||
with agent_context() as outer_ctx:
|
||||
assert get_current_context() is outer_ctx
|
||||
with agent_context() as inner_ctx:
|
||||
assert get_current_context() is inner_ctx
|
||||
assert inner_ctx is not outer_ctx
|
||||
assert get_current_context() is outer_ctx
|
||||
|
||||
def test_context_manager_with_parent(self):
|
||||
"""Test context manager with explicit parent context."""
|
||||
parent = AgentContext()
|
||||
parent.mark_task_completed("Parent task")
|
||||
with agent_context(parent_context=parent) as ctx:
|
||||
assert ctx.task_completed is True
|
||||
assert ctx.completion_message == "Parent task"
|
||||
|
||||
def test_context_manager_inheritance(self):
|
||||
"""Test that nested contexts inherit from outer contexts by default."""
|
||||
with agent_context() as outer:
|
||||
outer.mark_task_completed("Outer task")
|
||||
with agent_context() as inner:
|
||||
assert inner.task_completed is True
|
||||
assert inner.completion_message == "Outer task"
|
||||
inner.mark_plan_completed("Inner plan")
|
||||
# Outer context should not be affected by inner context changes
|
||||
assert outer.task_completed is True
|
||||
assert outer.plan_completed is False
|
||||
assert outer.completion_message == "Outer task"
|
||||
|
||||
|
||||
class TestThreadIsolation:
|
||||
"""Test thread isolation of context variables."""
|
||||
|
||||
def test_thread_isolation(self):
|
||||
"""Test that contexts are isolated between threads."""
|
||||
results = {}
|
||||
|
||||
def thread_func(thread_id):
|
||||
with agent_context() as ctx:
|
||||
ctx.mark_task_completed(f"Thread {thread_id}")
|
||||
time.sleep(0.1) # Give other threads time to run
|
||||
# Store the context's message for verification
|
||||
results[thread_id] = get_completion_message()
|
||||
|
||||
threads = []
|
||||
for i in range(3):
|
||||
t = threading.Thread(target=thread_func, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Each thread should have its own message
|
||||
assert results[0] == "Thread 0"
|
||||
assert results[1] == "Thread 1"
|
||||
assert results[2] == "Thread 2"
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test utility functions that operate on the current context."""
|
||||
|
||||
def test_mark_task_completed_utility(self):
|
||||
"""Test the mark_task_completed utility function."""
|
||||
with agent_context():
|
||||
mark_task_completed("Task done via utility")
|
||||
assert is_completed() is True
|
||||
assert get_completion_message() == "Task done via utility"
|
||||
|
||||
def test_mark_plan_completed_utility(self):
|
||||
"""Test the mark_plan_completed utility function."""
|
||||
with agent_context():
|
||||
mark_plan_completed("Plan done via utility")
|
||||
assert is_completed() is True
|
||||
assert get_completion_message() == "Plan done via utility"
|
||||
|
||||
def test_reset_completion_flags_utility(self):
|
||||
"""Test the reset_completion_flags utility function."""
|
||||
with agent_context():
|
||||
mark_task_completed("Task done")
|
||||
reset_completion_flags()
|
||||
assert is_completed() is False
|
||||
assert get_completion_message() == ""
|
||||
|
||||
def test_utility_functions_without_context(self):
|
||||
"""Test utility functions when no context is active."""
|
||||
# These should not raise exceptions even without an active context
|
||||
mark_task_completed("No context")
|
||||
mark_plan_completed("No context")
|
||||
reset_completion_flags()
|
||||
# These should have safe default returns
|
||||
assert is_completed() is False
|
||||
assert get_completion_message() == ""
|
||||
|
|
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Union
|
|||
from langchain_core.tools import tool
|
||||
from rich.console import Console
|
||||
|
||||
from ra_aid.agent_context import get_completion_message, reset_completion_flags
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
|
@ -84,16 +85,12 @@ def request_research(query: str) -> ResearchResult:
|
|||
reason = f"error: {str(e)}"
|
||||
finally:
|
||||
# Get completion message if available
|
||||
completion_message = _global_memory.get(
|
||||
"completion_message",
|
||||
"Task was completed successfully." if success else None,
|
||||
)
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
|
||||
work_log = get_work_log()
|
||||
|
||||
# Clear completion state from global memory
|
||||
_global_memory["completion_message"] = ""
|
||||
_global_memory["task_completed"] = False
|
||||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
|
|
@ -152,16 +149,12 @@ def request_web_research(query: str) -> ResearchResult:
|
|||
reason = f"error: {str(e)}"
|
||||
finally:
|
||||
# Get completion message if available
|
||||
completion_message = _global_memory.get(
|
||||
"completion_message",
|
||||
"Task was completed successfully." if success else None,
|
||||
)
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
|
||||
work_log = get_work_log()
|
||||
|
||||
# Clear completion state from global memory
|
||||
_global_memory["completion_message"] = ""
|
||||
_global_memory["task_completed"] = False
|
||||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
|
|
@ -222,16 +215,12 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
|||
reason = f"error: {str(e)}"
|
||||
|
||||
# Get completion message if available
|
||||
completion_message = _global_memory.get(
|
||||
"completion_message", "Task was completed successfully." if success else None
|
||||
)
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
|
||||
work_log = get_work_log()
|
||||
|
||||
# Clear completion state from global memory
|
||||
_global_memory["completion_message"] = ""
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["plan_completed"] = False
|
||||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
|
|
@ -276,7 +265,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
|
|||
# Run implementation agent
|
||||
from ..agent_utils import run_task_implementation_agent
|
||||
|
||||
_global_memory["completion_message"] = ""
|
||||
reset_completion_flags()
|
||||
|
||||
_result = run_task_implementation_agent(
|
||||
base_task=_global_memory.get("base_task", ""),
|
||||
|
|
@ -304,16 +293,13 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
|
|||
reason = f"error: {str(e)}"
|
||||
|
||||
# Get completion message if available
|
||||
completion_message = _global_memory.get(
|
||||
"completion_message", "Task was completed successfully." if success else None
|
||||
)
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
|
||||
# Get and reset work log if at root depth
|
||||
work_log = get_work_log()
|
||||
|
||||
# Clear completion state from global memory
|
||||
_global_memory["completion_message"] = ""
|
||||
_global_memory["task_completed"] = False
|
||||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
response_data = {
|
||||
"key_facts": get_memory_value("key_facts"),
|
||||
|
|
@ -325,6 +311,7 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
|
|||
}
|
||||
if work_log is not None:
|
||||
response_data["work_log"] = work_log
|
||||
print("TASK HERE", response_data)
|
||||
return response_data
|
||||
|
||||
|
||||
|
|
@ -347,7 +334,7 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
|
|||
# Run planning agent
|
||||
from ..agent_utils import run_planning_agent
|
||||
|
||||
_global_memory["completion_message"] = ""
|
||||
reset_completion_flags()
|
||||
|
||||
_result = run_planning_agent(
|
||||
task_spec,
|
||||
|
|
@ -372,17 +359,13 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
|
|||
reason = f"error: {str(e)}"
|
||||
|
||||
# Get completion message if available
|
||||
completion_message = _global_memory.get(
|
||||
"completion_message", "Task was completed successfully." if success else None
|
||||
)
|
||||
completion_message = get_completion_message() or ("Task was completed successfully." if success else None)
|
||||
|
||||
# Get and reset work log if at root depth
|
||||
work_log = get_work_log()
|
||||
|
||||
# Clear completion state from global memory
|
||||
_global_memory["completion_message"] = ""
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["plan_completed"] = False
|
||||
# Clear completion state
|
||||
reset_completion_flags()
|
||||
|
||||
response_data = {
|
||||
"completion_message": completion_message,
|
||||
|
|
@ -394,4 +377,6 @@ def request_implementation(task_spec: str) -> Dict[str, Any]:
|
|||
}
|
||||
if work_log is not None:
|
||||
response_data["work_log"] = work_log
|
||||
|
||||
print("HERE", response_data)
|
||||
return response_data
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
try:
|
||||
import magic
|
||||
|
|
@ -12,6 +12,8 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from ra_aid.agent_context import mark_task_completed, mark_plan_completed
|
||||
|
||||
|
||||
class WorkLogEntry(TypedDict):
|
||||
timestamp: str
|
||||
|
|
@ -28,25 +30,10 @@ class SnippetInfo(TypedDict):
|
|||
console = Console()
|
||||
|
||||
# Global memory store
|
||||
_global_memory: Dict[
|
||||
str,
|
||||
Union[
|
||||
Dict[int, str],
|
||||
Dict[int, SnippetInfo],
|
||||
Dict[int, WorkLogEntry],
|
||||
int,
|
||||
Set[str],
|
||||
bool,
|
||||
str,
|
||||
List[str],
|
||||
List[WorkLogEntry],
|
||||
],
|
||||
] = {
|
||||
_global_memory: Dict[str, Any] = {
|
||||
"research_notes": [],
|
||||
"plans": [],
|
||||
"tasks": {}, # Dict[int, str] - ID to task mapping
|
||||
"task_completed": False, # Flag indicating if task is complete
|
||||
"completion_message": "", # Message explaining completion
|
||||
"task_id_counter": 1, # Counter for generating unique task IDs
|
||||
"key_facts": {}, # Dict[int, str] - ID to fact mapping
|
||||
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
|
||||
|
|
@ -55,7 +42,6 @@ _global_memory: Dict[
|
|||
"implementation_requested": False,
|
||||
"related_files": {}, # Dict[int, str] - ID to filepath mapping
|
||||
"related_file_id_counter": 1, # Counter for generating unique file IDs
|
||||
"plan_completed": False,
|
||||
"agent_depth": 0,
|
||||
"work_log": [], # List[WorkLogEntry] - Timestamped work events
|
||||
}
|
||||
|
|
@ -327,10 +313,9 @@ def one_shot_completed(message: str) -> str:
|
|||
if _global_memory.get("implementation_requested", False):
|
||||
return "Cannot complete in one shot - implementation was requested"
|
||||
|
||||
_global_memory["task_completed"] = True
|
||||
_global_memory["completion_message"] = message
|
||||
mark_task_completed(message)
|
||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||
log_work_event(f"Task completed\n\n{message}")
|
||||
log_work_event(f"Task completed:\n\n{message}")
|
||||
return "Completion noted."
|
||||
|
||||
|
||||
|
|
@ -341,9 +326,9 @@ def task_completed(message: str) -> str:
|
|||
Args:
|
||||
message: Message explaining how/why the task is complete
|
||||
"""
|
||||
_global_memory["task_completed"] = True
|
||||
_global_memory["completion_message"] = message
|
||||
mark_task_completed(message)
|
||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||
log_work_event(f"Task completed:\n\n{message}")
|
||||
return "Completion noted."
|
||||
|
||||
|
||||
|
|
@ -354,14 +339,11 @@ def plan_implementation_completed(message: str) -> str:
|
|||
Args:
|
||||
message: Message explaining how the implementation plan was completed
|
||||
"""
|
||||
_global_memory["task_completed"] = True
|
||||
_global_memory["completion_message"] = message
|
||||
_global_memory["plan_completed"] = True
|
||||
_global_memory["completion_message"] = message
|
||||
mark_plan_completed(message)
|
||||
_global_memory["tasks"].clear() # Clear task list when plan is completed
|
||||
_global_memory["task_id_counter"] = 1
|
||||
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
||||
log_work_event(f"Plan execution completed:\n\n{message}")
|
||||
log_work_event(f"Completed implementation:\n\n{message}")
|
||||
return "Plan completion noted and task list cleared."
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,178 @@
|
|||
"""Unit tests for the agent_context module."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_context import (
|
||||
AgentContext,
|
||||
agent_context,
|
||||
get_current_context,
|
||||
mark_task_completed,
|
||||
mark_plan_completed,
|
||||
reset_completion_flags,
|
||||
is_completed,
|
||||
get_completion_message,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentContext:
|
||||
"""Test cases for the AgentContext class and related functions."""
|
||||
|
||||
def test_context_creation(self):
|
||||
"""Test creating a new context."""
|
||||
context = AgentContext()
|
||||
assert context.task_completed is False
|
||||
assert context.plan_completed is False
|
||||
assert context.completion_message == ""
|
||||
|
||||
def test_context_inheritance(self):
|
||||
"""Test that child contexts inherit state from parent contexts."""
|
||||
parent = AgentContext()
|
||||
parent.mark_task_completed("Parent task completed")
|
||||
child = AgentContext(parent_context=parent)
|
||||
assert child.task_completed is True
|
||||
assert child.completion_message == "Parent task completed"
|
||||
|
||||
def test_mark_task_completed(self):
|
||||
"""Test marking a task as completed."""
|
||||
context = AgentContext()
|
||||
context.mark_task_completed("Task done")
|
||||
assert context.task_completed is True
|
||||
assert context.plan_completed is False
|
||||
assert context.completion_message == "Task done"
|
||||
|
||||
def test_mark_plan_completed(self):
|
||||
"""Test marking a plan as completed."""
|
||||
context = AgentContext()
|
||||
context.mark_plan_completed("Plan done")
|
||||
assert context.task_completed is True
|
||||
assert context.plan_completed is True
|
||||
assert context.completion_message == "Plan done"
|
||||
|
||||
def test_reset_completion_flags(self):
|
||||
"""Test resetting completion flags."""
|
||||
context = AgentContext()
|
||||
context.mark_task_completed("Task done")
|
||||
context.reset_completion_flags()
|
||||
assert context.task_completed is False
|
||||
assert context.plan_completed is False
|
||||
assert context.completion_message == ""
|
||||
|
||||
def test_is_completed_property(self):
|
||||
"""Test the is_completed property."""
|
||||
context = AgentContext()
|
||||
assert context.is_completed is False
|
||||
context.mark_task_completed("Task done")
|
||||
assert context.is_completed is True
|
||||
context.reset_completion_flags()
|
||||
assert context.is_completed is False
|
||||
context.mark_plan_completed("Plan done")
|
||||
assert context.is_completed is True
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test cases for the agent_context context manager."""
|
||||
|
||||
def test_context_manager_basic(self):
|
||||
"""Test basic context manager functionality."""
|
||||
assert get_current_context() is None
|
||||
with agent_context() as ctx:
|
||||
assert get_current_context() is ctx
|
||||
assert ctx.task_completed is False
|
||||
assert get_current_context() is None
|
||||
|
||||
def test_nested_context_managers(self):
|
||||
"""Test nested context managers."""
|
||||
with agent_context() as outer_ctx:
|
||||
assert get_current_context() is outer_ctx
|
||||
with agent_context() as inner_ctx:
|
||||
assert get_current_context() is inner_ctx
|
||||
assert inner_ctx is not outer_ctx
|
||||
assert get_current_context() is outer_ctx
|
||||
|
||||
def test_context_manager_with_parent(self):
|
||||
"""Test context manager with explicit parent context."""
|
||||
parent = AgentContext()
|
||||
parent.mark_task_completed("Parent task")
|
||||
with agent_context(parent_context=parent) as ctx:
|
||||
assert ctx.task_completed is True
|
||||
assert ctx.completion_message == "Parent task"
|
||||
|
||||
def test_context_manager_inheritance(self):
|
||||
"""Test that nested contexts inherit from outer contexts by default."""
|
||||
with agent_context() as outer:
|
||||
outer.mark_task_completed("Outer task")
|
||||
with agent_context() as inner:
|
||||
assert inner.task_completed is True
|
||||
assert inner.completion_message == "Outer task"
|
||||
inner.mark_plan_completed("Inner plan")
|
||||
# Outer context should not be affected by inner context changes
|
||||
assert outer.task_completed is True
|
||||
assert outer.plan_completed is False
|
||||
assert outer.completion_message == "Outer task"
|
||||
|
||||
|
||||
class TestThreadIsolation:
|
||||
"""Test thread isolation of context variables."""
|
||||
|
||||
def test_thread_isolation(self):
|
||||
"""Test that contexts are isolated between threads."""
|
||||
results = {}
|
||||
|
||||
def thread_func(thread_id):
|
||||
with agent_context() as ctx:
|
||||
ctx.mark_task_completed(f"Thread {thread_id}")
|
||||
time.sleep(0.1) # Give other threads time to run
|
||||
# Store the context's message for verification
|
||||
results[thread_id] = get_completion_message()
|
||||
|
||||
threads = []
|
||||
for i in range(3):
|
||||
t = threading.Thread(target=thread_func, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Each thread should have its own message
|
||||
assert results[0] == "Thread 0"
|
||||
assert results[1] == "Thread 1"
|
||||
assert results[2] == "Thread 2"
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test utility functions that operate on the current context."""
|
||||
|
||||
def test_mark_task_completed_utility(self):
|
||||
"""Test the mark_task_completed utility function."""
|
||||
with agent_context():
|
||||
mark_task_completed("Task done via utility")
|
||||
assert is_completed() is True
|
||||
assert get_completion_message() == "Task done via utility"
|
||||
|
||||
def test_mark_plan_completed_utility(self):
|
||||
"""Test the mark_plan_completed utility function."""
|
||||
with agent_context():
|
||||
mark_plan_completed("Plan done via utility")
|
||||
assert is_completed() is True
|
||||
assert get_completion_message() == "Plan done via utility"
|
||||
|
||||
def test_reset_completion_flags_utility(self):
|
||||
"""Test the reset_completion_flags utility function."""
|
||||
with agent_context():
|
||||
mark_task_completed("Task done")
|
||||
reset_completion_flags()
|
||||
assert is_completed() is False
|
||||
assert get_completion_message() == ""
|
||||
|
||||
def test_utility_functions_without_context(self):
|
||||
"""Test utility functions when no context is active."""
|
||||
# These should not raise exceptions even without an active context
|
||||
mark_task_completed("No context")
|
||||
mark_plan_completed("No context")
|
||||
reset_completion_flags()
|
||||
# These should have safe default returns
|
||||
assert is_completed() is False
|
||||
assert get_completion_message() == ""
|
||||
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from ra_aid.agent_context import agent_context, get_current_context, reset_completion_flags
|
||||
from ra_aid.agent_utils import (
|
||||
AgentState,
|
||||
create_agent,
|
||||
|
|
@ -325,7 +326,7 @@ def test_increment_and_decrement_agent_depth():
|
|||
|
||||
|
||||
def test_run_agent_stream(monkeypatch):
|
||||
from ra_aid.agent_utils import _global_memory, _run_agent_stream
|
||||
from ra_aid.agent_utils import _run_agent_stream
|
||||
|
||||
# Create a dummy agent that yields one chunk
|
||||
class DummyAgent:
|
||||
|
|
@ -334,9 +335,11 @@ def test_run_agent_stream(monkeypatch):
|
|||
|
||||
dummy_agent = DummyAgent()
|
||||
# Set flags so that _run_agent_stream will reset them
|
||||
_global_memory["plan_completed"] = True
|
||||
_global_memory["task_completed"] = True
|
||||
_global_memory["completion_message"] = "existing"
|
||||
with agent_context() as ctx:
|
||||
ctx.plan_completed = True
|
||||
ctx.task_completed = True
|
||||
ctx.completion_message = "existing"
|
||||
|
||||
call_flag = {"called": False}
|
||||
|
||||
def fake_print_agent_output(
|
||||
|
|
@ -349,9 +352,11 @@ def test_run_agent_stream(monkeypatch):
|
|||
)
|
||||
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
|
||||
assert call_flag["called"]
|
||||
assert _global_memory["plan_completed"] is False
|
||||
assert _global_memory["task_completed"] is False
|
||||
assert _global_memory["completion_message"] == ""
|
||||
|
||||
with agent_context() as ctx:
|
||||
assert ctx.plan_completed is False
|
||||
assert ctx.task_completed is False
|
||||
assert ctx.completion_message == ""
|
||||
|
||||
|
||||
def test_execute_test_command_wrapper(monkeypatch):
|
||||
|
|
|
|||
Loading…
Reference in New Issue