Source code for pydantic_ai_toolsets.toolsets.monte_carlo_reasoning.toolset
"""Monte Carlo Tree Search toolset for pydantic-ai agents."""
from __future__ import annotations
import math
import sys
import time
import uuid
from typing import Any
from pydantic_ai import Agent
from pydantic_ai.toolsets import FunctionToolset
from .storage import MCTSStorage, MCTSStorageProtocol
from .types import (
BackpropagateItem,
ExpandNodeItem,
MCTSNode,
SelectNodeItem,
SimulateItem,
)
# =============================================================================
# SYSTEM PROMPT - Contains "when and why" to use the toolset
# =============================================================================
MCTS_SYSTEM_PROMPT = """
## Monte Carlo Tree Search (MCTS)
You have access to tools for MCTS-based reasoning:
- `read_mcts`: Review current tree state
- `select_node`: Select promising node using UCB1
- `expand_node`: Expand node with possible children
- `simulate`: Run simulation from a node
- `backpropagate`: Update statistics from simulation
- `get_best_action`: Get best action based on visits
### When to Use MCTS
Use these tools in these scenarios:
1. Decision-making with many possible actions
2. Game-like problems with win/loss outcomes
3. Problems requiring exploration vs exploitation balance
4. Sequential decision problems
5. Simulations can provide reward signals
### MCTS Four Phases (Per Iteration)
1. **Selection**: Pick promising node (UCB1)
2. **Expansion**: Add children to selected node
3. **Simulation**: Evaluate with reward (0-1)
4. **Backpropagation**: Update path statistics
### UCB1 Formula
UCB1 = win_rate + c × √(ln(parent_visits) / visits)
- `win_rate`: wins/visits (exploitation)
- `c`: exploration constant (default √2 ≈ 1.414)
- Higher c = more exploration
### Workflow
1. Call `read_mcts` to see current state
2. Create root if tree is empty
3. For each iteration:
a. `select_node` - find promising leaf
b. `expand_node` - add possible actions
c. `simulate` - evaluate with reward (0-1)
d. `backpropagate` - update statistics
4. After iterations, `get_best_action` for result
### Rewards
- Use 0.0-1.0 scale
- 1.0 = best outcome (win)
- 0.0 = worst outcome (loss)
- Intermediate values for partial success
**IMPORTANT**: Always call `read_mcts` before modifying.
"""
# =============================================================================
# TOOL DESCRIPTIONS - Concise "how" for each tool
# =============================================================================
READ_MCTS_DESCRIPTION = """Read the current MCTS tree state.
Returns nodes with visits, wins, UCB1 values.
"""
SELECT_NODE_DESCRIPTION = """Select a promising node using UCB1.
Parameters:
- node_id: Optional specific node (or auto-select from root)
- exploration_constant: c for UCB1 (default 1.414)
Returns selected node for expansion.
Precondition: Call read_mcts first.
"""
EXPAND_NODE_DESCRIPTION = """Expand a node with possible children.
Parameters:
- node_id: Node to expand
- children: List of child contents (possible actions)
- is_terminal: Optional list marking terminal children
Returns created child node IDs.
Precondition: Call read_mcts first.
"""
SIMULATE_DESCRIPTION = """Run simulation and record result.
Parameters:
- node_id: Starting node
- simulation_result: Reward 0.0-1.0
- simulation_path: Optional path of node IDs
Triggers backpropagation automatically.
Precondition: Call read_mcts first.
"""
BACKPROPAGATE_DESCRIPTION = """Update statistics from node to root.
Parameters:
- node_id: Leaf/terminal node
- reward: 0.0-1.0 value from simulation
Updates visits and wins for all ancestors.
Precondition: Call read_mcts first.
"""
GET_BEST_ACTION_DESCRIPTION = """Get best action based on statistics.
Returns highest-visited child of root.
Most robust selection criterion.
"""
# Legacy constant
MCTS_TOOL_DESCRIPTION = SELECT_NODE_DESCRIPTION
[docs]
def calculate_ucb1(node: MCTSNode, parent_visits: int, exploration_constant: float) -> float:
"""Calculate UCB1 value for a node.
UCB1 = win_rate + c * sqrt(ln(parent_visits) / visits)
Args:
node: The node to calculate UCB1 for.
parent_visits: Total visits of the parent node.
exploration_constant: The exploration constant (c).
Returns:
UCB1 value, or infinity if node hasn't been visited.
"""
if node.visits == 0:
return float("inf")
exploitation = node.wins / node.visits
exploration = exploration_constant * math.sqrt(math.log(parent_visits) / node.visits)
return exploitation + exploration
[docs]
def create_mcts_toolset(
storage: MCTSStorageProtocol | None = None,
*,
id: str | None = None,
track_usage: bool = False,
) -> FunctionToolset[Any]:
"""Create an MCTS toolset for tree-based exploration with statistics.
This toolset provides tools for AI agents to explore reasoning using
Monte Carlo Tree Search, balancing exploration and exploitation.
Args:
storage: Optional storage backend. Defaults to in-memory MCTSStorage.
id: Optional unique ID for the toolset.
track_usage: If True, enables usage metrics collection.
Returns:
FunctionToolset compatible with any pydantic-ai agent.
Example:
```python
from pydantic_ai import Agent
from pydantic_ai_toolsets import create_mcts_toolset, MCTSStorage
# With storage and metrics
storage = MCTSStorage(track_usage=True)
agent = Agent("openai:gpt-4.1", toolsets=[create_mcts_toolset(storage)])
print(storage.metrics.total_tokens())
```
"""
if storage is not None:
_storage = storage
else:
_storage = MCTSStorage(track_usage=track_usage)
toolset: FunctionToolset[Any] = FunctionToolset(id=id)
_metrics = getattr(_storage, "metrics", None) if hasattr(_storage, "metrics") else None
def _get_status_summary() -> str:
"""Get one-line status summary."""
if not _storage.nodes:
return "Status: ○ Empty"
stats = _storage.get_statistics() if hasattr(_storage, "get_statistics") else {}
total = stats.get("total_nodes", len(_storage.nodes))
iterations = stats.get("iterations", 0)
terminal = sum(1 for n in _storage.nodes.values() if n.is_terminal)
if terminal > 0:
return f"Status: ✓ Has solutions | {total} nodes, {iterations} iterations"
return f"Status: ● Active | {total} nodes, {iterations} iterations"
def _get_next_hint() -> str:
"""Get contextual hint for next action."""
if not _storage.nodes:
return "Use expand_node with a root node_id to create the tree."
root = next((n for n in _storage.nodes.values() if n.parent_id is None), None)
if not root:
return "Use expand_node to create a root node."
terminal = [n for n in _storage.nodes.values() if n.is_terminal and n.visits > 0]
if terminal:
return "Terminal nodes found. Use get_best_action for the most visited solution."
# Check for unexpanded nodes
unexpanded = [n for n in _storage.nodes.values() if not n.is_expanded and not n.is_terminal]
if unexpanded:
return f"Use select_node to find a promising leaf, then expand_node and simulate."
# Standard MCTS iteration
return "Run MCTS iteration: select_node → expand_node → simulate → backpropagate."
@toolset.tool(description=READ_MCTS_DESCRIPTION)
async def read_mcts() -> str:
"""Read the current MCTS tree state."""
start_time = time.perf_counter()
if not _storage.nodes:
result = f"{_get_status_summary()}\n\nEmpty tree.\n\nNext: {_get_next_hint()}"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("read_mcts", "", result, duration_ms)
return result
else:
lines: list[str] = [_get_status_summary(), "", "MCTS Tree:"]
lines.append("")
# Find root
root = next((n for n in _storage.nodes.values() if n.parent_id is None), None)
def display_node(node: MCTSNode, indent: str = "") -> None:
term = " ⭐" if node.is_terminal else ""
rate = f"{node.wins / node.visits:.2f}" if node.visits > 0 else "?"
ucb = ""
if node.parent_id and root and root.visits > 0:
ucb_val = calculate_ucb1(node, root.visits, math.sqrt(2))
ucb = f" ucb={ucb_val:.2f}" if ucb_val != float("inf") else " ucb=∞"
lines.append(
f"{indent}[{node.node_id}] visits={node.visits} "
f"wins={node.wins:.1f} rate={rate}{ucb}{term}"
)
lines.append(
f"{indent} {node.content}"
)
for cid in node.children_ids:
child = _storage.nodes.get(cid)
if child:
display_node(child, indent + " ")
if root:
display_node(root)
lines.append("")
# Summary
stats = _storage.get_statistics() if hasattr(_storage, "get_statistics") else {}
if stats:
lines.append(
f"Stats: {stats.get('total_nodes', 0)} nodes, "
f"{stats.get('iterations', 0)} iterations, "
f"depth {stats.get('max_depth', 0)}"
)
lines.append("")
lines.append(f"Next: {_get_next_hint()}")
result = "\n".join(lines)
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("read_mcts", "", result, duration_ms)
return result
@toolset.tool(description=SELECT_NODE_DESCRIPTION)
async def select_node(select: SelectNodeItem) -> str:
"""Select a promising node using UCB1."""
start_time = time.perf_counter()
input_text = select.model_dump_json() if _metrics else ""
if select.node_id:
if select.node_id not in _storage.nodes:
available = ", ".join([n.node_id for n in _storage.nodes.values()])
return f"Error: Node '{select.node_id}' not found. Available: [{available}]. Call read_mcts."
result = f"Selected [{select.node_id}]"
else:
# UCB1 selection from root
root = next((n for n in _storage.nodes.values() if n.parent_id is None), None)
if not root:
return "No root. Create tree with expand_node first."
current = root
path: list[str] = [current.node_id]
# Descend using UCB1
while current.children_ids and current.is_expanded:
best_child: MCTSNode | None = None
best_ucb = -float("inf")
for cid in current.children_ids:
child = _storage.nodes.get(cid)
if child:
ucb = calculate_ucb1(child, current.visits, select.exploration_constant)
if ucb > best_ucb:
best_ucb = ucb
best_child = child
if best_child is None:
break
current = best_child
path.append(current.node_id)
# Stop at unexpanded or terminal
if not current.is_expanded or current.is_terminal:
break
path_str = " → ".join(f"[{n}]" for n in path)
result = f"Selected [{current.node_id}] via {path_str}"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("select_node", input_text, result, duration_ms)
return result
@toolset.tool(description=EXPAND_NODE_DESCRIPTION)
async def expand_node(expand: ExpandNodeItem) -> str:
"""Expand a node by adding children."""
start_time = time.perf_counter()
input_text = expand.model_dump_json() if _metrics else ""
# Handle root creation
if expand.node_id not in _storage.nodes:
# Create root if tree is empty
if not _storage.nodes:
root = MCTSNode(
node_id=expand.node_id,
content="Root",
depth=0,
is_expanded=True,
)
_storage.nodes = root
else:
available = ", ".join([n.node_id for n in _storage.nodes.values()])
return f"Error: Node '{expand.node_id}' not found. Available: [{available}]. Call read_mcts."
parent = _storage.nodes[expand.node_id]
is_terminal_list = expand.is_terminal
if is_terminal_list and len(is_terminal_list) != len(expand.children):
return f"Error: is_terminal length must match children length."
new_ids: list[str] = []
for i, content in enumerate(expand.children):
child_id = str(uuid.uuid4())
is_term = is_terminal_list[i] if is_terminal_list else False
child = MCTSNode(
node_id=child_id,
content=content,
parent_id=parent.node_id,
depth=parent.depth + 1,
is_terminal=is_term,
)
_storage.nodes = child
parent.children_ids.append(child_id)
new_ids.append(child_id)
parent.is_expanded = True
result = f"Expanded [{expand.node_id}] → {len(expand.children)} children"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("expand_node", input_text, result, duration_ms)
return result
@toolset.tool(description=SIMULATE_DESCRIPTION)
async def simulate(sim: SimulateItem) -> str:
"""Run a simulation and record the result."""
start_time = time.perf_counter()
input_text = sim.model_dump_json() if _metrics else ""
if sim.node_id not in _storage.nodes:
available = ", ".join([n.node_id for n in _storage.nodes.values()])
return f"Error: Node '{sim.node_id}' not found. Available: [{available}]. Call read_mcts."
# Backpropagate from this node
current: str | None = sim.node_id
nodes_updated = 0
while current:
node = _storage.nodes.get(current)
if node:
node.visits += 1
node.wins += sim.simulation_result
nodes_updated += 1
current = node.parent_id
else:
break
# Increment iteration counter
if hasattr(_storage, "increment_iteration"):
_storage.increment_iteration()
result = f"Simulated [{sim.node_id}] reward={sim.simulation_result:.2f}, updated {nodes_updated} nodes"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("simulate", input_text, result, duration_ms)
return result
@toolset.tool(description=BACKPROPAGATE_DESCRIPTION)
async def backpropagate(backprop: BackpropagateItem) -> str:
"""Backpropagate statistics from a node to the root."""
start_time = time.perf_counter()
input_text = backprop.model_dump_json() if _metrics else ""
if backprop.node_id not in _storage.nodes:
available = ", ".join([n.node_id for n in _storage.nodes.values()])
return f"Error: Node '{backprop.node_id}' not found. Available: [{available}]. Call read_mcts."
current: str | None = backprop.node_id
nodes_updated = 0
while current:
node = _storage.nodes.get(current)
if node:
node.visits += 1
node.wins += backprop.reward
nodes_updated += 1
current = node.parent_id
else:
break
result = f"Backpropagated from [{backprop.node_id}] reward={backprop.reward:.2f}, updated {nodes_updated} nodes"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("backpropagate", input_text, result, duration_ms)
return result
@toolset.tool(description=GET_BEST_ACTION_DESCRIPTION)
async def get_best_action() -> str:
"""Get the best action based on visit counts."""
start_time = time.perf_counter()
root = next((n for n in _storage.nodes.values() if n.parent_id is None), None)
if not root:
result = "No root. Create tree first."
elif not root.children_ids:
result = "Root has no children. Expand root first."
else:
# Select by most visits (most robust)
best_child: MCTSNode | None = None
best_visits = -1
for cid in root.children_ids:
child = _storage.nodes.get(cid)
if child and child.visits > best_visits:
best_visits = child.visits
best_child = child
if best_child is None:
result = "No visited children."
else:
rate = best_child.wins / best_child.visits if best_child.visits > 0 else 0
lines = [
f"Best action: [{best_child.node_id}]",
f" Visits: {best_child.visits}",
f" Win rate: {rate:.2%}",
f" Content: {best_child.content}",
]
# Show all children for comparison
lines.append("")
lines.append("All root children:")
children = [(cid, _storage.nodes.get(cid)) for cid in root.children_ids]
children.sort(key=lambda x: x[1].visits if x[1] else 0, reverse=True)
for cid, child in children:
if child:
c_rate = child.wins / child.visits if child.visits > 0 else 0
star = " ←" if child.node_id == best_child.node_id else ""
lines.append(
f" [{cid}] v={child.visits} rate={c_rate:.2%}{star}"
)
result = "\n".join(lines)
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("get_best_action", "", result, duration_ms)
return result
return toolset
[docs]
def get_mcts_system_prompt(storage: MCTSStorageProtocol | None = None) -> str:
"""Generate dynamic system prompt section for MCTS.
Args:
storage: Optional storage to read current tree from.
Returns:
System prompt section with current tree state, or base prompt if empty.
"""
if storage is None or not storage.nodes:
return MCTS_SYSTEM_PROMPT
lines: list[str] = [MCTS_SYSTEM_PROMPT, "", "## Current State"]
total = len(storage.nodes)
root = next((n for n in storage.nodes.values() if n.parent_id is None), None)
if root:
lines.append(f"Root visits: {root.visits}, Total nodes: {total}")
if root.children_ids:
lines.append("")
lines.append("Top actions by visits:")
children = [
(cid, storage.nodes.get(cid))
for cid in root.children_ids
if storage.nodes.get(cid)
]
children.sort(key=lambda x: x[1].visits if x[1] else 0, reverse=True)
for cid, child in children:
if child:
rate = child.wins / child.visits if child.visits > 0 else 0
lines.append(f"- [{cid}] v={child.visits} rate={rate:.2%}")
return "\n".join(lines)
def create_mcts_toolset_agent(model: str = "openrouter:x-ai/grok-4.1-fast") -> Agent:
"""Create a Pydantic-ai agent with the MCTS toolset.
Args:
model: The model to use for the agent.
Returns:
Pydantic-ai agent with the MCTS toolset.
"""
storage = MCTSStorage()
toolset = create_mcts_toolset(storage=storage)
agent = Agent(
model,
system_prompt="""
You are an MCTS agent. You have access to tools for MCTS-based reasoning:
- `read_mcts`: Review current tree state
- `select_node`: Select promising node using UCB1
- `expand_node`: Expand node with possible children
- `simulate`: Run simulation from a node
- `backpropagate`: Update statistics from simulation
- `get_best_action`: Get best action based on visits
**IMPORTANT**: Use these tools to explore reasoning using Monte Carlo Tree Search.
""",
toolsets=[toolset]
)
@agent.instructions
async def add_prompt() -> str:
"""Add the MCTS system prompt."""
return get_mcts_system_prompt(storage)
return agent