diff --git a/ra_aid/agent_context.py b/ra_aid/agent_context.py index c6d76ba..980626c 100644 --- a/ra_aid/agent_context.py +++ b/ra_aid/agent_context.py @@ -1,11 +1,12 @@ """Context manager for tracking agent state and completion status.""" -import threading +import threading # Keep for backward compatibility +import contextvars from contextlib import contextmanager from typing import Optional -# Thread-local storage for context variables -_thread_local = threading.local() +# Create contextvar to hold the agent context +agent_context_var = contextvars.ContextVar("agent_context", default=None) class AgentContext: @@ -121,7 +122,7 @@ def get_current_context() -> Optional[AgentContext]: Returns: The current AgentContext or None if no context is active """ - return getattr(_thread_local, "current_context", None) + return agent_context_var.get() def get_depth() -> int: @@ -150,7 +151,7 @@ def agent_context(parent_context=None): The newly created AgentContext """ # Save the previous context - previous_context = getattr(_thread_local, "current_context", None) + previous_context = agent_context_var.get() # Create a new context, inheriting from parent if provided # If parent_context is None but previous_context exists, use previous_context as parent @@ -159,14 +160,14 @@ def agent_context(parent_context=None): else: context = AgentContext(parent_context) - # Set as current context - _thread_local.current_context = context + # Set as current context and get token for resetting later + token = agent_context_var.set(context) try: yield context finally: # Restore previous context - _thread_local.current_context = previous_context + agent_context_var.reset(token) def mark_task_completed(message: str) -> None: @@ -271,4 +272,4 @@ def get_crash_message() -> Optional[str]: The crash message or None if the agent has not crashed """ context = get_current_context() - return context.agent_crashed_message if context and context.is_crashed() else None + return context.agent_crashed_message if context and context.is_crashed() else None \ No newline at end of file diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 6fe8ed5..7cf2ef1 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -1077,16 +1077,13 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict) human-in-the-loop interruptions using interrupt_after=["tools"]. """ while True: - print("HERE") # Process each chunk from the agent stream. for chunk in agent.stream({"messages": msg_list}, config): - print("HERE IN FOR CHUNK") logger.debug("Agent output: %s", chunk) check_interrupt() agent_type = get_agent_type(agent) print_agent_output(chunk, agent_type) if is_completed() or should_exit(): - print("IS COMPLETED OR SHOULD EXIT TRIGGERED") reset_completion_flags() return True # Exit immediately when finished or signaled to exit. logger.debug("Stream iteration ended; checking agent state for continuation.") @@ -1108,20 +1105,15 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict) # If the state indicates that further steps remain (i.e. state.next is non-empty), # then resume execution by invoking the agent with no new input. if state.next: - print("HAS NEXT STATE") logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next) agent.invoke(None, config) continue else: - print("NO NEXT STATE") logger.debug("No continuation indicated in state; exiting stream loop.") break - print("WHILE EXITED") return True - - def run_agent_with_retry( agent: RAgents, prompt: str,