Fix Web Tools, Upgrade MoA to GPT5, Add Trajectory Saving
This commit is contained in:
parent
4ece87efb0
commit
587d1cf720
5 changed files with 1090 additions and 131 deletions
220
run_agent.py
220
run_agent.py
|
|
@ -26,6 +26,7 @@ import time
|
|||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
|
||||
# Import our tool system
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
|
|
@ -49,7 +50,8 @@ class AIAgent:
|
|||
enabled_tools: List[str] = None,
|
||||
disabled_tools: List[str] = None,
|
||||
enabled_toolsets: List[str] = None,
|
||||
disabled_toolsets: List[str] = None
|
||||
disabled_toolsets: List[str] = None,
|
||||
save_trajectories: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
|
|
@ -64,10 +66,12 @@ class AIAgent:
|
|||
disabled_tools (List[str]): Disable these specific tools (optional)
|
||||
enabled_toolsets (List[str]): Only enable tools from these toolsets (optional)
|
||||
disabled_toolsets (List[str]): Disable tools from these toolsets (optional)
|
||||
save_trajectories (bool): Whether to save conversation trajectories to JSONL files (default: False)
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
self.tool_delay = tool_delay
|
||||
self.save_trajectories = save_trajectories
|
||||
|
||||
# Store tool filtering options
|
||||
self.enabled_tools = enabled_tools
|
||||
|
|
@ -123,31 +127,184 @@ class AIAgent:
|
|||
missing_reqs = [name for name, available in requirements.items() if not available]
|
||||
if missing_reqs:
|
||||
print(f"⚠️ Some tools may not work due to missing requirements: {missing_reqs}")
|
||||
|
||||
# Show trajectory saving status
|
||||
if self.save_trajectories:
|
||||
print("📝 Trajectory saving enabled")
|
||||
|
||||
def create_system_message(self, custom_system: str = None) -> str:
|
||||
def _format_tools_for_system_message(self) -> str:
|
||||
"""
|
||||
Create the system message for the agent.
|
||||
Format tool definitions for the system message in the trajectory format.
|
||||
|
||||
Returns:
|
||||
str: JSON string representation of tool definitions
|
||||
"""
|
||||
if not self.tools:
|
||||
return "[]"
|
||||
|
||||
# Convert tool definitions to the format expected in trajectories
|
||||
formatted_tools = []
|
||||
for tool in self.tools:
|
||||
func = tool["function"]
|
||||
formatted_tool = {
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
"required": None # Match the format in the example
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return json.dumps(formatted_tools)
|
||||
|
||||
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert internal message format to trajectory format for saving.
|
||||
|
||||
Args:
|
||||
custom_system (str): Custom system message (optional)
|
||||
messages (List[Dict]): Internal message history
|
||||
user_query (str): Original user query
|
||||
completed (bool): Whether the conversation completed successfully
|
||||
|
||||
Returns:
|
||||
str: System message content
|
||||
List[Dict]: Messages in trajectory format
|
||||
"""
|
||||
if custom_system:
|
||||
return custom_system
|
||||
trajectory = []
|
||||
|
||||
return (
|
||||
"You are an AI assistant that provides helpful responses. You may use extremely long chains of thought "
|
||||
"to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help "
|
||||
"come to a correct solution prior to answering. You should enclose your thoughts and internal monologue "
|
||||
"inside <thinking> tags.\n\n"
|
||||
"You are equipped with web research tools that allow you to search the web, extract content from web pages, "
|
||||
"and crawl websites. Use these tools to gather current information and provide accurate, well-researched responses. "
|
||||
"You can call multiple tools in parallel if they are not reliant on each other's results. You can also use "
|
||||
"sequential tool calls to build on data you've collected from previous tool calls. Continue using tools until "
|
||||
"you feel confident you have enough information to provide a comprehensive answer."
|
||||
# Add system message with tool definitions
|
||||
system_msg = (
|
||||
"You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. "
|
||||
"You may call one or more functions to assist with the user query. If available tools are not relevant in assisting "
|
||||
"with user query, just respond in natural conversational language. Don't make assumptions about what values to plug "
|
||||
"into functions. After calling & executing the functions, you will be provided with function results within "
|
||||
"<tool_response> </tool_response> XML tags. Here are the available tools:\n"
|
||||
f"<tools>\n{self._format_tools_for_system_message()}\n</tools>\n"
|
||||
"For each function call return a JSON object, with the following pydantic model json schema for each:\n"
|
||||
"{'title': 'FunctionCall', 'type': 'object', 'properties': {'name': {'title': 'Name', 'type': 'string'}, "
|
||||
"'arguments': {'title': 'Arguments', 'type': 'object'}}, 'required': ['name', 'arguments']}\n"
|
||||
"Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n"
|
||||
"Example:\n<tool_call>\n{'name': <function-name>,'arguments': <args-dict>}\n</tool_call>"
|
||||
)
|
||||
|
||||
trajectory.append({
|
||||
"from": "system",
|
||||
"value": system_msg
|
||||
})
|
||||
|
||||
# Add the initial user message
|
||||
trajectory.append({
|
||||
"from": "human",
|
||||
"value": user_query
|
||||
})
|
||||
|
||||
# Process remaining messages
|
||||
i = 1 # Skip the first user message as we already added it
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
if msg["role"] == "assistant":
|
||||
# Check if this message has tool calls
|
||||
if "tool_calls" in msg and msg["tool_calls"]:
|
||||
# Format assistant message with tool calls
|
||||
content = ""
|
||||
if msg.get("content") and msg["content"].strip():
|
||||
content = msg["content"] + "\n"
|
||||
|
||||
# Add tool calls wrapped in XML tags
|
||||
for tool_call in msg["tool_calls"]:
|
||||
tool_call_json = {
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"]
|
||||
}
|
||||
content += f"<tool_call>\n{json.dumps(tool_call_json)}\n</tool_call>\n"
|
||||
|
||||
trajectory.append({
|
||||
"from": "gpt",
|
||||
"value": content.rstrip()
|
||||
})
|
||||
|
||||
# Collect all subsequent tool responses
|
||||
tool_responses = []
|
||||
j = i + 1
|
||||
while j < len(messages) and messages[j]["role"] == "tool":
|
||||
tool_msg = messages[j]
|
||||
# Format tool response with XML tags
|
||||
tool_response = f"<tool_response>\n"
|
||||
|
||||
# Try to parse tool content as JSON if it looks like JSON
|
||||
tool_content = tool_msg["content"]
|
||||
try:
|
||||
if tool_content.strip().startswith(("{", "[")):
|
||||
tool_content = json.loads(tool_content)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass # Keep as string if not valid JSON
|
||||
|
||||
tool_response += json.dumps({
|
||||
"tool_call_id": tool_msg.get("tool_call_id", ""),
|
||||
"name": msg["tool_calls"][len(tool_responses)]["function"]["name"] if len(tool_responses) < len(msg["tool_calls"]) else "unknown",
|
||||
"content": tool_content
|
||||
})
|
||||
tool_response += "\n</tool_response>"
|
||||
tool_responses.append(tool_response)
|
||||
j += 1
|
||||
|
||||
# Add all tool responses as a single message
|
||||
if tool_responses:
|
||||
trajectory.append({
|
||||
"from": "tool",
|
||||
"value": "\n".join(tool_responses)
|
||||
})
|
||||
i = j - 1 # Skip the tool messages we just processed
|
||||
|
||||
else:
|
||||
# Regular assistant message without tool calls
|
||||
trajectory.append({
|
||||
"from": "gpt",
|
||||
"value": msg["content"] or ""
|
||||
})
|
||||
|
||||
elif msg["role"] == "user":
|
||||
trajectory.append({
|
||||
"from": "human",
|
||||
"value": msg["content"]
|
||||
})
|
||||
|
||||
i += 1
|
||||
|
||||
return trajectory
|
||||
|
||||
def _save_trajectory(self, messages: List[Dict[str, Any]], user_query: str, completed: bool):
|
||||
"""
|
||||
Save conversation trajectory to JSONL file.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): Complete message history
|
||||
user_query (str): Original user query
|
||||
completed (bool): Whether the conversation completed successfully
|
||||
"""
|
||||
if not self.save_trajectories:
|
||||
return
|
||||
|
||||
# Convert messages to trajectory format
|
||||
trajectory = self._convert_to_trajectory_format(messages, user_query, completed)
|
||||
|
||||
# Determine which file to save to
|
||||
filename = "trajectory_samples.jsonl" if completed else "failed_trajectories.jsonl"
|
||||
|
||||
# Create trajectory entry
|
||||
entry = {
|
||||
"conversations": trajectory,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": self.model,
|
||||
"completed": completed
|
||||
}
|
||||
|
||||
# Append to JSONL file
|
||||
try:
|
||||
with open(filename, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
print(f"💾 Trajectory saved to {filename}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
|
||||
def run_conversation(
|
||||
self,
|
||||
|
|
@ -169,13 +326,6 @@ class AIAgent:
|
|||
# Initialize conversation
|
||||
messages = conversation_history or []
|
||||
|
||||
# Add system message if not already present
|
||||
if not messages or messages[0]["role"] != "system":
|
||||
messages.insert(0, {
|
||||
"role": "system",
|
||||
"content": self.create_system_message(system_message)
|
||||
})
|
||||
|
||||
# Add user message
|
||||
messages.append({
|
||||
"role": "user",
|
||||
|
|
@ -292,11 +442,17 @@ class AIAgent:
|
|||
if final_response is None:
|
||||
final_response = "I've reached the maximum number of iterations. Here's what I found so far."
|
||||
|
||||
# Determine if conversation completed successfully
|
||||
completed = final_response is not None and api_call_count < self.max_iterations
|
||||
|
||||
# Save trajectory if enabled
|
||||
self._save_trajectory(messages, user_message, completed)
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": final_response is not None
|
||||
"completed": completed
|
||||
}
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
|
|
@ -323,7 +479,8 @@ def main(
|
|||
disabled_tools: str = None,
|
||||
enabled_toolsets: str = None,
|
||||
disabled_toolsets: str = None,
|
||||
list_tools: bool = False
|
||||
list_tools: bool = False,
|
||||
save_trajectories: bool = False
|
||||
):
|
||||
"""
|
||||
Main function for running the agent directly.
|
||||
|
|
@ -339,6 +496,7 @@ def main(
|
|||
enabled_toolsets (str): Comma-separated list of toolsets to enable (e.g., "web_tools")
|
||||
disabled_toolsets (str): Comma-separated list of toolsets to disable (e.g., "terminal_tools")
|
||||
list_tools (bool): Just list available tools and exit
|
||||
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
|
||||
"""
|
||||
print("🤖 AI Agent with Tool Calling")
|
||||
print("=" * 50)
|
||||
|
|
@ -373,6 +531,8 @@ def main(
|
|||
print(f" python run_agent.py --enabled_tools=web_search,web_extract --query='research topic'")
|
||||
print(f" # Run without terminal tools")
|
||||
print(f" python run_agent.py --disabled_tools=terminal --query='web research only'")
|
||||
print(f" # Run with trajectory saving enabled")
|
||||
print(f" python run_agent.py --save_trajectories --query='your question here'")
|
||||
return
|
||||
|
||||
# Parse tool selection arguments
|
||||
|
|
@ -397,6 +557,11 @@ def main(
|
|||
disabled_toolsets_list = [t.strip() for t in disabled_toolsets.split(",")]
|
||||
print(f"🚫 Disabled toolsets: {disabled_toolsets_list}")
|
||||
|
||||
if save_trajectories:
|
||||
print(f"💾 Trajectory saving: ENABLED")
|
||||
print(f" - Successful conversations → trajectory_samples.jsonl")
|
||||
print(f" - Failed conversations → failed_trajectories.jsonl")
|
||||
|
||||
# Initialize agent with provided parameters
|
||||
try:
|
||||
agent = AIAgent(
|
||||
|
|
@ -407,7 +572,8 @@ def main(
|
|||
enabled_tools=enabled_tools_list,
|
||||
disabled_tools=disabled_tools_list,
|
||||
enabled_toolsets=enabled_toolsets_list,
|
||||
disabled_toolsets=disabled_toolsets_list
|
||||
disabled_toolsets=disabled_toolsets_list,
|
||||
save_trajectories=save_trajectories
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(f"❌ Failed to initialize agent: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue