fix agent context

This commit is contained in:
AI Christianson 2025-03-07 18:33:45 -05:00
parent 3607803bf5
commit 60e4616313
2 changed files with 10 additions and 17 deletions

View File

@ -1,11 +1,12 @@
"""Context manager for tracking agent state and completion status.""" """Context manager for tracking agent state and completion status."""
import threading import threading # Keep for backward compatibility
import contextvars
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import Optional
# Thread-local storage for context variables # Create contextvar to hold the agent context
_thread_local = threading.local() agent_context_var = contextvars.ContextVar("agent_context", default=None)
class AgentContext: class AgentContext:
@ -121,7 +122,7 @@ def get_current_context() -> Optional[AgentContext]:
Returns: Returns:
The current AgentContext or None if no context is active The current AgentContext or None if no context is active
""" """
return getattr(_thread_local, "current_context", None) return agent_context_var.get()
def get_depth() -> int: def get_depth() -> int:
@ -150,7 +151,7 @@ def agent_context(parent_context=None):
The newly created AgentContext The newly created AgentContext
""" """
# Save the previous context # Save the previous context
previous_context = getattr(_thread_local, "current_context", None) previous_context = agent_context_var.get()
# Create a new context, inheriting from parent if provided # Create a new context, inheriting from parent if provided
# If parent_context is None but previous_context exists, use previous_context as parent # If parent_context is None but previous_context exists, use previous_context as parent
@ -159,14 +160,14 @@ def agent_context(parent_context=None):
else: else:
context = AgentContext(parent_context) context = AgentContext(parent_context)
# Set as current context # Set as current context and get token for resetting later
_thread_local.current_context = context token = agent_context_var.set(context)
try: try:
yield context yield context
finally: finally:
# Restore previous context # Restore previous context
_thread_local.current_context = previous_context agent_context_var.reset(token)
def mark_task_completed(message: str) -> None: def mark_task_completed(message: str) -> None:

View File

@ -1077,16 +1077,13 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
human-in-the-loop interruptions using interrupt_after=["tools"]. human-in-the-loop interruptions using interrupt_after=["tools"].
""" """
while True: while True:
print("HERE")
# Process each chunk from the agent stream. # Process each chunk from the agent stream.
for chunk in agent.stream({"messages": msg_list}, config): for chunk in agent.stream({"messages": msg_list}, config):
print("HERE IN FOR CHUNK")
logger.debug("Agent output: %s", chunk) logger.debug("Agent output: %s", chunk)
check_interrupt() check_interrupt()
agent_type = get_agent_type(agent) agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type) print_agent_output(chunk, agent_type)
if is_completed() or should_exit(): if is_completed() or should_exit():
print("IS COMPLETED OR SHOULD EXIT TRIGGERED")
reset_completion_flags() reset_completion_flags()
return True # Exit immediately when finished or signaled to exit. return True # Exit immediately when finished or signaled to exit.
logger.debug("Stream iteration ended; checking agent state for continuation.") logger.debug("Stream iteration ended; checking agent state for continuation.")
@ -1108,20 +1105,15 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict)
# If the state indicates that further steps remain (i.e. state.next is non-empty), # If the state indicates that further steps remain (i.e. state.next is non-empty),
# then resume execution by invoking the agent with no new input. # then resume execution by invoking the agent with no new input.
if state.next: if state.next:
print("HAS NEXT STATE")
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next) logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
agent.invoke(None, config) agent.invoke(None, config)
continue continue
else: else:
print("NO NEXT STATE")
logger.debug("No continuation indicated in state; exiting stream loop.") logger.debug("No continuation indicated in state; exiting stream loop.")
break break
print("WHILE EXITED")
return True return True
def run_agent_with_retry( def run_agent_with_retry(
agent: RAgents, agent: RAgents,
prompt: str, prompt: str,