Source code for pydantic_ai_toolsets.toolsets.chain_of_thought_reasoning.toolset
"""Chain of thoughts toolset for pydantic-ai agents."""
from __future__ import annotations
import sys
import time
from typing import Any
from pydantic_ai import Agent
from pydantic_ai.toolsets import FunctionToolset
from .storage import CoTStorage, CoTStorageProtocol
from .types import Thought
# =============================================================================
# SYSTEM PROMPT - Contains "when and why" to use the toolset
# =============================================================================
COT_SYSTEM_PROMPT = """
## Chain of Thoughts
You have access to tools for managing your reasoning process:
- `read_thoughts`: Review your current chain of thoughts
- `write_thoughts`: Add a new thought to your chain
### When to Use Chain of Thoughts
Use these tools in these scenarios:
1. Complex problems requiring multi-step reasoning
2. Planning and design tasks that may need revision
3. Analysis where understanding evolves
4. Multi-step solutions needing context tracking
5. Problems with uncertainty requiring exploration
6. Hypothesis generation and verification
### Workflow
1. Call `read_thoughts` to see your current reasoning state
2. Call `write_thoughts` to add your next thought (increment thought_number)
3. Repeat until you reach a conclusion (set next_thought_needed=false)
### Thought Management
- Start with thought_number=1 and estimate total_thoughts
- Each thought should build on, question, or revise previous insights
- Mark is_revision=true when reconsidering previous thoughts
- Use branch_from_thought and branch_id for alternative paths
- Set next_thought_needed=false when you've reached a satisfactory answer
**IMPORTANT**: Always call `read_thoughts` before `write_thoughts` to:
- Review previous reasoning
- Determine the next thought_number
- Avoid repeating yourself
- Make informed revisions
"""
# =============================================================================
# TOOL DESCRIPTIONS - Contains "how" to use each specific tool
# =============================================================================
READ_THOUGHTS_DESCRIPTION = """Read your current chain of thoughts.
Returns all recorded thoughts with their sequence numbers, revisions, and branches.
Use this to review your reasoning history before adding new thoughts.
Returns:
- Thoughts in sequence with metadata (revisions, branches)
- Summary statistics (total, revisions, branches, final)
"""
WRITE_THOUGHTS_DESCRIPTION = """Add a new thought to your chain.
Parameters:
- thought: Your current thinking step content
- thought_number: Sequential number (1-based, increment from previous)
- total_thoughts: Estimated total needed (adjust as understanding deepens)
- is_revision: True if reconsidering a previous thought
- revises_thought: Which thought number being reconsidered (if is_revision)
- branch_from_thought: Branching point if exploring alternative path
- branch_id: Identifier for the branch (groups related thoughts)
- next_thought_needed: False when you've reached a conclusion
Returns confirmation with updated statistics.
Precondition: Call read_thoughts first to see current state.
"""
# Legacy constant for backward compatibility (now points to write description)
COT_TOOL_DESCRIPTION = WRITE_THOUGHTS_DESCRIPTION
[docs]
def create_cot_toolset(
storage: CoTStorageProtocol | None = None,
*,
id: str | None = None,
track_usage: bool = False,
) -> FunctionToolset[Any]:
"""Create a chain of thoughts toolset for reasoning exploration.
This toolset provides read_thoughts and write_thoughts tools for AI agents
to document and explore their reasoning process during a session.
Args:
storage: Optional storage backend. Defaults to in-memory CoTStorage.
You can provide a custom storage implementing CoTStorageProtocol
for persistence or integration with other systems.
id: Optional unique ID for the toolset.
track_usage: If True, enables usage metrics collection on the default
storage. Ignored if custom storage is provided.
Returns:
FunctionToolset compatible with any pydantic-ai agent.
Example (standalone):
```python
from pydantic_ai import Agent
from pydantic_ai_toolsets import create_cot_toolset
agent = Agent("openai:gpt-4.1", toolsets=[create_cot_toolset()])
result = await agent.run("Solve this complex problem step by step")
```
Example (with custom storage):
```python
from pydantic_ai_toolsets import create_cot_toolset, CoTStorage
storage = CoTStorage()
toolset = create_cot_toolset(storage=storage)
# After agent runs, access thoughts directly
print(storage.thoughts)
```
Example (with usage tracking):
```python
from pydantic_ai_toolsets import create_cot_toolset, CoTStorage
storage = CoTStorage(track_usage=True)
toolset = create_cot_toolset(storage=storage)
# After agent runs, check usage metrics
print(storage.metrics.total_tokens())
print(storage.metrics.invocation_count())
```
"""
if storage is not None:
_storage = storage
else:
_storage = CoTStorage(track_usage=track_usage)
toolset: FunctionToolset[Any] = FunctionToolset(id=id)
# Get metrics for tracking if available
_metrics = getattr(_storage, "metrics", None) if hasattr(_storage, "metrics") else None
def _get_status_summary() -> str:
"""Get one-line status summary."""
if not _storage.thoughts:
return "Status: ○ Empty"
total = len(_storage.thoughts)
branches = len(set(t.branch_id for t in _storage.thoughts if t.branch_id))
final = sum(1 for t in _storage.thoughts if not t.next_thought_needed)
if final > 0:
return f"Status: ✓ Complete | {total} thoughts, {branches} branches"
return f"Status: ● Active | {total} thoughts, {branches} branches"
def _get_next_hint() -> str:
"""Get contextual hint for next action."""
if not _storage.thoughts:
return "Use write_thoughts with thought_number=1 to start reasoning."
sorted_thoughts = sorted(_storage.thoughts, key=lambda t: t.thought_number)
final = sum(1 for t in _storage.thoughts if not t.next_thought_needed)
if final > 0:
return "Reasoning complete. Provide your final answer."
next_num = sorted_thoughts[-1].thought_number + 1
return f"Continue with write_thoughts using thought_number={next_num}."
@toolset.tool(description=READ_THOUGHTS_DESCRIPTION)
async def read_thoughts() -> str:
"""Read the current chain of thoughts."""
start_time = time.perf_counter()
if not _storage.thoughts:
result = f"{_get_status_summary()}\n\nNo thoughts recorded yet.\n\nNext: {_get_next_hint()}"
else:
lines: list[str] = [_get_status_summary(), "", "Chain of Thoughts:"]
lines.append("")
# Sort by thought_number to ensure correct order
sorted_thoughts = sorted(_storage.thoughts, key=lambda t: t.thought_number)
for thought in sorted_thoughts:
# Thought header
header_parts: list[str] = [f"#{thought.thought_number}"]
if thought.is_revision:
header_parts.append("(REVISION")
if thought.revises_thought:
header_parts.append(f"of #{thought.revises_thought}")
header_parts.append(")")
if thought.branch_id:
header_parts.append(f"[{thought.branch_id}]")
if thought.branch_from_thought:
header_parts.append(f"(from #{thought.branch_from_thought})")
lines.append(" ".join(header_parts))
lines.append(f" {thought.thought}")
# Metadata
if not thought.next_thought_needed:
lines.append(" [FINAL]")
lines.append("")
# Summary
total = len(_storage.thoughts)
revisions = sum(1 for t in _storage.thoughts if t.is_revision)
branches = len(set(t.branch_id for t in _storage.thoughts if t.branch_id))
final = sum(1 for t in _storage.thoughts if not t.next_thought_needed)
lines.append(f"Stats: {total} thoughts")
if revisions > 0:
lines.append(f" Revisions: {revisions}")
if branches > 0:
lines.append(f" Branches: {branches}")
if final > 0:
lines.append(f" Final: {final}")
lines.append("")
lines.append(f"Next: {_get_next_hint()}")
result = "\n".join(lines)
# Record metrics if tracking is enabled
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("read_thoughts", "", result, duration_ms)
return result
@toolset.tool(description=WRITE_THOUGHTS_DESCRIPTION)
async def write_thoughts(thought: Thought) -> str:
"""Add a new thought to the chain.
Args:
thought: Thought item with reasoning content and metadata.
"""
start_time = time.perf_counter()
# Serialize input for metrics
input_text = thought.model_dump_json() if _metrics else ""
_storage.thoughts = thought
# Count statistics
total = len(_storage.thoughts)
revisions = sum(1 for t in _storage.thoughts if t.is_revision)
branches = len(set(t.branch_id for t in _storage.thoughts if t.branch_id))
final = sum(1 for t in _storage.thoughts if not t.next_thought_needed)
parts = [f"Added thought #{thought.thought_number}"]
if thought.is_revision:
parts.append("(revision)")
if thought.branch_id:
parts.append(f"[{thought.branch_id}]")
if not thought.next_thought_needed:
parts.append("[FINAL]")
parts.append(f"| Total: {total}")
if revisions > 0:
parts.append(f"rev:{revisions}")
if branches > 0:
parts.append(f"branches:{branches}")
result = " ".join(parts)
# Record metrics if tracking is enabled
if _metrics is not None:
duration_ms = (time.perf_counter() - start_time) * 1000
_metrics.record_invocation("write_thoughts", input_text, result, duration_ms)
return result
return toolset
[docs]
def get_cot_system_prompt(storage: CoTStorageProtocol | None = None) -> str:
"""Generate dynamic system prompt section for chain of thoughts.
Args:
storage: Optional storage to read current thoughts from.
Returns:
System prompt section with current thoughts, or base prompt if no thoughts.
"""
if storage is None or not storage.thoughts:
return COT_SYSTEM_PROMPT
lines: list[str] = [COT_SYSTEM_PROMPT, "", "## Current State"]
# Sort by thought_number
sorted_thoughts = sorted(storage.thoughts, key=lambda t: t.thought_number)
total = len(sorted_thoughts)
lines.append(f"Thoughts recorded: {total}")
# Show last few thoughts as preview
preview_count = min(3, total)
if preview_count > 0:
lines.append("")
lines.append("Recent thoughts:")
for thought in sorted_thoughts[-preview_count:]:
header = f"#{thought.thought_number}"
if thought.is_revision:
header += " (rev)"
if thought.branch_id:
header += f" [{thought.branch_id}]"
lines.append(f"- {header}: {thought.thought}")
# Next thought number hint
next_num = sorted_thoughts[-1].thought_number + 1 if sorted_thoughts else 1
lines.append("")
lines.append(f"Next thought_number: {next_num}")
return "\n".join(lines)
def create_cot_toolset_agent(model: str = "openrouter:x-ai/grok-4.1-fast") -> Agent:
"""Create a Pydantic-ai agent with the chain of thoughts toolset.
Args:
model: The model to use for the agent.
Returns:
Pydantic-ai agent with the chain of thoughts toolset.
"""
storage = CoTStorage()
toolset = create_cot_toolset(storage=storage)
agent = Agent(
model,
system_prompt="""
You are a reasoning agent. You have access to tools for managing your reasoning process:
- `read_thoughts`: Review your current chain of thoughts
- `write_thoughts`: Add a new thought to your chain
**IMPORTANT**: Use these tools to document and explore your reasoning process during complex problems.
""",
toolsets=[toolset]
)
@agent.instructions
async def add_prompt() -> str:
"""Add the chain of thoughts system prompt."""
return get_cot_system_prompt(storage)
return agent