+ );
+}
+
+export default ReframeViewer;
diff --git a/src/agentgit/enhancers/viewer/index.html b/src/agentgit/enhancers/viewer/index.html
new file mode 100644
index 0000000..6fd36fb
--- /dev/null
+++ b/src/agentgit/enhancers/viewer/index.html
@@ -0,0 +1,220 @@
+
+
+
+
+
+ Reframe - Mental Model Viewer
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/agentgit/interactive_model.py b/src/agentgit/interactive_model.py
new file mode 100644
index 0000000..e60a1f8
--- /dev/null
+++ b/src/agentgit/interactive_model.py
@@ -0,0 +1,562 @@
+"""
+Interactive Mental Model - Time travel and focused prompting.
+
+Extends the living mental model with:
+1. Version history / snapshots for "time travel" through the model's evolution
+2. Interactive prompting - clicking on diagram elements generates contextual prompts
+
+The key insight: the diagram isn't just a visualization, it's an INTERFACE for
+directing the AI agent. Clicking on a component lets you ask focused questions
+or make targeted changes to that part of the system.
+"""
+
+from dataclasses import dataclass, field
+from typing import Optional, Callable
+from datetime import datetime
+from pathlib import Path
+import json
+import copy
+
+from .mental_model import (
+ MentalModel,
+ MentalModelNode,
+ MentalModelEdge,
+ ModelUpdate,
+)
+from .core import Prompt
+
+
+@dataclass
+class ModelSnapshot:
+ """A point-in-time snapshot of the mental model."""
+
+ version: int
+ timestamp: datetime
+ model_state: MentalModel
+ trigger_prompt_id: Optional[str] = None
+ trigger_prompt_text: Optional[str] = None
+ change_summary: str = ""
+
+ def to_dict(self) -> dict:
+ return {
+ "version": self.version,
+ "timestamp": self.timestamp.isoformat(),
+ "trigger_prompt_id": self.trigger_prompt_id,
+ "trigger_prompt_text": self.trigger_prompt_text,
+ "change_summary": self.change_summary,
+ "model_state": json.loads(self.model_state.to_json()),
+ }
+
+
+@dataclass
+class FocusedPrompt:
+ """A prompt generated from interacting with the diagram."""
+
+ prompt_text: str
+ context: str # What part of the diagram this relates to
+ suggested_files: list[str] = field(default_factory=list)
+ node_ids: list[str] = field(default_factory=list) # Related nodes
+
+
+class InteractiveMentalModel:
+ """
+ A mental model with time travel and interactive prompting capabilities.
+
+ This transforms the diagram from a passive visualization into an
+ active interface for directing an AI coding agent.
+ """
+
+ def __init__(self, project_name: str = ""):
+ self.current_model = MentalModel(project_name=project_name)
+ self.history: list[ModelSnapshot] = []
+ self.current_version: int = 0
+
+ # Map from node IDs to associated files (for focused context)
+ self.node_files: dict[str, set[str]] = {}
+
+ # Map from node IDs to prompts that modified them (for provenance)
+ self.node_prompt_history: dict[str, list[str]] = {}
+
+ def apply_update(
+ self,
+ update: ModelUpdate,
+ prompt: Optional[Prompt] = None,
+ affected_files: Optional[list[str]] = None,
+ ) -> None:
+ """Apply an update and save a snapshot for history."""
+
+ # Save snapshot BEFORE applying update
+ self._save_snapshot(prompt, update.reasoning)
+
+ # Apply the update
+ for node_id in update.remove_node_ids:
+ self.current_model.remove_node(node_id)
+
+ for source_id, target_id, relationship in update.remove_edges:
+ self.current_model.edges = [
+ e for e in self.current_model.edges
+ if not (e.source_id == source_id and
+ e.target_id == target_id and
+ e.relationship == relationship)
+ ]
+
+ for node in update.add_nodes:
+ self.current_model.add_node(node, prompt.prompt_id if prompt else None)
+
+ # Track file associations
+ if affected_files:
+ if node.id not in self.node_files:
+ self.node_files[node.id] = set()
+ self.node_files[node.id].update(affected_files)
+
+ # Track prompt history
+ if prompt:
+ if node.id not in self.node_prompt_history:
+ self.node_prompt_history[node.id] = []
+ self.node_prompt_history[node.id].append(prompt.prompt_id)
+
+ for edge in update.add_edges:
+ edge.introduced_by_prompt_id = prompt.prompt_id if prompt else None
+ self.current_model.add_edge(edge)
+
+ self.current_model.version += 1
+ self.current_model.last_updated = datetime.now()
+ self.current_version = self.current_model.version
+
+ def _save_snapshot(self, prompt: Optional[Prompt], change_summary: str) -> None:
+ """Save current state as a snapshot."""
+ snapshot = ModelSnapshot(
+ version=self.current_version,
+ timestamp=datetime.now(),
+ model_state=copy.deepcopy(self.current_model),
+ trigger_prompt_id=prompt.prompt_id if prompt else None,
+ trigger_prompt_text=prompt.text if prompt else None,
+ change_summary=change_summary,
+ )
+ self.history.append(snapshot)
+
+ # ===== TIME TRAVEL =====
+
+ def get_version(self, version: int) -> Optional[MentalModel]:
+ """Get the model state at a specific version."""
+ for snapshot in self.history:
+ if snapshot.version == version:
+ return copy.deepcopy(snapshot.model_state)
+ if version == self.current_version:
+ return copy.deepcopy(self.current_model)
+ return None
+
+ def get_timeline(self) -> list[dict]:
+ """Get a summary of all versions for timeline display."""
+ timeline = []
+ for snapshot in self.history:
+ timeline.append({
+ "version": snapshot.version,
+ "timestamp": snapshot.timestamp.isoformat(),
+ "prompt_preview": (
+ snapshot.trigger_prompt_text[:50] + "..."
+ if snapshot.trigger_prompt_text and len(snapshot.trigger_prompt_text) > 50
+ else snapshot.trigger_prompt_text
+ ),
+ "change_summary": snapshot.change_summary,
+ "node_count": len(snapshot.model_state.nodes),
+ "edge_count": len(snapshot.model_state.edges),
+ })
+
+ # Add current state
+ timeline.append({
+ "version": self.current_version,
+ "timestamp": self.current_model.last_updated.isoformat() if self.current_model.last_updated else None,
+ "prompt_preview": None,
+ "change_summary": "Current state",
+ "node_count": len(self.current_model.nodes),
+ "edge_count": len(self.current_model.edges),
+ })
+
+ return timeline
+
+ def diff_versions(self, from_version: int, to_version: int) -> dict:
+ """Get the diff between two versions."""
+ from_model = self.get_version(from_version)
+ to_model = self.get_version(to_version)
+
+ if not from_model or not to_model:
+ return {"error": "Version not found"}
+
+ from_nodes = set(from_model.nodes.keys())
+ to_nodes = set(to_model.nodes.keys())
+
+ added_nodes = to_nodes - from_nodes
+ removed_nodes = from_nodes - to_nodes
+
+ # Find modified nodes (same ID but different content)
+ modified_nodes = []
+ for node_id in from_nodes & to_nodes:
+ from_node = from_model.nodes[node_id]
+ to_node = to_model.nodes[node_id]
+ if from_node.label != to_node.label or from_node.description != to_node.description:
+ modified_nodes.append(node_id)
+
+ return {
+ "from_version": from_version,
+ "to_version": to_version,
+ "added_nodes": list(added_nodes),
+ "removed_nodes": list(removed_nodes),
+ "modified_nodes": modified_nodes,
+ }
+
+ # ===== INTERACTIVE PROMPTING =====
+
+ def generate_node_prompt(self, node_id: str, intent: str = "explore") -> FocusedPrompt:
+ """
+ Generate a focused prompt for a specific node in the diagram.
+
+ This is called when a user clicks on a node and wants to interact
+ with the AI about that specific part of the system.
+
+ Intent options:
+ - "explore": Learn more about this component
+ - "modify": Make changes to this component
+ - "connect": Add connections to/from this component
+ - "debug": Investigate issues with this component
+ - "test": Add tests for this component
+ """
+ node = self.current_model.nodes.get(node_id)
+ if not node:
+ return FocusedPrompt(
+ prompt_text=f"I don't see a component with ID '{node_id}' in the current model.",
+ context="Unknown node",
+ )
+
+ # Get related context
+ connected_nodes = self._get_connected_nodes(node_id)
+ associated_files = list(self.node_files.get(node_id, []))
+ prompt_history = self.node_prompt_history.get(node_id, [])
+
+ # Build context string
+ context_parts = [f"**{node.label}** ({node.node_type})"]
+ if node.description:
+ context_parts.append(f"Description: {node.description}")
+ if connected_nodes:
+ context_parts.append(f"Connected to: {', '.join(connected_nodes)}")
+ if associated_files:
+ context_parts.append(f"Related files: {', '.join(associated_files[:5])}")
+
+ context = "\n".join(context_parts)
+
+ # Generate prompt based on intent
+ prompts = {
+ "explore": f"Tell me more about the {node.label} component. What does it do, how does it work, and how does it fit into the overall system?",
+
+ "modify": f"I want to make changes to the {node.label} component. What are the key files I should look at, and what should I be careful about when modifying it?",
+
+ "connect": f"I want to add a new connection to/from the {node.label} component. What other parts of the system might it need to interact with?",
+
+ "debug": f"I'm having issues with the {node.label} component. Can you help me understand how it works and identify potential problem areas?",
+
+ "test": f"I want to add tests for the {node.label} component. What are the key behaviors I should test, and what edge cases should I consider?",
+
+ "explain": f"Explain the {node.label} component to me as if I'm new to this codebase. What's its purpose and how does it relate to the rest of the system?",
+
+ "refactor": f"I'm considering refactoring the {node.label} component. What improvements could be made, and what would be the impact on connected components?",
+ }
+
+ prompt_text = prompts.get(intent, prompts["explore"])
+
+ # Add context to make it more specific
+ if associated_files:
+ prompt_text += f"\n\nRelevant files: {', '.join(associated_files[:3])}"
+
+ return FocusedPrompt(
+ prompt_text=prompt_text,
+ context=context,
+ suggested_files=associated_files,
+ node_ids=[node_id] + list(connected_nodes),
+ )
+
+ def generate_edge_prompt(self, source_id: str, target_id: str) -> FocusedPrompt:
+ """Generate a prompt focused on a relationship between two nodes."""
+ source = self.current_model.nodes.get(source_id)
+ target = self.current_model.nodes.get(target_id)
+
+ if not source or not target:
+ return FocusedPrompt(
+ prompt_text="Could not find one or both nodes.",
+ context="Unknown relationship",
+ )
+
+ # Find the edge
+ edge = None
+ for e in self.current_model.edges:
+ if e.source_id == source_id and e.target_id == target_id:
+ edge = e
+ break
+
+ relationship_desc = edge.relationship if edge else "relates to"
+
+ prompt_text = f"Explain how {source.label} {relationship_desc} {target.label}. How do these components interact, and what data or control flows between them?"
+
+ return FocusedPrompt(
+ prompt_text=prompt_text,
+ context=f"Relationship: {source.label} → {target.label}",
+ suggested_files=list(
+ self.node_files.get(source_id, set()) |
+ self.node_files.get(target_id, set())
+ ),
+ node_ids=[source_id, target_id],
+ )
+
+ def generate_area_prompt(self, node_ids: list[str]) -> FocusedPrompt:
+ """Generate a prompt for a selected area of the diagram (multiple nodes)."""
+ nodes = [self.current_model.nodes[nid] for nid in node_ids if nid in self.current_model.nodes]
+
+ if not nodes:
+ return FocusedPrompt(
+ prompt_text="No valid nodes selected.",
+ context="Empty selection",
+ )
+
+ node_names = [n.label for n in nodes]
+ all_files = set()
+ for nid in node_ids:
+ all_files.update(self.node_files.get(nid, set()))
+
+ prompt_text = f"I'm looking at this area of the system: {', '.join(node_names)}. How do these components work together? What's the data flow and control flow between them?"
+
+ return FocusedPrompt(
+ prompt_text=prompt_text,
+ context=f"Selected components: {', '.join(node_names)}",
+ suggested_files=list(all_files),
+ node_ids=node_ids,
+ )
+
+ def _get_connected_nodes(self, node_id: str) -> set[str]:
+ """Get all nodes connected to the given node."""
+ connected = set()
+ for edge in self.current_model.edges:
+ if edge.source_id == node_id:
+ if edge.target_id in self.current_model.nodes:
+ connected.add(self.current_model.nodes[edge.target_id].label)
+ elif edge.target_id == node_id:
+ if edge.source_id in self.current_model.nodes:
+ connected.add(self.current_model.nodes[edge.source_id].label)
+ return connected
+
+ # ===== EXPORT =====
+
+ def to_interactive_html(self) -> str:
+ """
+ Export the model as an interactive HTML page.
+
+ This would render the Mermaid diagram with click handlers that
+ trigger the focused prompt generation.
+ """
+ mermaid_code = self.current_model.to_mermaid()
+ timeline_data = json.dumps(self.get_timeline())
+
+ # This is a simplified template - a real implementation would use
+ # a proper frontend framework and Mermaid's click callbacks
+ html = f"""
+
+
+ Living Mental Model
+
+
+
+
+
🧠 Living Mental Model
+
+
+
+
System Architecture
+
+{mermaid_code}
+
+
Click on any node to generate a focused prompt
+
+
Time Travel
+
+ Version {self.current_version}
+
+
+
+
+
📜 History
+
+
+
+
+
💬 Focused Prompt
+
Click a node in the diagram to generate a context-aware prompt.
+
+
+
+
+
+
+
+"""
+ return html
+
+ def save(self, path: Path) -> None:
+ """Save the complete interactive model state."""
+ state = {
+ "current_model": json.loads(self.current_model.to_json()),
+ "history": [s.to_dict() for s in self.history],
+ "current_version": self.current_version,
+ "node_files": {k: list(v) for k, v in self.node_files.items()},
+ "node_prompt_history": self.node_prompt_history,
+ }
+ Path(path).write_text(json.dumps(state, indent=2))
+
+
+# Example usage
+if __name__ == "__main__":
+ from .mental_model import MentalModelNode, MentalModelEdge, ModelUpdate
+ from .core import Prompt
+
+ # Create an interactive model
+ model = InteractiveMentalModel(project_name="E-commerce Platform")
+
+ # Simulate evolution through multiple prompts
+ prompts = [
+ Prompt(text="Create a basic product catalog", timestamp=datetime.now()),
+ Prompt(text="Add shopping cart functionality", timestamp=datetime.now()),
+ Prompt(text="Implement checkout with Stripe", timestamp=datetime.now()),
+ ]
+
+ updates = [
+ ModelUpdate(
+ reasoning="Adding product catalog components",
+ add_nodes=[
+ MentalModelNode(id="products", label="Product Catalog", node_type="component"),
+ MentalModelNode(id="db", label="Product Database", node_type="data"),
+ ],
+ add_edges=[
+ MentalModelEdge(source_id="products", target_id="db", relationship="uses"),
+ ],
+ ),
+ ModelUpdate(
+ reasoning="Adding shopping cart",
+ add_nodes=[
+ MentalModelNode(id="cart", label="Shopping Cart", node_type="component"),
+ ],
+ add_edges=[
+ MentalModelEdge(source_id="cart", target_id="products", relationship="uses"),
+ ],
+ ),
+ ModelUpdate(
+ reasoning="Adding checkout and payment processing",
+ add_nodes=[
+ MentalModelNode(id="checkout", label="Checkout Flow", node_type="process"),
+ MentalModelNode(id="stripe", label="Stripe Integration", node_type="interface"),
+ ],
+ add_edges=[
+ MentalModelEdge(source_id="checkout", target_id="cart", relationship="uses"),
+ MentalModelEdge(source_id="checkout", target_id="stripe", relationship="uses"),
+ ],
+ ),
+ ]
+
+ # Track file associations
+ file_associations = [
+ ["src/products/catalog.py", "src/products/models.py"],
+ ["src/cart/cart.py", "src/cart/session.py"],
+ ["src/checkout/flow.py", "src/payments/stripe.py"],
+ ]
+
+ # Apply updates
+ for prompt, update, files in zip(prompts, updates, file_associations):
+ model.apply_update(update, prompt, files)
+ print(f"Applied: {update.reasoning}")
+
+ # Show timeline
+ print("\n=== Timeline ===")
+ for item in model.get_timeline():
+ print(f" v{item['version']}: {item['change_summary']} ({item['node_count']} nodes)")
+
+ # Show diff
+ print("\n=== Diff v0 → v3 ===")
+ diff = model.diff_versions(0, 3)
+ print(f" Added: {diff['added_nodes']}")
+
+ # Generate focused prompts
+ print("\n=== Focused Prompts ===")
+
+ prompt = model.generate_node_prompt("checkout", "explore")
+ print(f"\n[Explore Checkout]")
+ print(f" Context: {prompt.context}")
+ print(f" Prompt: {prompt.prompt_text[:100]}...")
+
+ prompt = model.generate_node_prompt("cart", "modify")
+ print(f"\n[Modify Cart]")
+ print(f" Context: {prompt.context}")
+ print(f" Prompt: {prompt.prompt_text[:100]}...")
+
+ prompt = model.generate_edge_prompt("checkout", "stripe")
+ print(f"\n[Checkout → Stripe relationship]")
+ print(f" Prompt: {prompt.prompt_text}")
+
+ # Final diagram
+ print("\n=== Current Diagram ===")
+ print(model.current_model.to_mermaid())
diff --git a/src/agentgit/living_model.py b/src/agentgit/living_model.py
new file mode 100644
index 0000000..0b15832
--- /dev/null
+++ b/src/agentgit/living_model.py
@@ -0,0 +1,363 @@
+"""
+Living Mental Model - Real-time visualization that evolves with your codebase.
+
+This module ties together the mental model with the transcript watcher,
+creating a visualization that automatically adapts as an AI coding agent
+makes changes to the codebase.
+
+The key insight: as you work with AI coding agents, the important thing isn't
+the specific code structure, but maintaining a clear mental model of how
+the system works conceptually. This tool keeps that model current.
+"""
+
+from pathlib import Path
+from typing import Optional, Callable
+from datetime import datetime
+import os
+
+from .core import FileOperation, Transcript
+from .mental_model import (
+ MentalModel,
+ ModelUpdate,
+ apply_update,
+ generate_update_prompt,
+ parse_update_response,
+)
+
+
+class LivingMentalModel:
+ """
+ A mental model that automatically updates as code changes.
+
+ Usage:
+ model = LivingMentalModel(project_name="My App")
+
+ # When changes happen (e.g., from watcher callback):
+ model.process_changes(operations)
+
+ # Get current visualization:
+ print(model.to_mermaid())
+
+ # Save/load state:
+ model.save("mental_model.json")
+ model = LivingMentalModel.load("mental_model.json")
+ """
+
+ def __init__(
+ self,
+ project_name: str = "",
+ model_path: Optional[Path] = None,
+ ai_callback: Optional[Callable[[str], str]] = None,
+ ):
+ """
+ Initialize a living mental model.
+
+ Args:
+ project_name: Name for this project's model
+ model_path: Optional path to persist the model (auto-saves on update)
+ ai_callback: Function to call AI for model updates.
+ Signature: (prompt: str) -> str (AI response)
+ If None, uses a simple rule-based updater.
+ """
+ self.model = MentalModel(project_name=project_name)
+ self.model_path = Path(model_path) if model_path else None
+ self.ai_callback = ai_callback
+
+ # Track what we've seen for incremental updates
+ self._processed_tool_ids: set[str] = set()
+
+ # Load existing model if path provided and exists
+ if self.model_path and self.model_path.exists():
+ self._load()
+
+ def process_changes(
+ self,
+ operations: list[FileOperation],
+ force: bool = False,
+ ) -> Optional[ModelUpdate]:
+ """
+ Process a batch of file operations and update the mental model.
+
+ Args:
+ operations: List of file operations from a transcript
+ force: If True, reprocess even already-seen operations
+
+ Returns:
+ The ModelUpdate that was applied, or None if no update needed
+ """
+ # Filter to new operations only (unless forced)
+ if not force:
+ new_ops = [
+ op for op in operations
+ if op.tool_id and op.tool_id not in self._processed_tool_ids
+ ]
+ else:
+ new_ops = operations
+
+ if not new_ops:
+ return None
+
+ # Mark as processed
+ for op in new_ops:
+ if op.tool_id:
+ self._processed_tool_ids.add(op.tool_id)
+
+ # Get the update
+ update = self._compute_update(new_ops)
+
+ if update and not update.is_empty():
+ # Get prompt_id from first operation with a prompt
+ prompt_id = None
+ for op in new_ops:
+ if op.prompt:
+ prompt_id = op.prompt.prompt_id
+ break
+
+ apply_update(self.model, update, prompt_id)
+
+ # Auto-save if path configured
+ if self.model_path:
+ self._save()
+
+ return update
+
+ def process_transcript(self, transcript: Transcript) -> list[ModelUpdate]:
+ """
+ Process an entire transcript and return all updates made.
+
+ This is useful for building a mental model from a complete session.
+ """
+ updates = []
+
+ # Group operations by prompt for better context
+ ops_by_prompt: dict[str, list[FileOperation]] = {}
+ for op in transcript.operations:
+ key = op.prompt.prompt_id if op.prompt else "unknown"
+ if key not in ops_by_prompt:
+ ops_by_prompt[key] = []
+ ops_by_prompt[key].append(op)
+
+ # Process each group
+ for prompt_id, ops in ops_by_prompt.items():
+ update = self.process_changes(ops)
+ if update:
+ updates.append(update)
+
+ return updates
+
+ def _compute_update(self, operations: list[FileOperation]) -> Optional[ModelUpdate]:
+ """Compute an update to the mental model based on operations."""
+
+ if self.ai_callback:
+ return self._compute_update_with_ai(operations)
+ else:
+ return self._compute_update_rule_based(operations)
+
+ def _compute_update_with_ai(self, operations: list[FileOperation]) -> Optional[ModelUpdate]:
+ """Use AI to interpret changes and propose model updates."""
+
+ prompt = generate_update_prompt(self.model, operations)
+ response = self.ai_callback(prompt)
+ return parse_update_response(response)
+
+ def _compute_update_rule_based(self, operations: list[FileOperation]) -> Optional[ModelUpdate]:
+ """
+ Simple rule-based updater for when no AI is available.
+
+ This provides basic structure based on file patterns, but won't
+ capture semantic meaning like an AI would.
+ """
+ from .mental_model import MentalModelNode, MentalModelEdge
+
+ update = ModelUpdate(reasoning="Rule-based inference from file patterns")
+
+ # Simple heuristics based on file paths
+ for op in operations:
+ path = op.file_path.lower()
+ path_parts = Path(op.file_path).parts
+
+ # Skip if we already have a node for this
+ # (very simplistic - AI would do much better)
+ if len(path_parts) < 2:
+ continue
+
+ # Infer component from directory structure
+ if "src" in path_parts:
+ idx = path_parts.index("src")
+ if idx + 1 < len(path_parts):
+ component_name = path_parts[idx + 1]
+ node_id = f"component_{component_name}"
+
+ if node_id not in self.model.nodes:
+ # Infer type from common patterns
+ node_type = "component"
+ if "api" in component_name or "routes" in component_name:
+ node_type = "interface"
+ elif "service" in component_name:
+ node_type = "service"
+ elif "model" in component_name or "db" in component_name:
+ node_type = "data"
+
+ update.add_nodes.append(MentalModelNode(
+ id=node_id,
+ label=component_name.replace("_", " ").title(),
+ node_type=node_type,
+ description=f"Inferred from {op.file_path}",
+ ))
+
+ return update if not update.is_empty() else None
+
+ def to_mermaid(self) -> str:
+ """Get current model as Mermaid diagram."""
+ return self.model.to_mermaid()
+
+ def to_json(self) -> str:
+ """Get current model as JSON."""
+ return self.model.to_json()
+
+ def save(self, path: Optional[Path] = None) -> None:
+ """Save the model to a file."""
+ save_path = Path(path) if path else self.model_path
+ if not save_path:
+ raise ValueError("No path provided and no default path configured")
+ save_path.write_text(self.model.to_json())
+
+ def _save(self) -> None:
+ """Internal save to configured path."""
+ if self.model_path:
+ self.model_path.parent.mkdir(parents=True, exist_ok=True)
+ self.model_path.write_text(self.model.to_json())
+
+ def _load(self) -> None:
+ """Internal load from configured path."""
+ if self.model_path and self.model_path.exists():
+ self.model = MentalModel.from_json(self.model_path.read_text())
+
+ @classmethod
+ def load(cls, path: Path, ai_callback: Optional[Callable[[str], str]] = None) -> "LivingMentalModel":
+ """Load a model from a file."""
+ instance = cls(model_path=path, ai_callback=ai_callback)
+ return instance
+
+
+def create_watcher_callback(
+ living_model: LivingMentalModel,
+ output_path: Optional[Path] = None,
+ on_update: Optional[Callable[[MentalModel, ModelUpdate], None]] = None,
+):
+ """
+ Create a callback function for use with TranscriptWatcher.
+
+ This integrates the living model with agentgit's watch functionality,
+ automatically updating the mental model as the transcript grows.
+
+ Usage:
+ from agentgit import watch_transcript
+ from agentgit.living_model import LivingMentalModel, create_watcher_callback
+
+ model = LivingMentalModel("My Project")
+ callback = create_watcher_callback(model, output_path=Path("model.md"))
+
+ # Watch will call our callback on each update
+ watch_transcript(
+ "transcript.jsonl",
+ output_dir="./output",
+ on_update=lambda count: callback(latest_operations)
+ )
+ """
+
+ def callback(transcript: Transcript):
+ """Process new operations and update model."""
+ update = living_model.process_changes(transcript.operations)
+
+ if update and not update.is_empty():
+ # Write updated visualization
+ if output_path:
+ mermaid = living_model.to_mermaid()
+ output_path.write_text(f"# Mental Model\n\n```mermaid\n{mermaid}\n```\n")
+
+ # Call user callback
+ if on_update:
+ on_update(living_model.model, update)
+
+ return callback
+
+
+# Convenience function for quick visualization
+def visualize_transcript(
+ transcript: Transcript,
+ ai_callback: Optional[Callable[[str], str]] = None,
+) -> str:
+ """
+ Generate a mental model visualization from a transcript.
+
+ This is a one-shot function for visualizing a complete transcript.
+ For live updates, use LivingMentalModel with a watcher.
+ """
+ model = LivingMentalModel(
+ project_name=transcript.session_id or "Unknown Project",
+ ai_callback=ai_callback,
+ )
+ model.process_transcript(transcript)
+ return model.to_mermaid()
+
+
+# Example with mock AI for testing
+if __name__ == "__main__":
+ from .core import FileOperation, OperationType, Prompt, AssistantContext
+
+ # Mock AI callback that returns a simple response
+ def mock_ai(prompt: str) -> str:
+ return """```json
+{
+ "reasoning": "The changes introduce a new authentication system with user management",
+ "updates": {
+ "add_nodes": [
+ {"id": "auth", "label": "Authentication", "node_type": "service", "description": "Handles user auth"},
+ {"id": "users", "label": "User Management", "node_type": "component", "description": "User CRUD"}
+ ],
+ "add_edges": [
+ {"source_id": "auth", "target_id": "users", "relationship": "uses", "label": "validates"}
+ ],
+ "remove_node_ids": [],
+ "remove_edges": []
+ }
+}
+```"""
+
+ # Create living model
+ model = LivingMentalModel(project_name="Test App", ai_callback=mock_ai)
+
+ # Simulate some operations
+ ops = [
+ FileOperation(
+ file_path="src/auth/login.py",
+ operation_type=OperationType.WRITE,
+ content="# Login handler",
+ timestamp=datetime.now(),
+ tool_id="op1",
+ prompt=Prompt(text="Add user authentication", timestamp=datetime.now()),
+ assistant_context=AssistantContext(
+ thinking="I'll create an authentication module with login/logout"
+ ),
+ ),
+ FileOperation(
+ file_path="src/users/models.py",
+ operation_type=OperationType.WRITE,
+ content="# User model",
+ timestamp=datetime.now(),
+ tool_id="op2",
+ ),
+ ]
+
+ # Process changes
+ update = model.process_changes(ops)
+
+ print("=== Update Applied ===")
+ if update:
+ print(f"Reasoning: {update.reasoning}")
+ print(f"Added nodes: {[n.label for n in update.add_nodes]}")
+ print(f"Added edges: {len(update.add_edges)}")
+
+ print("\n=== Current Mental Model ===")
+ print(model.to_mermaid())
diff --git a/src/agentgit/mental_model.py b/src/agentgit/mental_model.py
new file mode 100644
index 0000000..72c07b0
--- /dev/null
+++ b/src/agentgit/mental_model.py
@@ -0,0 +1,444 @@
+"""
+Living Mental Model - A visualization that evolves with your codebase.
+
+This module provides a "mental model" diagram that automatically adapts as code changes,
+capturing the semantic structure of a system rather than just its file layout.
+
+The key insight: when working with AI coding agents, what matters isn't the exact code
+structure, but understanding *how the system works conceptually* and how that evolves.
+"""
+
+from dataclasses import dataclass, field
+from typing import Optional
+from datetime import datetime
+import json
+import hashlib
+
+from .core import FileOperation, Prompt, AssistantContext
+
+
+@dataclass
+class MentalModelNode:
+ """A node in the mental model - represents a concept, component, or capability."""
+
+ id: str # Stable identifier
+ label: str # Human-readable name
+ node_type: str # e.g., "component", "service", "concept", "data", "interface"
+ description: str = "" # What this node represents
+
+ # Provenance - where did this understanding come from?
+ introduced_by_prompt_id: Optional[str] = None
+ last_updated_prompt_id: Optional[str] = None
+
+ def __hash__(self):
+ return hash(self.id)
+
+
+@dataclass
+class MentalModelEdge:
+ """A relationship between nodes in the mental model."""
+
+ source_id: str
+ target_id: str
+ relationship: str # e.g., "uses", "contains", "transforms", "depends_on"
+ label: str = "" # Optional label for the edge
+
+ # Provenance
+ introduced_by_prompt_id: Optional[str] = None
+
+
+@dataclass
+class MentalModel:
+ """
+ A living representation of the system's conceptual structure.
+
+ This is NOT a file tree or class diagram - it's a semantic model of
+ what the system *does* and how its parts relate conceptually.
+ """
+
+ nodes: dict[str, MentalModelNode] = field(default_factory=dict)
+ edges: list[MentalModelEdge] = field(default_factory=list)
+
+ # Metadata
+ project_name: str = ""
+ version: int = 0 # Incremented on each update
+ last_updated: Optional[datetime] = None
+
+ # History of what changed
+ changelog: list[dict] = field(default_factory=list)
+
+ def add_node(self, node: MentalModelNode, prompt_id: Optional[str] = None) -> None:
+ """Add or update a node in the model."""
+ if node.id in self.nodes:
+ # Update existing
+ existing = self.nodes[node.id]
+ existing.label = node.label
+ existing.description = node.description
+ existing.last_updated_prompt_id = prompt_id
+ else:
+ # Add new
+ node.introduced_by_prompt_id = prompt_id
+ self.nodes[node.id] = node
+
+ def add_edge(self, edge: MentalModelEdge) -> None:
+ """Add an edge if it doesn't already exist."""
+ for existing in self.edges:
+ if (existing.source_id == edge.source_id and
+ existing.target_id == edge.target_id and
+ existing.relationship == edge.relationship):
+ return # Already exists
+ self.edges.append(edge)
+
+ def remove_node(self, node_id: str) -> None:
+ """Remove a node and all its edges."""
+ if node_id in self.nodes:
+ del self.nodes[node_id]
+ self.edges = [e for e in self.edges
+ if e.source_id != node_id and e.target_id != node_id]
+
+ def to_mermaid(self) -> str:
+ """Export the mental model as a Mermaid diagram."""
+ lines = ["graph TD"]
+
+ # Group nodes by type for styling
+ node_types = {}
+ for node in self.nodes.values():
+ if node.node_type not in node_types:
+ node_types[node.node_type] = []
+ node_types[node.node_type].append(node)
+
+ # Add nodes with type-specific shapes
+ shape_map = {
+ "component": ("[", "]"), # Rectangle
+ "service": ("[[", "]]"), # Subroutine shape
+ "concept": ("(", ")"), # Rounded
+ "data": ("[(", ")]"), # Cylinder (database)
+ "interface": ("{{", "}}"), # Hexagon
+ "process": ("([", "])"), # Stadium
+ }
+
+ for node in self.nodes.values():
+ left, right = shape_map.get(node.node_type, ("[", "]"))
+ # Escape quotes in label
+ safe_label = node.label.replace('"', "'")
+ lines.append(f' {node.id}{left}"{safe_label}"{right}')
+
+ # Add edges
+ arrow_map = {
+ "uses": "-->",
+ "contains": "-->",
+ "transforms": "==>",
+ "depends_on": "-.->",
+ "implements": "-->",
+ "extends": "-->",
+ }
+
+ for edge in self.edges:
+ arrow = arrow_map.get(edge.relationship, "-->")
+ if edge.label:
+ lines.append(f' {edge.source_id} {arrow}|"{edge.label}"| {edge.target_id}')
+ else:
+ lines.append(f' {edge.source_id} {arrow} {edge.target_id}')
+
+ # Add styling by node type
+ style_map = {
+ "component": "fill:#e1f5fe",
+ "service": "fill:#fff3e0",
+ "concept": "fill:#f3e5f5",
+ "data": "fill:#e8f5e9",
+ "interface": "fill:#fce4ec",
+ "process": "fill:#fff8e1",
+ }
+
+ for node_type, nodes in node_types.items():
+ if node_type in style_map:
+ node_ids = ",".join(n.id for n in nodes)
+ if node_ids:
+ lines.append(f' style {node_ids} {style_map[node_type]}')
+
+ return "\n".join(lines)
+
+ def to_json(self) -> str:
+ """Serialize the mental model to JSON."""
+ return json.dumps({
+ "project_name": self.project_name,
+ "version": self.version,
+ "last_updated": self.last_updated.isoformat() if self.last_updated else None,
+ "nodes": {
+ nid: {
+ "id": n.id,
+ "label": n.label,
+ "node_type": n.node_type,
+ "description": n.description,
+ "introduced_by_prompt_id": n.introduced_by_prompt_id,
+ "last_updated_prompt_id": n.last_updated_prompt_id,
+ }
+ for nid, n in self.nodes.items()
+ },
+ "edges": [
+ {
+ "source_id": e.source_id,
+ "target_id": e.target_id,
+ "relationship": e.relationship,
+ "label": e.label,
+ "introduced_by_prompt_id": e.introduced_by_prompt_id,
+ }
+ for e in self.edges
+ ],
+ "changelog": self.changelog,
+ }, indent=2)
+
+ @classmethod
+ def from_json(cls, json_str: str) -> "MentalModel":
+ """Deserialize a mental model from JSON."""
+ data = json.loads(json_str)
+ model = cls(
+ project_name=data.get("project_name", ""),
+ version=data.get("version", 0),
+ changelog=data.get("changelog", []),
+ )
+ if data.get("last_updated"):
+ model.last_updated = datetime.fromisoformat(data["last_updated"])
+
+ for nid, ndata in data.get("nodes", {}).items():
+ model.nodes[nid] = MentalModelNode(
+ id=ndata["id"],
+ label=ndata["label"],
+ node_type=ndata["node_type"],
+ description=ndata.get("description", ""),
+ introduced_by_prompt_id=ndata.get("introduced_by_prompt_id"),
+ last_updated_prompt_id=ndata.get("last_updated_prompt_id"),
+ )
+
+ for edata in data.get("edges", []):
+ model.edges.append(MentalModelEdge(
+ source_id=edata["source_id"],
+ target_id=edata["target_id"],
+ relationship=edata["relationship"],
+ label=edata.get("label", ""),
+ introduced_by_prompt_id=edata.get("introduced_by_prompt_id"),
+ ))
+
+ return model
+
+
+@dataclass
+class ModelUpdate:
+ """Describes a proposed update to the mental model."""
+
+ add_nodes: list[MentalModelNode] = field(default_factory=list)
+ remove_node_ids: list[str] = field(default_factory=list)
+ add_edges: list[MentalModelEdge] = field(default_factory=list)
+ remove_edges: list[tuple[str, str, str]] = field(default_factory=list) # (source, target, relationship)
+
+ # Explanation of why this update was made
+ reasoning: str = ""
+
+ def is_empty(self) -> bool:
+ return (not self.add_nodes and not self.remove_node_ids and
+ not self.add_edges and not self.remove_edges)
+
+
+def apply_update(model: MentalModel, update: ModelUpdate, prompt_id: Optional[str] = None) -> MentalModel:
+ """Apply an update to the mental model."""
+
+ # Remove nodes first
+ for node_id in update.remove_node_ids:
+ model.remove_node(node_id)
+
+ # Remove edges
+ for source_id, target_id, relationship in update.remove_edges:
+ model.edges = [e for e in model.edges
+ if not (e.source_id == source_id and
+ e.target_id == target_id and
+ e.relationship == relationship)]
+
+ # Add nodes
+ for node in update.add_nodes:
+ model.add_node(node, prompt_id)
+
+ # Add edges
+ for edge in update.add_edges:
+ edge.introduced_by_prompt_id = prompt_id
+ model.add_edge(edge)
+
+ # Update metadata
+ model.version += 1
+ model.last_updated = datetime.now()
+
+ if update.reasoning:
+ model.changelog.append({
+ "version": model.version,
+ "prompt_id": prompt_id,
+ "reasoning": update.reasoning,
+ "timestamp": model.last_updated.isoformat(),
+ })
+
+ return model
+
+
+def generate_update_prompt(
+ current_model: MentalModel,
+ operations: list[FileOperation],
+) -> str:
+ """
+ Generate a prompt for an AI to analyze changes and propose mental model updates.
+
+ This is the core of the "living mental model" concept - we're asking the AI
+ to interpret code changes in terms of conceptual/architectural impact.
+ """
+
+ # Build context about what changed
+ changes_summary = []
+ for op in operations:
+ change_desc = f"- {op.operation_type.value}: {op.file_path}"
+ if op.prompt:
+ change_desc += f"\n Intent: {op.prompt.text[:200]}..."
+ if op.assistant_context and op.assistant_context.summary:
+ change_desc += f"\n Reasoning: {op.assistant_context.summary[:200]}..."
+ changes_summary.append(change_desc)
+
+ current_diagram = current_model.to_mermaid() if current_model.nodes else "(empty - no existing model)"
+
+ prompt = f"""You are analyzing code changes to update a "mental model" diagram.
+
+The mental model represents the CONCEPTUAL structure of the system - not file paths or
+class hierarchies, but the semantic components, capabilities, and how they relate.
+
+## Current Mental Model
+```mermaid
+{current_diagram}
+```
+
+## Recent Code Changes
+{chr(10).join(changes_summary)}
+
+## Your Task
+Analyze these changes and determine if/how the mental model should be updated.
+
+Consider:
+1. Are new concepts or components being introduced?
+2. Are existing relationships changing?
+3. Is there a shift in how parts of the system interact?
+4. Should any concepts be removed or renamed?
+
+Respond with a JSON object:
+{{
+ "reasoning": "Explain your interpretation of how these changes affect the system's conceptual structure",
+ "updates": {{
+ "add_nodes": [
+ {{"id": "unique_id", "label": "Human Name", "node_type": "component|service|concept|data|interface|process", "description": "What this represents"}}
+ ],
+ "remove_node_ids": ["node_id_to_remove"],
+ "add_edges": [
+ {{"source_id": "from_node", "target_id": "to_node", "relationship": "uses|contains|transforms|depends_on", "label": "optional edge label"}}
+ ],
+ "remove_edges": [["source_id", "target_id", "relationship"]]
+ }}
+}}
+
+If no updates are needed, return: {{"reasoning": "...", "updates": null}}
+"""
+
+ return prompt
+
+
+def parse_update_response(response: str) -> Optional[ModelUpdate]:
+ """Parse an AI response into a ModelUpdate."""
+ import re
+
+ # Try to extract JSON from the response
+ # Look for JSON block or the whole response
+ json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
+ if json_match:
+ json_str = json_match.group(1)
+ else:
+ # Try to find raw JSON
+ json_match = re.search(r'\{.*\}', response, re.DOTALL)
+ if json_match:
+ json_str = json_match.group(0)
+ else:
+ return None
+
+ try:
+ data = json.loads(json_str)
+ except json.JSONDecodeError:
+ return None
+
+ if data.get("updates") is None:
+ return ModelUpdate(reasoning=data.get("reasoning", "No changes needed"))
+
+ updates = data["updates"]
+
+ update = ModelUpdate(reasoning=data.get("reasoning", ""))
+
+ for node_data in updates.get("add_nodes", []):
+ update.add_nodes.append(MentalModelNode(
+ id=node_data["id"],
+ label=node_data["label"],
+ node_type=node_data.get("node_type", "component"),
+ description=node_data.get("description", ""),
+ ))
+
+ update.remove_node_ids = updates.get("remove_node_ids", [])
+
+ for edge_data in updates.get("add_edges", []):
+ update.add_edges.append(MentalModelEdge(
+ source_id=edge_data["source_id"],
+ target_id=edge_data["target_id"],
+ relationship=edge_data.get("relationship", "uses"),
+ label=edge_data.get("label", ""),
+ ))
+
+ update.remove_edges = [tuple(e) for e in updates.get("remove_edges", [])]
+
+ return update
+
+
+# Example usage and testing
+if __name__ == "__main__":
+ # Create a simple mental model
+ model = MentalModel(project_name="Example System")
+
+ # Add some nodes
+ model.add_node(MentalModelNode(
+ id="api",
+ label="REST API",
+ node_type="interface",
+ description="External API for client applications"
+ ))
+
+ model.add_node(MentalModelNode(
+ id="processor",
+ label="Data Processor",
+ node_type="service",
+ description="Transforms and validates incoming data"
+ ))
+
+ model.add_node(MentalModelNode(
+ id="storage",
+ label="Data Store",
+ node_type="data",
+ description="Persistent storage for processed data"
+ ))
+
+ # Add relationships
+ model.add_edge(MentalModelEdge(
+ source_id="api",
+ target_id="processor",
+ relationship="uses",
+ label="sends requests"
+ ))
+
+ model.add_edge(MentalModelEdge(
+ source_id="processor",
+ target_id="storage",
+ relationship="transforms",
+ label="persists"
+ ))
+
+ print("=== Mental Model (Mermaid) ===")
+ print(model.to_mermaid())
+ print()
+ print("=== Mental Model (JSON) ===")
+ print(model.to_json())
diff --git a/tests/test_reframe.py b/tests/test_reframe.py
new file mode 100644
index 0000000..7588b10
--- /dev/null
+++ b/tests/test_reframe.py
@@ -0,0 +1,1172 @@
+"""Tests for Reframe (mental model enhancer)."""
+
+import json
+import pytest
+from datetime import datetime
+from pathlib import Path
+
+from agentgit.enhancers.mental_model import (
+ MentalModel,
+ ModelElement,
+ ModelRelation,
+ ModelSnapshot,
+ SequenceStep,
+ SequenceDiagram,
+ MentalModelEnhancer,
+ build_observation_prompt,
+ build_instruction_prompt,
+ build_focused_prompt,
+ reframe,
+ load_mental_model,
+ get_mental_model_enhancer,
+ _truncate,
+)
+
+
+class TestModelElement:
+ """Tests for ModelElement dataclass."""
+
+ def test_element_creation(self):
+ """Should create element with required fields."""
+ elem = ModelElement(id="auth", label="Authentication")
+ assert elem.id == "auth"
+ assert elem.label == "Authentication"
+ assert elem.shape == "box" # default
+ assert elem.color is None
+ assert elem.created_by == ""
+ assert elem.reasoning == ""
+
+ def test_element_with_all_fields(self):
+ """Should create element with all fields."""
+ elem = ModelElement(
+ id="auth",
+ label="Authentication",
+ properties={"complexity": "high"},
+ shape="hexagon",
+ color="#ff0000",
+ created_by="ai",
+ reasoning="Handles user identity",
+ )
+ assert elem.properties == {"complexity": "high"}
+ assert elem.shape == "hexagon"
+ assert elem.color == "#ff0000"
+ assert elem.created_by == "ai"
+ assert elem.reasoning == "Handles user identity"
+
+ def test_element_to_dict(self):
+ """Should serialize to dictionary."""
+ elem = ModelElement(
+ id="auth",
+ label="Auth",
+ shape="circle",
+ color="#00ff00",
+ created_by="human",
+ reasoning="test",
+ )
+ d = elem.to_dict()
+ assert d["id"] == "auth"
+ assert d["label"] == "Auth"
+ assert d["shape"] == "circle"
+ assert d["color"] == "#00ff00"
+ assert d["created_by"] == "human"
+ assert d["reasoning"] == "test"
+
+ def test_element_from_dict(self):
+ """Should deserialize from dictionary."""
+ d = {
+ "id": "cart",
+ "label": "Shopping Cart",
+ "properties": {"items": 0},
+ "shape": "rounded",
+ "color": "#0000ff",
+ "created_by": "ai",
+ "reasoning": "Stores items",
+ }
+ elem = ModelElement.from_dict(d)
+ assert elem.id == "cart"
+ assert elem.label == "Shopping Cart"
+ assert elem.properties == {"items": 0}
+ assert elem.shape == "rounded"
+ assert elem.color == "#0000ff"
+
+ def test_element_equality_by_id(self):
+ """Elements with same id should be distinguishable."""
+ elem1 = ModelElement(id="test", label="Test")
+ elem2 = ModelElement(id="test", label="Different Label")
+ # Dataclasses compare by all fields, so these are different
+ assert elem1.id == elem2.id
+ assert elem1.label != elem2.label
+
+
+class TestModelRelation:
+ """Tests for ModelRelation dataclass."""
+
+ def test_relation_creation(self):
+ """Should create relation with required fields."""
+ rel = ModelRelation(source_id="a", target_id="b")
+ assert rel.source_id == "a"
+ assert rel.target_id == "b"
+ assert rel.label == ""
+ assert rel.style == "solid"
+
+ def test_relation_with_all_fields(self):
+ """Should create relation with all fields."""
+ rel = ModelRelation(
+ source_id="auth",
+ target_id="users",
+ label="validates",
+ properties={"async": True},
+ style="dashed",
+ created_by="ai",
+ reasoning="Auth checks user DB",
+ )
+ assert rel.label == "validates"
+ assert rel.properties == {"async": True}
+ assert rel.style == "dashed"
+
+ def test_relation_to_dict(self):
+ """Should serialize to dictionary."""
+ rel = ModelRelation(
+ source_id="a",
+ target_id="b",
+ label="uses",
+ style="dotted",
+ )
+ d = rel.to_dict()
+ assert d["source_id"] == "a"
+ assert d["target_id"] == "b"
+ assert d["label"] == "uses"
+ assert d["style"] == "dotted"
+
+ def test_relation_from_dict(self):
+ """Should deserialize from dictionary."""
+ d = {
+ "source_id": "x",
+ "target_id": "y",
+ "label": "depends",
+ "style": "solid",
+ }
+ rel = ModelRelation.from_dict(d)
+ assert rel.source_id == "x"
+ assert rel.target_id == "y"
+ assert rel.label == "depends"
+
+
+class TestSequenceStep:
+ """Tests for SequenceStep dataclass."""
+
+ def test_step_creation(self):
+ """Should create step with required fields."""
+ step = SequenceStep(
+ from_participant="Client",
+ to_participant="Server",
+ message="GET /api/users",
+ )
+ assert step.from_participant == "Client"
+ assert step.to_participant == "Server"
+ assert step.message == "GET /api/users"
+ assert step.step_type == "sync" # default
+ assert step.note is None
+
+ def test_step_with_all_fields(self):
+ """Should create step with all fields."""
+ step = SequenceStep(
+ from_participant="API",
+ to_participant="Database",
+ message="SELECT * FROM users",
+ step_type="async",
+ note="May take a while",
+ )
+ assert step.step_type == "async"
+ assert step.note == "May take a while"
+
+
+class TestSequenceDiagram:
+ """Tests for SequenceDiagram dataclass."""
+
+ def test_diagram_creation(self):
+ """Should create empty diagram."""
+ diagram = SequenceDiagram(id="login", title="User Login Flow")
+ assert diagram.id == "login"
+ assert diagram.title == "User Login Flow"
+ assert diagram.participants == []
+ assert diagram.steps == []
+
+ def test_diagram_with_participants(self):
+ """Should create diagram with participants."""
+ diagram = SequenceDiagram(
+ id="api-flow",
+ title="API Request",
+ participants=["Client", "API", "Database"],
+ )
+ assert len(diagram.participants) == 3
+ assert "API" in diagram.participants
+
+ def test_to_mermaid_basic(self):
+ """Should generate valid mermaid sequence diagram."""
+ diagram = SequenceDiagram(
+ id="test",
+ title="Test Flow",
+ participants=["Client", "Server"],
+ )
+ diagram.steps.append(SequenceStep(
+ from_participant="Client",
+ to_participant="Server",
+ message="Request",
+ ))
+
+ mermaid = diagram.to_mermaid()
+
+ assert "sequenceDiagram" in mermaid
+ assert "title Test Flow" in mermaid
+ assert "participant Client" in mermaid
+ assert "participant Server" in mermaid
+ assert "Client->>Server: Request" in mermaid
+
+ def test_to_mermaid_arrow_styles(self):
+ """Should use correct arrow styles for step types."""
+ diagram = SequenceDiagram(
+ id="test",
+ title="",
+ participants=["A", "B"],
+ )
+ diagram.steps.append(SequenceStep(
+ from_participant="A",
+ to_participant="B",
+ message="sync",
+ step_type="sync",
+ ))
+ diagram.steps.append(SequenceStep(
+ from_participant="B",
+ to_participant="A",
+ message="async",
+ step_type="async",
+ ))
+ diagram.steps.append(SequenceStep(
+ from_participant="A",
+ to_participant="B",
+ message="reply",
+ step_type="reply",
+ ))
+
+ mermaid = diagram.to_mermaid()
+
+ assert "A->>B: sync" in mermaid
+ assert "B-->>A: async" in mermaid
+ assert "A-->>B: reply" in mermaid
+
+ def test_to_mermaid_with_notes(self):
+ """Should include notes in mermaid output."""
+ diagram = SequenceDiagram(
+ id="test",
+ title="",
+ participants=["A", "B"],
+ )
+ diagram.steps.append(SequenceStep(
+ from_participant="A",
+ to_participant="B",
+ message="Request",
+ note="Important step",
+ ))
+
+ mermaid = diagram.to_mermaid()
+
+ assert "Note over A,B: Important step" in mermaid
+
+ def test_to_mermaid_sanitizes_spaces(self):
+ """Should sanitize participant names with spaces."""
+ diagram = SequenceDiagram(
+ id="test",
+ title="",
+ participants=["API Gateway"],
+ )
+ diagram.steps.append(SequenceStep(
+ from_participant="API Gateway",
+ to_participant="API Gateway",
+ message="Self call",
+ ))
+
+ mermaid = diagram.to_mermaid()
+
+ assert "participant API_Gateway as API Gateway" in mermaid
+ assert "API_Gateway->>API_Gateway: Self call" in mermaid
+
+ def test_to_dict(self):
+ """Should serialize to dictionary."""
+ diagram = SequenceDiagram(
+ id="login",
+ title="Login",
+ participants=["User", "Auth"],
+ description="User authentication flow",
+ created_by="ai",
+ trigger="user login",
+ )
+ diagram.steps.append(SequenceStep(
+ from_participant="User",
+ to_participant="Auth",
+ message="Login",
+ step_type="sync",
+ note="With credentials",
+ ))
+
+ d = diagram.to_dict()
+
+ assert d["id"] == "login"
+ assert d["title"] == "Login"
+ assert d["participants"] == ["User", "Auth"]
+ assert d["description"] == "User authentication flow"
+ assert len(d["steps"]) == 1
+ assert d["steps"][0]["from"] == "User"
+ assert d["steps"][0]["to"] == "Auth"
+ assert d["steps"][0]["message"] == "Login"
+ assert d["steps"][0]["type"] == "sync"
+ assert d["steps"][0]["note"] == "With credentials"
+
+ def test_from_dict(self):
+ """Should deserialize from dictionary."""
+ d = {
+ "id": "checkout",
+ "title": "Checkout Flow",
+ "participants": ["Cart", "Payment", "Order"],
+ "steps": [
+ {"from": "Cart", "to": "Payment", "message": "Pay", "type": "sync"},
+ {"from": "Payment", "to": "Order", "message": "Create", "type": "async", "note": "Background"},
+ ],
+ "description": "Purchase flow",
+ "created_by": "human",
+ "trigger": "checkout button",
+ }
+
+ diagram = SequenceDiagram.from_dict(d)
+
+ assert diagram.id == "checkout"
+ assert diagram.title == "Checkout Flow"
+ assert len(diagram.participants) == 3
+ assert len(diagram.steps) == 2
+ assert diagram.steps[0].from_participant == "Cart"
+ assert diagram.steps[1].note == "Background"
+
+
+class TestMentalModel:
+ """Tests for MentalModel dataclass."""
+
+ def test_empty_model(self):
+ """Should create empty model."""
+ model = MentalModel()
+ assert len(model.elements) == 0
+ assert len(model.relations) == 0
+ assert model.version == 0
+ assert model.ai_summary == ""
+
+ def test_add_elements(self):
+ """Should add elements to model."""
+ model = MentalModel()
+ model.elements["auth"] = ModelElement(id="auth", label="Auth")
+ model.elements["db"] = ModelElement(id="db", label="Database")
+ assert len(model.elements) == 2
+ assert "auth" in model.elements
+ assert "db" in model.elements
+
+ def test_add_relations(self):
+ """Should add relations to model."""
+ model = MentalModel()
+ model.elements["a"] = ModelElement(id="a", label="A")
+ model.elements["b"] = ModelElement(id="b", label="B")
+ model.relations.append(ModelRelation(source_id="a", target_id="b"))
+ assert len(model.relations) == 1
+
+ def test_snapshot_saves_state(self):
+ """Should save current state as snapshot."""
+ model = MentalModel()
+ model.elements["test"] = ModelElement(id="test", label="Test")
+ model.version = 1
+
+ model.snapshot("first snapshot")
+
+ assert len(model.snapshots) == 1
+ assert model.snapshots[0].version == 1
+ assert model.snapshots[0].trigger == "first snapshot"
+ assert "test" in model.snapshots[0].elements
+
+ def test_snapshot_is_deep_copy(self):
+ """Snapshot should be independent of current state."""
+ model = MentalModel()
+ model.elements["test"] = ModelElement(id="test", label="Original")
+ model.snapshot("before change")
+
+ # Modify current state
+ model.elements["test"].label = "Modified"
+
+ # Snapshot should still have original
+ assert model.snapshots[0].elements["test"].label == "Original"
+
+ def test_to_mermaid_empty(self):
+ """Empty model should produce minimal mermaid."""
+ model = MentalModel()
+ mermaid = model.to_mermaid()
+ assert "graph TD" in mermaid
+
+ def test_to_mermaid_with_elements(self):
+ """Should generate valid mermaid with elements."""
+ model = MentalModel()
+ model.elements["api"] = ModelElement(id="api", label="REST API", shape="hexagon")
+ model.elements["db"] = ModelElement(id="db", label="Database", shape="cylinder")
+
+ mermaid = model.to_mermaid()
+
+ assert "graph TD" in mermaid
+ assert 'api{{"REST API"}}' in mermaid
+ assert 'db[("Database")]' in mermaid
+
+ def test_to_mermaid_with_relations(self):
+ """Should generate mermaid with relationships."""
+ model = MentalModel()
+ model.elements["a"] = ModelElement(id="a", label="A")
+ model.elements["b"] = ModelElement(id="b", label="B")
+ model.relations.append(ModelRelation(source_id="a", target_id="b", label="uses"))
+
+ mermaid = model.to_mermaid()
+
+ assert 'a -->|"uses"| b' in mermaid
+
+ def test_to_mermaid_with_colors(self):
+ """Should include style directives for colors."""
+ model = MentalModel()
+ model.elements["test"] = ModelElement(id="test", label="Test", color="#ff0000")
+
+ mermaid = model.to_mermaid()
+
+ assert "style test fill:#ff0000" in mermaid
+
+ def test_to_mermaid_shapes(self):
+ """Should use correct mermaid shapes."""
+ model = MentalModel()
+ model.elements["box"] = ModelElement(id="box", label="Box", shape="box")
+ model.elements["rounded"] = ModelElement(id="rounded", label="Rounded", shape="rounded")
+ model.elements["circle"] = ModelElement(id="circle", label="Circle", shape="circle")
+ model.elements["diamond"] = ModelElement(id="diamond", label="Diamond", shape="diamond")
+ model.elements["cylinder"] = ModelElement(id="cylinder", label="Cylinder", shape="cylinder")
+ model.elements["hexagon"] = ModelElement(id="hexagon", label="Hexagon", shape="hexagon")
+ model.elements["stadium"] = ModelElement(id="stadium", label="Stadium", shape="stadium")
+
+ mermaid = model.to_mermaid()
+
+ assert 'box["Box"]' in mermaid
+ assert 'rounded("Rounded")' in mermaid
+ assert 'circle(("Circle"))' in mermaid
+ assert 'diamond{"Diamond"}' in mermaid
+ assert 'cylinder[("Cylinder")]' in mermaid
+ assert 'hexagon{{"Hexagon"}}' in mermaid
+ assert 'stadium(["Stadium"])' in mermaid
+
+ def test_to_mermaid_relation_styles(self):
+ """Should use correct arrow styles for relations."""
+ model = MentalModel()
+ model.elements["a"] = ModelElement(id="a", label="A")
+ model.elements["b"] = ModelElement(id="b", label="B")
+ model.elements["c"] = ModelElement(id="c", label="C")
+ model.elements["d"] = ModelElement(id="d", label="D")
+
+ model.relations.append(ModelRelation(source_id="a", target_id="b", style="solid"))
+ model.relations.append(ModelRelation(source_id="b", target_id="c", style="dashed"))
+ model.relations.append(ModelRelation(source_id="c", target_id="d", style="thick"))
+
+ mermaid = model.to_mermaid()
+
+ assert "a --> b" in mermaid
+ assert "b -.-> c" in mermaid
+ assert "c ==> d" in mermaid
+
+ def test_to_json_roundtrip(self):
+ """Should serialize and deserialize correctly."""
+ model = MentalModel()
+ model.elements["test"] = ModelElement(
+ id="test", label="Test", shape="hexagon", color="#123456"
+ )
+ model.relations.append(ModelRelation(source_id="test", target_id="test", label="self"))
+ model.version = 5
+ model.ai_summary = "A test system"
+
+ json_str = model.to_json()
+ restored = MentalModel.from_dict(json.loads(json_str))
+
+ assert len(restored.elements) == 1
+ assert restored.elements["test"].label == "Test"
+ assert restored.elements["test"].shape == "hexagon"
+ assert len(restored.relations) == 1
+ assert restored.relations[0].label == "self"
+ assert restored.version == 5
+ assert restored.ai_summary == "A test system"
+
+ def test_to_dict(self):
+ """Should convert to dictionary."""
+ model = MentalModel()
+ model.elements["x"] = ModelElement(id="x", label="X")
+ model.version = 3
+
+ d = model.to_dict()
+
+ assert "elements" in d
+ assert "relations" in d
+ assert "version" in d
+ assert d["version"] == 3
+
+ def test_sequences_in_model(self):
+ """Should support sequence diagrams."""
+ model = MentalModel()
+ diagram = SequenceDiagram(
+ id="login",
+ title="Login Flow",
+ participants=["User", "Auth", "DB"],
+ )
+ diagram.steps.append(SequenceStep(
+ from_participant="User",
+ to_participant="Auth",
+ message="Login",
+ ))
+ model.sequences["login"] = diagram
+
+ assert len(model.sequences) == 1
+ assert "login" in model.sequences
+ assert model.sequences["login"].title == "Login Flow"
+
+ def test_to_dict_includes_sequences(self):
+ """Should include sequences in to_dict."""
+ model = MentalModel()
+ model.sequences["test"] = SequenceDiagram(
+ id="test",
+ title="Test",
+ participants=["A", "B"],
+ )
+
+ d = model.to_dict()
+
+ assert "sequences" in d
+ assert "test" in d["sequences"]
+ assert d["sequences"]["test"]["title"] == "Test"
+
+ def test_from_dict_restores_sequences(self):
+ """Should restore sequences from dict."""
+ d = {
+ "elements": {},
+ "relations": [],
+ "sequences": {
+ "checkout": {
+ "id": "checkout",
+ "title": "Checkout",
+ "participants": ["Cart", "Payment"],
+ "steps": [
+ {"from": "Cart", "to": "Payment", "message": "Pay", "type": "sync"}
+ ],
+ }
+ },
+ "version": 1,
+ "ai_summary": "",
+ }
+
+ model = MentalModel.from_dict(d)
+
+ assert len(model.sequences) == 1
+ assert "checkout" in model.sequences
+ assert model.sequences["checkout"].title == "Checkout"
+ assert len(model.sequences["checkout"].steps) == 1
+
+ def test_to_full_mermaid(self):
+ """Should export both component and sequence diagrams."""
+ model = MentalModel()
+ model.elements["api"] = ModelElement(id="api", label="API")
+ model.sequences["flow"] = SequenceDiagram(
+ id="flow",
+ title="Request Flow",
+ participants=["Client", "Server"],
+ description="How requests work",
+ )
+ model.sequences["flow"].steps.append(SequenceStep(
+ from_participant="Client",
+ to_participant="Server",
+ message="Request",
+ ))
+
+ full = model.to_full_mermaid()
+
+ # Should have component diagram
+ assert "## Component Diagram" in full
+ assert "graph TD" in full
+ assert "API" in full
+
+ # Should have sequence diagram
+ assert "## Request Flow" in full
+ assert "How requests work" in full
+ assert "sequenceDiagram" in full
+ assert "Client->>Server: Request" in full
+
+ def test_to_force_graph_empty(self):
+ """Empty model should produce empty nodes and links."""
+ model = MentalModel()
+ graph = model.to_force_graph()
+
+ assert graph["nodes"] == []
+ assert graph["links"] == []
+
+ def test_to_force_graph_with_elements(self):
+ """Should export elements as nodes."""
+ model = MentalModel()
+ model.elements["api"] = ModelElement(
+ id="api",
+ label="API Gateway",
+ shape="hexagon",
+ color="#e1f5fe",
+ reasoning="Entry point for all requests",
+ )
+ model.elements["db"] = ModelElement(
+ id="db",
+ label="Database",
+ shape="cylinder",
+ properties={"size": 2, "group": "storage"},
+ )
+
+ graph = model.to_force_graph()
+
+ assert len(graph["nodes"]) == 2
+
+ api_node = next(n for n in graph["nodes"] if n["id"] == "api")
+ assert api_node["name"] == "API Gateway"
+ assert api_node["group"] == "hexagon" # falls back to shape
+ assert api_node["color"] == "#e1f5fe"
+ assert api_node["description"] == "Entry point for all requests"
+ assert api_node["val"] == 1 # default size
+
+ db_node = next(n for n in graph["nodes"] if n["id"] == "db")
+ assert db_node["name"] == "Database"
+ assert db_node["group"] == "storage" # from properties
+ assert db_node["val"] == 2 # from properties.size
+
+ def test_to_force_graph_with_relations(self):
+ """Should export relations as links."""
+ model = MentalModel()
+ model.elements["a"] = ModelElement(id="a", label="A")
+ model.elements["b"] = ModelElement(id="b", label="B")
+ model.relations.append(ModelRelation(
+ source_id="a",
+ target_id="b",
+ label="calls",
+ style="dashed",
+ reasoning="A invokes B for processing",
+ ))
+
+ graph = model.to_force_graph()
+
+ assert len(graph["links"]) == 1
+ link = graph["links"][0]
+ assert link["source"] == "a"
+ assert link["target"] == "b"
+ assert link["label"] == "calls"
+ assert link["style"] == "dashed"
+ assert link["description"] == "A invokes B for processing"
+
+ def test_to_force_graph_includes_files(self):
+ """Should include associated files in nodes."""
+ model = MentalModel()
+ model.elements["auth"] = ModelElement(id="auth", label="Auth")
+ model.element_files["auth"] = {"src/auth.py", "src/auth_utils.py"}
+
+ graph = model.to_force_graph()
+
+ auth_node = graph["nodes"][0]
+ assert "files" in auth_node
+ assert set(auth_node["files"]) == {"src/auth.py", "src/auth_utils.py"}
+
+ def test_to_force_graph_omits_empty_fields(self):
+ """Should not include empty optional fields."""
+ model = MentalModel()
+ model.elements["simple"] = ModelElement(id="simple", label="Simple")
+ model.relations.append(ModelRelation(source_id="simple", target_id="simple"))
+
+ graph = model.to_force_graph()
+
+ node = graph["nodes"][0]
+ assert "color" not in node
+ assert "description" not in node
+ assert "files" not in node
+
+ link = graph["links"][0]
+ assert "label" not in link
+ assert "style" not in link
+ assert "description" not in link
+
+
+class TestMentalModelEnhancer:
+ """Tests for MentalModelEnhancer class."""
+
+ def test_init_without_repo(self):
+ """Should initialize without repo path."""
+ enhancer = MentalModelEnhancer()
+ assert enhancer.model is not None
+ assert enhancer._repo_path is None
+
+ def test_init_with_nonexistent_repo(self, tmp_path):
+ """Should initialize with new repo path."""
+ repo_path = tmp_path / "new_repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+
+ assert enhancer._repo_path == repo_path
+ assert len(enhancer.model.elements) == 0
+
+ def test_init_loads_existing_model(self, tmp_path):
+ """Should load existing model from repo."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ # Create existing model file
+ model_dir = repo_path / ".agentgit"
+ model_dir.mkdir()
+ model_file = model_dir / "mental_model.json"
+
+ existing_model = {
+ "elements": {
+ "existing": {
+ "id": "existing",
+ "label": "Existing Element",
+ "shape": "box",
+ }
+ },
+ "relations": [],
+ "version": 10,
+ "ai_summary": "Existing model",
+ }
+ model_file.write_text(json.dumps(existing_model))
+
+ enhancer = MentalModelEnhancer(repo_path)
+
+ assert len(enhancer.model.elements) == 1
+ assert "existing" in enhancer.model.elements
+ assert enhancer.model.version == 10
+
+ def test_model_path_property(self, tmp_path):
+ """Should return correct model path."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+
+ assert enhancer.model_path == repo_path / ".agentgit" / "mental_model.json"
+
+ def test_model_path_none_without_repo(self):
+ """Should return None if no repo configured."""
+ enhancer = MentalModelEnhancer()
+ assert enhancer.model_path is None
+
+ def test_save_creates_directory(self, tmp_path):
+ """Should create .agentgit directory if needed."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+ enhancer.model.elements["test"] = ModelElement(id="test", label="Test")
+ enhancer.save()
+
+ assert (repo_path / ".agentgit").exists()
+ assert (repo_path / ".agentgit" / "mental_model.json").exists()
+
+ def test_save_writes_model(self, tmp_path):
+ """Should write model to file."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+ enhancer.model.elements["test"] = ModelElement(id="test", label="Test Element")
+ enhancer.model.version = 42
+ enhancer.save()
+
+ content = (repo_path / ".agentgit" / "mental_model.json").read_text()
+ data = json.loads(content)
+
+ assert "test" in data["elements"]
+ assert data["elements"]["test"]["label"] == "Test Element"
+ assert data["version"] == 42
+
+ def test_save_insights_creates_file(self, tmp_path):
+ """Should create insights file with header."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+ path = enhancer.save_insights()
+
+ assert path is not None
+ assert path.exists()
+ content = path.read_text()
+ assert "# Reframe Insights" in content
+
+ def test_save_insights_appends(self, tmp_path):
+ """Should append insights to file."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+ enhancer.save_insights("First insight")
+ enhancer.save_insights("Second insight")
+
+ content = (repo_path / ".agentgit" / "mental_model.md").read_text()
+
+ assert "First insight" in content
+ assert "Second insight" in content
+
+ def test_save_insights_includes_timestamp(self, tmp_path):
+ """Should include timestamp with each insight."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+ enhancer.save_insights("Test insight")
+
+ content = (repo_path / ".agentgit" / "mental_model.md").read_text()
+
+ # Should have a timestamp header like "## 2026-01-10 14:30"
+ import re
+ assert re.search(r"## \d{4}-\d{2}-\d{2} \d{2}:\d{2}", content)
+
+ def test_load_insights_returns_content(self, tmp_path):
+ """Should load existing insights."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+ enhancer.save_insights("Test insight content")
+
+ loaded = enhancer.load_insights()
+
+ assert "Test insight content" in loaded
+
+ def test_load_insights_empty_if_no_file(self, tmp_path):
+ """Should return empty string if no insights file."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+ loaded = enhancer.load_insights()
+
+ assert loaded == ""
+
+ def test_add_human_insight(self, tmp_path):
+ """Should add human insight with marker."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+ enhancer.add_human_insight("The checkout is actually three steps")
+
+ content = (repo_path / ".agentgit" / "mental_model.md").read_text()
+
+ assert "**Human insight:**" in content
+ assert "The checkout is actually three steps" in content
+
+
+class TestPromptBuilding:
+ """Tests for prompt generation functions."""
+
+ def test_build_observation_prompt_empty_model(self):
+ """Should handle empty model."""
+ model = MentalModel()
+ context = {"prompts": ["Add auth"], "files_changed": ["auth.py"]}
+
+ prompt = build_observation_prompt(model, context)
+
+ assert "mental model" in prompt.lower()
+ assert "(empty)" in prompt
+ assert "auth.py" in prompt
+
+ def test_build_observation_prompt_with_model(self):
+ """Should include current model diagram."""
+ model = MentalModel()
+ model.elements["test"] = ModelElement(id="test", label="Test Component")
+ context = {"prompts": [], "files_changed": []}
+
+ prompt = build_observation_prompt(model, context)
+
+ assert "Test Component" in prompt
+ assert "```mermaid" in prompt
+
+ def test_build_observation_prompt_with_insights(self):
+ """Should include accumulated insights."""
+ model = MentalModel()
+ context = {"prompts": [], "files_changed": []}
+ insights = "Previous insight: The system uses microservices"
+
+ prompt = build_observation_prompt(model, context, insights)
+
+ assert "Previous insight" in prompt
+ assert "microservices" in prompt
+
+ def test_build_observation_prompt_truncates_long_insights(self):
+ """Should truncate very long insights."""
+ model = MentalModel()
+ context = {}
+ long_insights = "x" * 5000
+
+ prompt = build_observation_prompt(model, context, long_insights)
+
+ assert "(earlier insights truncated)" in prompt
+
+ def test_build_observation_prompt_includes_summary(self):
+ """Should include AI summary if present."""
+ model = MentalModel()
+ model.ai_summary = "An e-commerce platform"
+ context = {}
+
+ prompt = build_observation_prompt(model, context)
+
+ assert "An e-commerce platform" in prompt
+
+ def test_build_instruction_prompt_with_selection(self):
+ """Should include selected elements."""
+ model = MentalModel()
+ model.elements["checkout"] = ModelElement(
+ id="checkout",
+ label="Checkout",
+ reasoning="Where purchases happen",
+ )
+
+ prompt = build_instruction_prompt(
+ model,
+ selected_ids=["checkout"],
+ instruction="Split this into three steps",
+ )
+
+ assert "Checkout" in prompt
+ assert "Split this into three steps" in prompt
+ assert "Where purchases happen" in prompt
+
+ def test_build_instruction_prompt_empty_selection(self):
+ """Should handle empty selection."""
+ model = MentalModel()
+
+ prompt = build_instruction_prompt(
+ model,
+ selected_ids=[],
+ instruction="Add a database component",
+ )
+
+ assert "Add a database component" in prompt
+ assert "No specific elements" in prompt
+
+ def test_build_focused_prompt_explore(self):
+ """Should generate explore prompt."""
+ model = MentalModel()
+ model.elements["auth"] = ModelElement(id="auth", label="Authentication")
+
+ prompt = build_focused_prompt(model, "auth", "explore")
+
+ assert "Authentication" in prompt
+ assert "Tell me" in prompt or "tell me" in prompt
+
+ def test_build_focused_prompt_debug(self):
+ """Should generate debug prompt."""
+ model = MentalModel()
+ model.elements["auth"] = ModelElement(id="auth", label="Authentication")
+
+ prompt = build_focused_prompt(model, "auth", "debug")
+
+ assert "issues" in prompt.lower() or "help" in prompt.lower()
+
+ def test_build_focused_prompt_unknown_element(self):
+ """Should handle unknown element."""
+ model = MentalModel()
+
+ prompt = build_focused_prompt(model, "nonexistent", "explore")
+
+ assert "not found" in prompt.lower()
+
+ def test_build_focused_prompt_mentions_element(self):
+ """Should mention the element in the prompt."""
+ model = MentalModel()
+ model.elements["auth"] = ModelElement(id="auth", label="Authentication")
+ model.elements["users"] = ModelElement(id="users", label="User Store")
+ model.relations.append(
+ ModelRelation(source_id="auth", target_id="users", label="validates")
+ )
+
+ prompt = build_focused_prompt(model, "auth", "explore")
+
+ # Prompt should mention the element being explored
+ assert "Authentication" in prompt
+
+
+class TestHelperFunctions:
+ """Tests for helper functions."""
+
+ def test_truncate_short_text(self):
+ """Should not truncate short text."""
+ result = _truncate("short", 100)
+ assert result == "short"
+
+ def test_truncate_long_text(self):
+ """Should truncate and add ellipsis."""
+ result = _truncate("a" * 100, 10)
+ assert len(result) == 10
+ assert result.endswith("...")
+
+ def test_truncate_exact_length(self):
+ """Should not truncate at exact length."""
+ result = _truncate("exact", 5)
+ assert result == "exact"
+
+
+class TestGlobalFunctions:
+ """Tests for module-level functions."""
+
+ def test_load_mental_model_exists(self, tmp_path):
+ """Should load model from repo."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+ model_dir = repo_path / ".agentgit"
+ model_dir.mkdir()
+
+ model_data = {
+ "elements": {"test": {"id": "test", "label": "Test", "shape": "box"}},
+ "relations": [],
+ "version": 5,
+ "ai_summary": "Test model",
+ }
+ (model_dir / "mental_model.json").write_text(json.dumps(model_data))
+
+ loaded = load_mental_model(repo_path)
+
+ assert loaded is not None
+ assert "test" in loaded.elements
+ assert loaded.version == 5
+
+ def test_load_mental_model_not_exists(self, tmp_path):
+ """Should return None if model doesn't exist."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ loaded = load_mental_model(repo_path)
+
+ assert loaded is None
+
+ def test_get_mental_model_enhancer_creates_instance(self):
+ """Should create enhancer instance."""
+ # Reset global state
+ import agentgit.enhancers.mental_model as mm
+ mm._enhancer_instance = None
+
+ enhancer = get_mental_model_enhancer()
+
+ assert enhancer is not None
+ assert isinstance(enhancer, MentalModelEnhancer)
+
+ def test_get_mental_model_enhancer_reuses_instance(self):
+ """Should reuse existing instance."""
+ # Reset global state
+ import agentgit.enhancers.mental_model as mm
+ mm._enhancer_instance = None
+
+ enhancer1 = get_mental_model_enhancer()
+ enhancer2 = get_mental_model_enhancer()
+
+ assert enhancer1 is enhancer2
+
+ def test_get_mental_model_enhancer_new_repo(self, tmp_path):
+ """Should create new instance for different repo."""
+ import agentgit.enhancers.mental_model as mm
+ mm._enhancer_instance = None
+
+ repo1 = tmp_path / "repo1"
+ repo1.mkdir()
+ repo2 = tmp_path / "repo2"
+ repo2.mkdir()
+
+ enhancer1 = get_mental_model_enhancer(repo1)
+ enhancer2 = get_mental_model_enhancer(repo2)
+
+ assert enhancer1._repo_path == repo1
+ assert enhancer2._repo_path == repo2
+
+
+class TestIntegration:
+ """Integration tests for the full reframe flow."""
+
+ def test_full_model_evolution(self, tmp_path):
+ """Test a complete model evolution scenario."""
+ repo_path = tmp_path / "repo"
+ repo_path.mkdir()
+
+ enhancer = MentalModelEnhancer(repo_path)
+
+ # Step 1: Initial model
+ enhancer.model.elements["browse"] = ModelElement(
+ id="browse", label="Product Browser", shape="stadium"
+ )
+ enhancer.model.elements["cart"] = ModelElement(
+ id="cart", label="Shopping Cart", shape="rounded"
+ )
+ enhancer.model.relations.append(
+ ModelRelation(source_id="browse", target_id="cart", label="add to")
+ )
+ enhancer.model.version = 1
+ enhancer.model.ai_summary = "E-commerce product browsing"
+ enhancer.save()
+
+ # Step 2: Human adds insight
+ enhancer.add_human_insight("The cart should support wishlists too")
+
+ # Step 3: Snapshot and evolve
+ enhancer.model.snapshot("before wishlist")
+ enhancer.model.elements["wishlist"] = ModelElement(
+ id="wishlist", label="Wishlist", shape="rounded"
+ )
+ enhancer.model.relations.append(
+ ModelRelation(source_id="browse", target_id="wishlist", label="save for later")
+ )
+ enhancer.model.version = 2
+ enhancer.save()
+
+ # Verify final state
+ loaded = load_mental_model(repo_path)
+ assert len(loaded.elements) == 3
+ assert "wishlist" in loaded.elements
+ assert loaded.version == 2
+
+ # Verify insights
+ insights = enhancer.load_insights()
+ assert "wishlists" in insights
+
+ # Verify snapshots
+ assert len(enhancer.model.snapshots) == 1
+ assert len(enhancer.model.snapshots[0].elements) == 2 # Before wishlist
+
+ def test_mermaid_output_is_valid(self, tmp_path):
+ """Test that generated mermaid is syntactically valid."""
+ model = MentalModel()
+
+ # Add various elements and relations
+ model.elements["api"] = ModelElement(
+ id="api", label="API Gateway", shape="hexagon", color="#e1f5fe"
+ )
+ model.elements["auth"] = ModelElement(
+ id="auth", label="Auth Service", shape="box", color="#fff3e0"
+ )
+ model.elements["db"] = ModelElement(
+ id="db", label="Database", shape="cylinder", color="#e8f5e9"
+ )
+
+ model.relations.append(
+ ModelRelation(source_id="api", target_id="auth", label="authenticates", style="solid")
+ )
+ model.relations.append(
+ ModelRelation(source_id="auth", target_id="db", label="queries", style="dashed")
+ )
+
+ mermaid = model.to_mermaid()
+
+ # Check structure
+ lines = mermaid.strip().split("\n")
+ assert lines[0].strip() == "graph TD"
+
+ # Check no syntax errors (basic validation)
+ assert mermaid.count("(") == mermaid.count(")")
+ assert mermaid.count("[") == mermaid.count("]")
+ assert mermaid.count("{") == mermaid.count("}")
diff --git a/tests/test_static_analysis.py b/tests/test_static_analysis.py
new file mode 100644
index 0000000..be835b1
--- /dev/null
+++ b/tests/test_static_analysis.py
@@ -0,0 +1,516 @@
+"""Tests for static analysis module."""
+
+import pytest
+import textwrap
+from pathlib import Path
+
+from agentgit.enhancers.static_analysis import (
+ CodeEntity,
+ CodeRelation,
+ CodeStructure,
+ analyze_python_simple,
+ analyze_javascript_simple,
+ analyze_code,
+ structure_to_context,
+)
+
+
+class TestCodeEntity:
+ """Tests for CodeEntity dataclass."""
+
+ def test_entity_creation(self):
+ """Should create entity with required fields."""
+ entity = CodeEntity(
+ name="UserService",
+ entity_type="class",
+ file_path="src/services/user.py",
+ )
+ assert entity.name == "UserService"
+ assert entity.entity_type == "class"
+ assert entity.file_path == "src/services/user.py"
+ assert entity.line_number is None
+ assert entity.docstring is None
+ assert entity.properties == {}
+
+ def test_entity_with_all_fields(self):
+ """Should create entity with all fields."""
+ entity = CodeEntity(
+ name="authenticate",
+ entity_type="function",
+ file_path="auth.py",
+ line_number=42,
+ docstring="Authenticate a user.",
+ properties={"async": True, "public": True},
+ )
+ assert entity.line_number == 42
+ assert entity.docstring == "Authenticate a user."
+ assert entity.properties["async"] is True
+
+
+class TestCodeRelation:
+ """Tests for CodeRelation dataclass."""
+
+ def test_relation_creation(self):
+ """Should create relation with required fields."""
+ rel = CodeRelation(
+ source="UserService",
+ target="Database",
+ relation_type="uses",
+ file_path="services.py",
+ )
+ assert rel.source == "UserService"
+ assert rel.target == "Database"
+ assert rel.relation_type == "uses"
+ assert rel.file_path == "services.py"
+
+ def test_relation_types(self):
+ """Should support various relation types."""
+ for rel_type in ["calls", "imports", "inherits", "implements", "uses"]:
+ rel = CodeRelation(
+ source="A",
+ target="B",
+ relation_type=rel_type,
+ file_path="test.py",
+ )
+ assert rel.relation_type == rel_type
+
+
+class TestCodeStructure:
+ """Tests for CodeStructure dataclass."""
+
+ def test_empty_structure(self):
+ """Should create empty structure."""
+ structure = CodeStructure()
+ assert structure.entities == []
+ assert structure.relations == []
+ assert structure.language == "unknown"
+ assert structure.analysis_method == "unknown"
+
+ def test_structure_with_data(self):
+ """Should hold entities and relations."""
+ structure = CodeStructure(
+ language="python",
+ analysis_method="regex",
+ )
+ structure.entities.append(CodeEntity(
+ name="TestClass",
+ entity_type="class",
+ file_path="test.py",
+ ))
+ structure.relations.append(CodeRelation(
+ source="test.py",
+ target="os",
+ relation_type="imports",
+ file_path="test.py",
+ ))
+
+ assert len(structure.entities) == 1
+ assert len(structure.relations) == 1
+ assert structure.language == "python"
+
+
+class TestAnalyzePythonSimple:
+ """Tests for analyze_python_simple function."""
+
+ def test_finds_classes(self, tmp_path):
+ """Should find class definitions."""
+ py_file = tmp_path / "test.py"
+ py_file.write_text("""
+class UserService:
+ pass
+
+class OrderService(BaseService):
+ pass
+""")
+
+ structure = analyze_python_simple(tmp_path)
+
+ class_entities = [e for e in structure.entities if e.entity_type == "class"]
+ assert len(class_entities) == 2
+ names = [e.name for e in class_entities]
+ assert "UserService" in names
+ assert "OrderService" in names
+
+ def test_finds_functions(self, tmp_path):
+ """Should find function definitions."""
+ py_file = tmp_path / "test.py"
+ py_file.write_text(textwrap.dedent("""\
+ def top_level_func():
+ pass
+
+ def another_function(arg):
+ return arg
+ """))
+
+ structure = analyze_python_simple(tmp_path)
+
+ func_entities = [e for e in structure.entities if e.entity_type == "function"]
+ assert len(func_entities) == 2
+ names = [e.name for e in func_entities]
+ assert "top_level_func" in names
+ assert "another_function" in names
+
+ def test_finds_methods(self, tmp_path):
+ """Should find method definitions."""
+ py_file = tmp_path / "test.py"
+ py_file.write_text("""
+class MyClass:
+ def my_method(self):
+ pass
+
+ def another_method(self):
+ pass
+""")
+
+ structure = analyze_python_simple(tmp_path)
+
+ method_entities = [e for e in structure.entities if e.entity_type == "method"]
+ assert len(method_entities) == 2
+
+ def test_tracks_inheritance(self, tmp_path):
+ """Should track inheritance relations."""
+ py_file = tmp_path / "test.py"
+ py_file.write_text("""
+class Child(Parent):
+ pass
+
+class MultiChild(Parent1, Parent2):
+ pass
+""")
+
+ structure = analyze_python_simple(tmp_path)
+
+ inherit_rels = [r for r in structure.relations if r.relation_type == "inherits"]
+ assert len(inherit_rels) >= 2
+
+ targets = [r.target for r in inherit_rels]
+ assert "Parent" in targets
+ assert "Parent1" in targets
+
+ def test_tracks_imports(self, tmp_path):
+ """Should track import relations."""
+ py_file = tmp_path / "test.py"
+ py_file.write_text("""
+import os
+from pathlib import Path
+from typing import List, Dict
+""")
+
+ structure = analyze_python_simple(tmp_path)
+
+ import_rels = [r for r in structure.relations if r.relation_type == "imports"]
+ assert len(import_rels) >= 3
+
+ targets = [r.target for r in import_rels]
+ assert "os" in targets
+ assert "pathlib.Path" in targets
+
+ def test_handles_single_file(self, tmp_path):
+ """Should handle single file path."""
+ py_file = tmp_path / "single.py"
+ py_file.write_text("""
+class SingleClass:
+ pass
+""")
+
+ structure = analyze_python_simple(py_file)
+
+ assert len(structure.entities) >= 1
+ assert structure.language == "python"
+ assert structure.analysis_method == "regex"
+
+ def test_ignores_object_base(self, tmp_path):
+ """Should ignore 'object' as base class."""
+ py_file = tmp_path / "test.py"
+ py_file.write_text("""
+class MyClass(object):
+ pass
+""")
+
+ structure = analyze_python_simple(tmp_path)
+
+ inherit_rels = [r for r in structure.relations if r.relation_type == "inherits"]
+ targets = [r.target for r in inherit_rels]
+ assert "object" not in targets
+
+ def test_includes_line_numbers(self, tmp_path):
+ """Should include line numbers."""
+ py_file = tmp_path / "test.py"
+ py_file.write_text("""
+# Line 1 is blank
+
+class MyClass:
+ pass
+""")
+
+ structure = analyze_python_simple(tmp_path)
+
+ class_entities = [e for e in structure.entities if e.entity_type == "class"]
+ assert len(class_entities) == 1
+ assert class_entities[0].line_number is not None
+ assert class_entities[0].line_number > 0
+
+
+class TestAnalyzeJavaScriptSimple:
+ """Tests for analyze_javascript_simple function."""
+
+ def test_finds_classes(self, tmp_path):
+ """Should find class definitions."""
+ js_file = tmp_path / "test.js"
+ js_file.write_text("""
+class UserService {
+ constructor() {}
+}
+
+class OrderService extends BaseService {
+ process() {}
+}
+""")
+
+ structure = analyze_javascript_simple(tmp_path)
+
+ class_entities = [e for e in structure.entities if e.entity_type == "class"]
+ assert len(class_entities) == 2
+ names = [e.name for e in class_entities]
+ assert "UserService" in names
+ assert "OrderService" in names
+
+ def test_tracks_extends(self, tmp_path):
+ """Should track extends relations."""
+ js_file = tmp_path / "test.js"
+ js_file.write_text("""
+class Child extends Parent {
+}
+""")
+
+ structure = analyze_javascript_simple(tmp_path)
+
+ extend_rels = [r for r in structure.relations if r.relation_type == "extends"]
+ assert len(extend_rels) == 1
+ assert extend_rels[0].target == "Parent"
+
+ def test_finds_functions(self, tmp_path):
+ """Should find function definitions."""
+ js_file = tmp_path / "test.js"
+ js_file.write_text("""
+function regularFunction() {
+ return 1;
+}
+
+const arrowFunc = () => {
+ return 2;
+};
+
+const anotherArrow = (x) => x * 2;
+""")
+
+ structure = analyze_javascript_simple(tmp_path)
+
+ func_entities = [e for e in structure.entities if e.entity_type == "function"]
+ names = [e.name for e in func_entities]
+ assert "regularFunction" in names
+ assert "arrowFunc" in names
+
+ def test_finds_react_components(self, tmp_path):
+ """Should find React components (PascalCase functions)."""
+ tsx_file = tmp_path / "test.tsx"
+ tsx_file.write_text(textwrap.dedent("""\
+ function UserProfile() {
+ return