Source code for pydantic_ai_toolsets.toolsets.beam_search_reasoning.toolset
"""Beam search toolset for pydantic-ai agents."""
from __future__ import annotations
import sys
import time
import uuid
from typing import Any
from pydantic_ai import Agent
from pydantic_ai.toolsets import FunctionToolset
from .storage import BeamStorage, BeamStorageProtocol
from .types import (
BeamCandidate,
BeamStep,
CreateCandidateItem,
ExpandCandidateItem,
PruneBeamItem,
ScoreCandidateItem,
)
# =============================================================================
# SYSTEM PROMPT - Contains "when and why" to use the toolset
# =============================================================================
BEAM_SYSTEM_PROMPT = """
## Beam Search
You have access to tools for beam search exploration:
- `read_beam`: Review current beam state and candidates
- `create_candidate`: Create initial candidates
- `expand_candidate`: Generate next steps from a candidate
- `score_candidate`: Assign quality score (0-100)
- `prune_beam`: Keep only top-k candidates at a step
- `get_best_path`: Find highest-scoring path to terminal
### When to Use Beam Search
Use these tools in these scenarios:
1. Problems requiring simultaneous multi-path exploration
2. Tasks needing systematic exploration with pruning
3. Balancing exploration vs exploitation
4. Problems with clear scoring/evaluation functions
5. When breadth-first is too expensive
### Beam Search Process
1. **Initialize**: Create initial candidates (step 0)
2. **Expand**: Generate possible next steps from beam candidates
3. **Score**: Evaluate each candidate (0-100)
4. **Prune**: Keep only top-k highest-scoring (the "beam")
5. **Repeat**: Continue until terminal states or depth limit
6. **Select**: Return best path via get_best_path
### Key Parameters
- **Beam width (k)**: Candidates to keep per step
- k=1: greedy search (fast, may miss optimal)
- k=3-10: typical for practical applications
- **Scoring**: 0-100, higher is better
- **Terminal**: Mark solution candidates with is_terminal=true
### Workflow
1. Call `read_beam` to see current state
2. Create initial candidates if none exist
3. Expand candidates to generate continuations
4. Score all new candidates
5. Prune beam to keep top-k
6. Repeat until terminal candidates found
7. Use get_best_path for final result
**IMPORTANT**: Always call `read_beam` before modifying.
"""
# =============================================================================
# TOOL DESCRIPTIONS - Contains "how" to use each specific tool
# =============================================================================
READ_BEAM_DESCRIPTION = """Read the current beam search state.
Returns candidates organized by depth with scores and terminal status.
Returns:
- Beam steps with candidate lists
- Candidates by depth with scores
- Summary statistics
"""
CREATE_CANDIDATE_DESCRIPTION = """Create a new initial candidate.
Parameters:
- content: Reasoning content for this candidate
- is_terminal: True if this is a solution
Returns candidate ID and placement info.
Precondition: Call read_beam first.
"""
EXPAND_CANDIDATE_DESCRIPTION = """Expand a candidate into next steps.
Parameters:
- candidate_id: ID of candidate to expand
- expansions: List of continuation contents
- is_terminal: Optional list marking which are solutions
Creates new candidates at depth+1 in next step.
Precondition: Call read_beam first.
"""
SCORE_CANDIDATE_DESCRIPTION = """Score a candidate's quality.
Parameters:
- candidate_id: ID to score
- score: 0-100 (higher is better)
- reasoning: Explanation for the score
Score determines pruning priority.
Precondition: Call read_beam first.
"""
PRUNE_BEAM_DESCRIPTION = """Prune beam to keep top-k candidates.
Parameters:
- step_index: Which step to prune
- beam_width: k - how many top candidates to keep
Candidates sorted by score, top-k kept.
Precondition: Score candidates first.
"""
GET_BEST_PATH_DESCRIPTION = """Find best path to terminal candidate.
Returns highest-scoring path from initial to terminal candidates.
Returns:
- Best path with average score
- Full reasoning chain
"""
# Legacy constant
BEAM_TOOL_DESCRIPTION = CREATE_CANDIDATE_DESCRIPTION
[docs]
def create_beam_toolset(
storage: BeamStorageProtocol | None = None,
*,
id: str | None = None,
track_usage: bool = False,
) -> FunctionToolset[Any]:
"""Create a beam search toolset for beam-based reasoning exploration.
This toolset provides tools for AI agents to explore reasoning using beam search,
maintaining a beam of top-k candidates at each step.
Args:
storage: Optional storage backend. Defaults to in-memory BeamStorage.
id: Optional unique ID for the toolset.
track_usage: If True, enables usage metrics collection.
Returns:
FunctionToolset compatible with any pydantic-ai agent.
Example:
```python
from pydantic_ai import Agent
from pydantic_ai_toolsets import create_beam_toolset, BeamStorage
# With storage and metrics
storage = BeamStorage(track_usage=True)
agent = Agent("openai:gpt-4.1", toolsets=[create_beam_toolset(storage)])
print(storage.metrics.total_tokens())
```
"""
if storage is not None:
_storage = storage
else:
_storage = BeamStorage(track_usage=track_usage)
toolset: FunctionToolset[Any] = FunctionToolset(id=id)
_metrics = getattr(_storage, "metrics", None) if hasattr(_storage, "metrics") else None
def _get_status_summary() -> str:
"""Get one-line status summary."""
if not _storage.candidates:
return "Status: ○ Empty"
total = len(_storage.candidates)
terminal = sum(1 for c in _storage.candidates.values() if c.is_terminal)
max_step = max((s.step_index for s in _storage.steps), default=0) if _storage.steps else 0
if terminal > 0:
return f"Status: ✓ Has solutions | Step {max_step}, {total} candidates, {terminal} terminal"
return f"Status: ● Active | Step {max_step}, {total} candidates"
def _get_next_hint() -> str:
"""Get contextual hint for next action."""
if not _storage.candidates:
return "Use create_candidate to create initial candidates."
terminal = [c for c in _storage.candidates.values() if c.is_terminal and c.score is not None]
if terminal:
return "Terminal candidates found. Use get_best_path to find the best solution."
unscored = [c for c in _storage.candidates.values() if c.score is None]
if unscored:
return f"Use score_candidate on [{unscored[0].candidate_id}] to evaluate quality."
# Find candidates that can be expanded (non-terminal, scored)
expandable = [c for c in _storage.candidates.values() if not c.is_terminal and c.score is not None]
if expandable:
best = max(expandable, key=lambda c: c.score or 0)
return f"Use expand_candidate on [{best.candidate_id}] to generate next steps, then prune_beam."
return "Create more candidates or mark solutions as terminal."
@toolset.tool(description=READ_BEAM_DESCRIPTION)
async def read_beam() -> str:
"""Read the current beam search state."""
start_time = time.perf_counter()
if not _storage.candidates:
result = f"{_get_status_summary()}\n\nNo candidates.\n\nNext: {_get_next_hint()}"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("read_beam", "", result, duration_ms)
return result
else:
lines: list[str] = [_get_status_summary(), "", "Beam Search State:"]
lines.append("")
# Steps
if _storage.steps:
lines.append("Steps:")
for step in sorted(_storage.steps, key=lambda s: s.step_index):
lines.append(f" Step {step.step_index} (k={step.beam_width}):")
for cid in step.candidate_ids:
c = _storage.candidates.get(cid)
if c:
score = f"{c.score:.0f}" if c.score is not None else "?"
term = " ⭐" if c.is_terminal else ""
lines.append(f" [{cid}] score={score}{term}")
lines.append("")
# Candidates by depth
by_depth: dict[int, list[BeamCandidate]] = {}
for c in _storage.candidates.values():
by_depth.setdefault(c.depth, []).append(c)
lines.append("Candidates:")
for depth in sorted(by_depth.keys()):
candidates = sorted(by_depth[depth], key=lambda c: c.score or -1, reverse=True)
lines.append(f" Depth {depth}:")
for c in candidates:
score = f"{c.score:.0f}" if c.score is not None else "?"
term = " ⭐" if c.is_terminal else ""
parent = f" ←[{c.parent_id}]" if c.parent_id else " (root)"
lines.append(f" [{c.candidate_id}] {score}{term}{parent}")
lines.append(f" {c.content}")
lines.append("")
# Summary
stats = _storage.get_statistics() if hasattr(_storage, "get_statistics") else {}
if stats:
lines.append(
f"Stats: {stats.get('total_candidates', 0)} candidates, "
f"{stats.get('terminal_candidates', 0)} terminal, "
f"depth {stats.get('max_depth', 0)}"
)
lines.append("")
lines.append(f"Next: {_get_next_hint()}")
result = "\n".join(lines)
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("read_beam", "", result, duration_ms)
return result
@toolset.tool(description=CREATE_CANDIDATE_DESCRIPTION)
async def create_candidate(candidate: CreateCandidateItem) -> str:
"""Create a new candidate in the beam search."""
start_time = time.perf_counter()
input_text = candidate.model_dump_json() if _metrics else ""
candidate_id = str(uuid.uuid4())
step_index = 0
new_candidate = BeamCandidate(
candidate_id=candidate_id,
content=candidate.content,
depth=0,
is_terminal=candidate.is_terminal,
step_index=step_index,
)
_storage.candidates = new_candidate
# Find or create step 0
step = next((s for s in _storage.steps if s.step_index == 0), None)
if step is None:
step = BeamStep(step_index=0, candidate_ids=[], beam_width=1)
step.candidate_ids.append(candidate_id)
_storage.steps = step
result = f"Created [{candidate_id}] at step 0"
if candidate.is_terminal:
result += " ⭐"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("create_candidate", input_text, result, duration_ms)
return result
@toolset.tool(description=EXPAND_CANDIDATE_DESCRIPTION)
async def expand_candidate(expand: ExpandCandidateItem) -> str:
"""Expand a candidate to generate next steps."""
start_time = time.perf_counter()
input_text = expand.model_dump_json() if _metrics else ""
if expand.candidate_id not in _storage.candidates:
available = ", ".join([c.candidate_id for c in _storage.candidates.values()])
return f"Error: Candidate '{expand.candidate_id}' not found. Available: [{available}]. Call read_beam."
parent = _storage.candidates[expand.candidate_id]
is_terminal_list = expand.is_terminal
if is_terminal_list and len(is_terminal_list) != len(expand.expansions):
return f"Error: is_terminal length ({len(is_terminal_list)}) must match expansions ({len(expand.expansions)})."
new_ids: list[str] = []
next_depth = parent.depth + 1
next_step = parent.step_index + 1
for i, content in enumerate(expand.expansions):
cid = str(uuid.uuid4())
is_term = is_terminal_list[i] if is_terminal_list else False
new_c = BeamCandidate(
candidate_id=cid,
content=content,
depth=next_depth,
parent_id=expand.candidate_id,
is_terminal=is_term,
step_index=next_step,
)
_storage.candidates = new_c
new_ids.append(cid)
# Update step
step = next((s for s in _storage.steps if s.step_index == next_step), None)
if step:
step.candidate_ids.extend(new_ids)
else:
step = BeamStep(step_index=next_step, candidate_ids=new_ids.copy(), beam_width=len(new_ids))
_storage.steps = step
result = f"Expanded [{expand.candidate_id}] → {len(expand.expansions)} candidates at step {next_step}"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("expand_candidate", input_text, result, duration_ms)
return result
@toolset.tool(description=SCORE_CANDIDATE_DESCRIPTION)
async def score_candidate(score: ScoreCandidateItem) -> str:
"""Score a candidate to evaluate its quality."""
start_time = time.perf_counter()
input_text = score.model_dump_json() if _metrics else ""
if score.candidate_id not in _storage.candidates:
available = ", ".join([c.candidate_id for c in _storage.candidates.values()])
return f"Error: Candidate '{score.candidate_id}' not found. Available: [{available}]. Call read_beam."
candidate = _storage.candidates[score.candidate_id]
candidate.score = score.score
result = f"Scored [{score.candidate_id}]: {score.score:.0f}/100"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("score_candidate", input_text, result, duration_ms)
return result
@toolset.tool(description=PRUNE_BEAM_DESCRIPTION)
async def prune_beam(prune: PruneBeamItem) -> str:
"""Prune the beam to keep only top-k candidates at a step."""
start_time = time.perf_counter()
input_text = prune.model_dump_json() if _metrics else ""
step = next((s for s in _storage.steps if s.step_index == prune.step_index), None)
if step is None:
available_steps = ", ".join([str(s.step_index) for s in _storage.steps])
return f"Error: Step {prune.step_index} not found. Available steps: [{available_steps}]. Call read_beam."
candidates_with_scores = [
(cid, _storage.candidates[cid].score if _storage.candidates.get(cid) and _storage.candidates[cid].score is not None else -1)
for cid in step.candidate_ids
if cid in _storage.candidates
]
if not candidates_with_scores:
return f"No candidates in step {prune.step_index}."
candidates_with_scores.sort(key=lambda x: x[1], reverse=True)
kept = [cid for cid, _ in candidates_with_scores[:prune.beam_width]]
discarded = len(candidates_with_scores) - len(kept)
step.candidate_ids = kept
step.beam_width = prune.beam_width
_storage.steps = step
result = f"Pruned step {prune.step_index}: kept {len(kept)}, discarded {discarded}"
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("prune_beam", input_text, result, duration_ms)
return result
@toolset.tool(description=GET_BEST_PATH_DESCRIPTION)
async def get_best_path() -> str:
"""Find the best path found so far in the beam search."""
start_time = time.perf_counter()
if not _storage.candidates:
result = "No candidates. Create candidates first."
else:
terminals = [c for c in _storage.candidates.values() if c.is_terminal and c.score is not None]
if not terminals:
result = "No scored terminal candidates. Mark is_terminal=true and score them."
else:
# Build parent map
parent_map = {c.candidate_id: c.parent_id for c in _storage.candidates.values() if c.parent_id}
def reconstruct(cid: str) -> list[str]:
path = []
current: str | None = cid
while current:
path.append(current)
current = parent_map.get(current)
path.reverse()
return path
best_path: list[str] | None = None
best_score = -1.0
for term in terminals:
path = reconstruct(term.candidate_id)
scores = [_storage.candidates[cid].score for cid in path if _storage.candidates.get(cid) and _storage.candidates[cid].score is not None]
if scores:
avg = sum(scores) / len(scores)
if avg > best_score:
best_score = avg
best_path = path
if best_path is None:
result = "No scored path found."
else:
lines = [f"Best Path (avg score: {best_score:.0f}/100):", ""]
for i, cid in enumerate(best_path):
c = _storage.candidates.get(cid)
if c:
score = f"{c.score:.0f}" if c.score is not None else "?"
term = " ⭐" if c.is_terminal else ""
lines.append(f"{i+1}. [{cid}] {score}{term}")
lines.append(f" {c.content}")
result = "\n".join(lines)
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("get_best_path", "", result, duration_ms)
return result
return toolset
[docs]
def get_beam_system_prompt(storage: BeamStorageProtocol | None = None) -> str:
"""Generate dynamic system prompt section for beam search.
Args:
storage: Optional storage to read current beam from.
Returns:
System prompt section with current beam state, or base prompt if no candidates.
"""
if storage is None:
return BEAM_SYSTEM_PROMPT
if not hasattr(storage, "candidates"):
return BEAM_SYSTEM_PROMPT
if not storage.candidates:
return BEAM_SYSTEM_PROMPT
lines: list[str] = [BEAM_SYSTEM_PROMPT, "", "## Current State"]
total = len(storage.candidates)
scored = sum(1 for c in storage.candidates.values() if c.score is not None)
terminal = sum(1 for c in storage.candidates.values() if c.is_terminal)
max_depth = max((c.depth for c in storage.candidates.values()), default=0)
lines.append(f"Candidates: {total}, Scored: {scored}, Terminal: {terminal}, Depth: {max_depth}")
# Top candidates
scored_list = [c for c in storage.candidates.values() if c.score is not None]
scored_list.sort(key=lambda c: c.score or 0, reverse=True)
if scored_list:
lines.append("")
lines.append("Top candidates:")
for c in scored_list:
term = " ⭐" if c.is_terminal else ""
lines.append(f"- [{c.candidate_id}] {c.score:.0f}/100{term}")
return "\n".join(lines)
def create_beam_toolset_agent(model: str = "openrouter:x-ai/grok-4.1-fast") -> Agent:
"""Create a Pydantic-ai agent with the beam search toolset.
Args:
model: The model to use for the agent.
Returns:
Pydantic-ai agent with the beam search toolset.
"""
storage = BeamStorage()
toolset = create_beam_toolset(storage=storage)
agent = Agent(
model,
system_prompt="""
You are a beam search agent. You have access to tools for beam search exploration:
- `read_beam`: Review current beam state and candidates
- `create_candidate`: Create initial candidates
- `expand_candidate`: Generate next steps from a candidate
- `score_candidate`: Assign quality score
- `prune_beam`: Keep only top-k candidates
- `get_best_path`: Find highest-scoring path
**IMPORTANT**: Use these tools to explore reasoning using beam search with pruning.
""",
toolsets=[toolset]
)
@agent.instructions
async def add_prompt() -> str:
"""Add the beam search system prompt."""
return get_beam_system_prompt(storage)
return agent