Tavily integration

This commit is contained in:
user 2024-12-23 17:20:24 -05:00
parent f5482dce2a
commit 1ab7ce60ea
5 changed files with 17 additions and 7 deletions

View File

@ -161,7 +161,7 @@ def run_research_agent(
return run_agent_with_retry(agent, prompt, run_config) return run_agent_with_retry(agent, prompt, run_config)
def run_web_research_agent( def run_web_research_agent(
base_task_or_query: str, query: str,
model, model,
*, *,
expert_enabled: bool = False, expert_enabled: bool = False,
@ -174,7 +174,7 @@ def run_web_research_agent(
"""Run a web research agent with the given configuration. """Run a web research agent with the given configuration.
Args: Args:
base_task_or_query: The main task or query for web research query: The mainquery for web research
model: The LLM model to use model: The LLM model to use
expert_enabled: Whether expert mode is enabled expert_enabled: Whether expert mode is enabled
hil: Whether human-in-the-loop mode is enabled hil: Whether human-in-the-loop mode is enabled
@ -218,7 +218,7 @@ def run_web_research_agent(
# Build prompt # Build prompt
prompt = WEB_RESEARCH_PROMPT.format( prompt = WEB_RESEARCH_PROMPT.format(
base_task=base_task_or_query, web_research_query=query,
expert_section=expert_section, expert_section=expert_section,
human_section=human_section, human_section=human_section,
key_facts=key_facts, key_facts=key_facts,

View File

@ -117,7 +117,8 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
list: List of tools configured for web research list: List of tools configured for web research
""" """
tools = [ tools = [
web_search_tavily web_search_tavily,
emit_research_notes
] ]
if expert_enabled: if expert_enabled:

View File

@ -111,10 +111,12 @@ def request_web_research(query: str) -> ResearchResult:
success = True success = True
reason = None reason = None
web_research_notes = []
try: try:
# Run web research agent # Run web research agent
from ..agent_utils import run_web_research_agent from ..agent_utils import run_web_research_agent
original_research_notes = _global_memory.get('research_notes', [])
result = run_web_research_agent( result = run_web_research_agent(
query, query,
model, model,
@ -123,6 +125,7 @@ def request_web_research(query: str) -> ResearchResult:
console_message=query, console_message=query,
config=config config=config
) )
web_research_notes = _global_memory.get('research_notes', [])
except AgentInterrupt: except AgentInterrupt:
print() print()
response = ask_human.invoke({"question": "Why did you interrupt me?"}) response = ask_human.invoke({"question": "Why did you interrupt me?"})
@ -135,6 +138,7 @@ def request_web_research(query: str) -> ResearchResult:
success = False success = False
reason = f"error: {str(e)}" reason = f"error: {str(e)}"
finally: finally:
_global_memory['research_notes'] = original_research_notes
# Get completion message if available # Get completion message if available
completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None)
@ -151,9 +155,7 @@ def request_web_research(query: str) -> ResearchResult:
return { return {
"work_log": work_log, "work_log": work_log,
"completion_message": completion_message, "completion_message": completion_message,
"key_facts": get_memory_value("key_facts"), "web_research_notes": web_research_notes,
"related_files": get_related_files(),
"research_notes": get_memory_value("research_notes"),
"key_snippets": get_memory_value("key_snippets"), "key_snippets": get_memory_value("key_snippets"),
"success": success, "success": success,
"reason": reason "reason": reason

View File

@ -5,6 +5,7 @@ from typing import Dict, Optional, Tuple
from langchain_core.tools import tool from langchain_core.tools import tool
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.markdown import Markdown
from ra_aid.text.processing import truncate_output from ra_aid.text.processing import truncate_output
console = Console() console = Console()

View File

@ -2,6 +2,11 @@ import os
from typing import Dict from typing import Dict
from tavily import TavilyClient from tavily import TavilyClient
from langchain_core.tools import tool from langchain_core.tools import tool
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
console = Console()
@tool @tool
def web_search_tavily(query: str) -> Dict: def web_search_tavily(query: str) -> Dict:
@ -15,5 +20,6 @@ def web_search_tavily(query: str) -> Dict:
Dict containing search results from Tavily Dict containing search results from Tavily
""" """
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
console.print(Panel(Markdown(query), title="🔍 Searching Tavily", border_style="bright_blue"))
search_result = client.search(query=query) search_result = client.search(query=query)
return search_result return search_result