diff --git a/src/agentgit/adaptive_model.py b/src/agentgit/adaptive_model.py new file mode 100644 index 0000000..e6962f8 --- /dev/null +++ b/src/agentgit/adaptive_model.py @@ -0,0 +1,564 @@ +""" +Adaptive Mental Model - AI-determined structure with direct manipulation. + +Key changes from previous approach: +1. NO prescribed node types or relationships - the AI decides what structure + makes sense based on what it observes in the codebase +2. Selection + instruction interface - draw a box around any part of the + diagram, describe what you want, and the model adapts + +The mental model becomes a TWO-WAY interface: +- AI → Human: Shows what the AI thinks the system is +- Human → AI: You reshape the model, which guides future AI understanding +""" + +from dataclasses import dataclass, field +from typing import Optional, Callable, Any +from datetime import datetime +import json +import copy + + +@dataclass +class ModelElement: + """ + A generic element in the mental model. + + Unlike the previous approach with fixed types (component, service, etc.), + elements are freeform. The AI decides what they represent and how to + visualize them. + """ + id: str + label: str + + # Freeform properties - AI decides what's relevant + properties: dict[str, Any] = field(default_factory=dict) + + # Visual hints (AI-suggested, can be overridden) + shape: str = "box" # box, circle, diamond, cylinder, etc. + color: Optional[str] = None + + # Provenance + created_by: str = "" # "ai" or "human" + reasoning: str = "" # Why this element exists + + def __hash__(self): + return hash(self.id) + + +@dataclass +class ModelRelation: + """A relationship between elements - also freeform.""" + source_id: str + target_id: str + + # AI decides what the relationship means + label: str = "" + properties: dict[str, Any] = field(default_factory=dict) + + # Visual hints + style: str = "solid" # solid, dashed, dotted + arrow: str = "normal" # normal, none, both + + created_by: str = "" + reasoning: str = "" + + +@dataclass +class SelectionInstruction: + """ + An instruction from the user about a selected region of the diagram. + + This is the core of the "draw a box and describe" interaction. + """ + selected_element_ids: list[str] + instruction: str # What the user wants to change + timestamp: datetime = field(default_factory=datetime.now) + + # After AI processes this + applied: bool = False + ai_interpretation: str = "" + + +@dataclass +class AdaptiveModel: + """ + A mental model where structure emerges from AI observation + human guidance. + """ + elements: dict[str, ModelElement] = field(default_factory=dict) + relations: list[ModelRelation] = field(default_factory=list) + + # History + version: int = 0 + snapshots: list[dict] = field(default_factory=list) + + # Pending human instructions + pending_instructions: list[SelectionInstruction] = field(default_factory=list) + + # The AI's current "understanding" - a freeform description + ai_summary: str = "" + + def snapshot(self) -> None: + """Save current state for time travel.""" + self.snapshots.append({ + "version": self.version, + "timestamp": datetime.now().isoformat(), + "elements": {eid: self._element_to_dict(e) for eid, e in self.elements.items()}, + "relations": [self._relation_to_dict(r) for r in self.relations], + "ai_summary": self.ai_summary, + }) + + def _element_to_dict(self, e: ModelElement) -> dict: + return { + "id": e.id, "label": e.label, "properties": e.properties, + "shape": e.shape, "color": e.color, + "created_by": e.created_by, "reasoning": e.reasoning, + } + + def _relation_to_dict(self, r: ModelRelation) -> dict: + return { + "source_id": r.source_id, "target_id": r.target_id, + "label": r.label, "properties": r.properties, + "style": r.style, "arrow": r.arrow, + "created_by": r.created_by, "reasoning": r.reasoning, + } + + def get_version(self, version: int) -> Optional[dict]: + """Get state at a specific version.""" + for snap in self.snapshots: + if snap["version"] == version: + return snap + return None + + def to_mermaid(self) -> str: + """Export as Mermaid - let the structure speak for itself.""" + lines = ["graph TD"] + + # Shape mapping + shapes = { + "box": ("[", "]"), + "rounded": ("(", ")"), + "circle": ("((", "))"), + "diamond": ("{", "}"), + "cylinder": ("[(", ")]"), + "hexagon": ("{{", "}}"), + "stadium": ("([", "])"), + } + + for elem in self.elements.values(): + left, right = shapes.get(elem.shape, ("[", "]")) + safe_label = elem.label.replace('"', "'") + lines.append(f' {elem.id}{left}"{safe_label}"{right}') + + # Relations + arrow_styles = { + ("solid", "normal"): "-->", + ("solid", "none"): "---", + ("solid", "both"): "<-->", + ("dashed", "normal"): "-.->", + ("dashed", "none"): "-.-", + ("dotted", "normal"): "..>", + } + + for rel in self.relations: + arrow = arrow_styles.get((rel.style, rel.arrow), "-->") + if rel.label: + lines.append(f' {rel.source_id} {arrow}|"{rel.label}"| {rel.target_id}') + else: + lines.append(f' {rel.source_id} {arrow} {rel.target_id}') + + # Colors + for elem in self.elements.values(): + if elem.color: + lines.append(f' style {elem.id} fill:{elem.color}') + + return "\n".join(lines) + + def to_json(self) -> str: + return json.dumps({ + "elements": {eid: self._element_to_dict(e) for eid, e in self.elements.items()}, + "relations": [self._relation_to_dict(r) for r in self.relations], + "version": self.version, + "ai_summary": self.ai_summary, + "snapshots": self.snapshots, + }, indent=2) + + +def generate_observation_prompt( + current_model: AdaptiveModel, + context: dict, # Freeform context about what changed +) -> str: + """ + Generate a prompt for the AI to observe changes and decide how to update + the mental model. No prescribed structure - AI decides. + """ + + current_diagram = current_model.to_mermaid() if current_model.elements else "(empty)" + + # Include any pending human instructions + human_guidance = "" + if current_model.pending_instructions: + instructions = current_model.pending_instructions + human_guidance = "\n\n## Human Guidance\nThe user has provided these instructions about the model:\n" + for inst in instructions: + selected = ", ".join(inst.selected_element_ids) if inst.selected_element_ids else "general" + human_guidance += f"\n- Selection: [{selected}]\n Instruction: {inst.instruction}\n" + + prompt = f"""You are observing a codebase and maintaining a "mental model" - a diagram that +captures how the system works conceptually. + +## Current Mental Model +```mermaid +{current_diagram} +``` + +{f"Current AI understanding: {current_model.ai_summary}" if current_model.ai_summary else ""} +{human_guidance} + +## Recent Changes +{json.dumps(context, indent=2)} + +## Your Task + +Decide how (if at all) to update the mental model. You have complete freedom in how you +structure it - there are no required categories or relationships. Choose whatever +representation best captures how this system works. + +Consider: +- What are the key concepts/capabilities/flows? +- How do they relate to each other? +- What level of abstraction is most useful? +- Has the user provided guidance that should reshape your understanding? + +Respond with JSON: +```json +{{ + "thinking": "Your reasoning about what this system is and how to represent it", + "summary": "A 1-2 sentence description of what this system does (update your understanding)", + "updates": {{ + "add_elements": [ + {{ + "id": "unique_id", + "label": "Human readable name", + "shape": "box|rounded|circle|diamond|cylinder|hexagon|stadium", + "color": "#hex or null", + "reasoning": "Why this element exists in the model" + }} + ], + "remove_element_ids": ["id1", "id2"], + "modify_elements": [ + {{"id": "existing_id", "label": "new label", "shape": "new_shape"}} + ], + "add_relations": [ + {{ + "source_id": "from", + "target_id": "to", + "label": "relationship description", + "style": "solid|dashed|dotted", + "reasoning": "Why this relationship matters" + }} + ], + "remove_relations": [["source_id", "target_id"]] + }} +}} +``` + +If no updates needed: {{"thinking": "...", "summary": "...", "updates": null}} +""" + return prompt + + +def generate_instruction_prompt( + model: AdaptiveModel, + instruction: SelectionInstruction, +) -> str: + """ + Generate a prompt to process a human instruction about a selected region. + """ + + # Get the selected elements + selected_elements = [model.elements[eid] for eid in instruction.selected_element_ids + if eid in model.elements] + + # Get relations involving selected elements + selected_ids = set(instruction.selected_element_ids) + related_relations = [r for r in model.relations + if r.source_id in selected_ids or r.target_id in selected_ids] + + selection_desc = "" + if selected_elements: + selection_desc = "Selected elements:\n" + for elem in selected_elements: + selection_desc += f"- {elem.label} ({elem.id}): {elem.reasoning}\n" + if related_relations: + selection_desc += "\nRelated connections:\n" + for rel in related_relations: + selection_desc += f"- {rel.source_id} --{rel.label}--> {rel.target_id}\n" + else: + selection_desc = "(No specific elements selected - instruction applies to whole model)" + + prompt = f"""The user has selected part of the mental model and provided an instruction. + +## Current Model +```mermaid +{model.to_mermaid()} +``` + +## Selection +{selection_desc} + +## User Instruction +"{instruction.instruction}" + +## Your Task + +Interpret the user's instruction and determine how to update the model. The user is +telling you how they think about this part of the system - incorporate their perspective. + +Respond with JSON: +```json +{{ + "interpretation": "How you understand the user's instruction", + "updates": {{ + "add_elements": [...], + "remove_element_ids": [...], + "modify_elements": [...], + "add_relations": [...], + "remove_relations": [...] + }} +}} +``` +""" + return prompt + + +class AdaptiveModelController: + """ + Controller that mediates between the model, AI, and human. + """ + + def __init__( + self, + ai_callback: Callable[[str], str], + model: Optional[AdaptiveModel] = None, + ): + self.model = model or AdaptiveModel() + self.ai_callback = ai_callback + + def observe(self, context: dict) -> dict: + """ + Have the AI observe changes and update the model. + Returns the AI's response. + """ + self.model.snapshot() + + prompt = generate_observation_prompt(self.model, context) + response = self.ai_callback(prompt) + + # Parse and apply + result = self._parse_response(response) + if result and result.get("updates"): + self._apply_updates(result["updates"], created_by="ai") + + if result and result.get("summary"): + self.model.ai_summary = result["summary"] + + # Clear processed instructions + self.model.pending_instructions = [] + self.model.version += 1 + + return result or {} + + def instruct(self, selected_ids: list[str], instruction: str) -> dict: + """ + Process a human instruction about selected elements. + This is the "draw a box and describe" interaction. + """ + inst = SelectionInstruction( + selected_element_ids=selected_ids, + instruction=instruction, + ) + + self.model.snapshot() + + prompt = generate_instruction_prompt(self.model, inst) + response = self.ai_callback(prompt) + + result = self._parse_response(response) + if result and result.get("updates"): + self._apply_updates(result["updates"], created_by="human") + + inst.applied = True + inst.ai_interpretation = result.get("interpretation", "") if result else "" + + self.model.version += 1 + + return result or {} + + def queue_instruction(self, selected_ids: list[str], instruction: str) -> None: + """ + Queue a human instruction to be considered on next observation. + Use this when you want to guide future AI updates without immediately applying. + """ + self.model.pending_instructions.append(SelectionInstruction( + selected_element_ids=selected_ids, + instruction=instruction, + )) + + def _parse_response(self, response: str) -> Optional[dict]: + """Extract JSON from AI response.""" + import re + + json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group(1)) + except json.JSONDecodeError: + pass + + # Try raw JSON + try: + return json.loads(response) + except json.JSONDecodeError: + return None + + def _apply_updates(self, updates: dict, created_by: str) -> None: + """Apply updates to the model.""" + + # Remove elements first + for eid in updates.get("remove_element_ids", []): + if eid in self.model.elements: + del self.model.elements[eid] + # Also remove related relations + self.model.relations = [ + r for r in self.model.relations + if r.source_id != eid and r.target_id != eid + ] + + # Remove relations + for source_id, target_id in updates.get("remove_relations", []): + self.model.relations = [ + r for r in self.model.relations + if not (r.source_id == source_id and r.target_id == target_id) + ] + + # Add elements + for elem_data in updates.get("add_elements", []): + elem = ModelElement( + id=elem_data["id"], + label=elem_data["label"], + shape=elem_data.get("shape", "box"), + color=elem_data.get("color"), + properties=elem_data.get("properties", {}), + created_by=created_by, + reasoning=elem_data.get("reasoning", ""), + ) + self.model.elements[elem.id] = elem + + # Modify elements + for mod in updates.get("modify_elements", []): + if mod["id"] in self.model.elements: + elem = self.model.elements[mod["id"]] + if "label" in mod: + elem.label = mod["label"] + if "shape" in mod: + elem.shape = mod["shape"] + if "color" in mod: + elem.color = mod["color"] + if "properties" in mod: + elem.properties.update(mod["properties"]) + + # Add relations + for rel_data in updates.get("add_relations", []): + rel = ModelRelation( + source_id=rel_data["source_id"], + target_id=rel_data["target_id"], + label=rel_data.get("label", ""), + style=rel_data.get("style", "solid"), + arrow=rel_data.get("arrow", "normal"), + properties=rel_data.get("properties", {}), + created_by=created_by, + reasoning=rel_data.get("reasoning", ""), + ) + self.model.relations.append(rel) + + +# Demo with simulated AI +if __name__ == "__main__": + # Simulated AI that responds to different contexts + def mock_ai(prompt: str) -> str: + if "e-commerce" in prompt.lower() or "product" in prompt.lower(): + return """```json +{ + "thinking": "This looks like an e-commerce system. I see product-related code and what appears to be order processing. Rather than using generic 'component' types, I'll model this as a flow: browsing -> cart -> checkout.", + "summary": "An e-commerce platform where users browse products, add to cart, and checkout", + "updates": { + "add_elements": [ + {"id": "browse", "label": "Browse & Search", "shape": "stadium", "color": "#e3f2fd", "reasoning": "Entry point - users discovering products"}, + {"id": "cart", "label": "Shopping Cart", "shape": "rounded", "color": "#fff3e0", "reasoning": "Accumulation state before purchase"}, + {"id": "checkout", "label": "Purchase Flow", "shape": "hexagon", "color": "#e8f5e9", "reasoning": "Critical path - where money changes hands"} + ], + "add_relations": [ + {"source_id": "browse", "target_id": "cart", "label": "add items", "style": "solid"}, + {"source_id": "cart", "target_id": "checkout", "label": "proceed", "style": "solid"} + ] + } +} +```""" + elif "actually three separate" in prompt.lower() or "split" in prompt.lower(): + # Response to human instruction to split something + return """```json +{ + "interpretation": "The user sees the checkout as three distinct phases that should be modeled separately", + "updates": { + "remove_element_ids": ["checkout"], + "add_elements": [ + {"id": "shipping", "label": "Shipping Info", "shape": "rounded", "color": "#e8f5e9", "reasoning": "User specified: first phase of checkout"}, + {"id": "payment", "label": "Payment", "shape": "hexagon", "color": "#ffebee", "reasoning": "User specified: critical payment phase"}, + {"id": "confirm", "label": "Confirmation", "shape": "rounded", "color": "#e8f5e9", "reasoning": "User specified: final confirmation"} + ], + "add_relations": [ + {"source_id": "cart", "target_id": "shipping", "label": "begin checkout", "style": "solid"}, + {"source_id": "shipping", "target_id": "payment", "label": "next", "style": "solid"}, + {"source_id": "payment", "target_id": "confirm", "label": "complete", "style": "solid"} + ] + } +} +```""" + else: + return """```json +{"thinking": "No significant changes needed", "summary": "System unchanged", "updates": null} +```""" + + # Create controller + controller = AdaptiveModelController(ai_callback=mock_ai) + + print("=== Initial Observation ===") + print("AI observes: 'e-commerce codebase with products and orders'\n") + + result = controller.observe({ + "files_changed": ["src/products/catalog.py", "src/orders/checkout.py"], + "intent": "Building an e-commerce platform", + }) + + print(f"AI thinking: {result.get('thinking', '')[:100]}...") + print(f"AI summary: {result.get('summary', '')}") + print(f"\nDiagram:\n{controller.model.to_mermaid()}") + + print("\n" + "="*50) + print("=== Human Instruction ===") + print("User selects 'checkout' and says: 'This is actually three separate steps'") + print() + + result = controller.instruct( + selected_ids=["checkout"], + instruction="This is actually three separate steps: shipping info, payment, and confirmation. Split it up." + ) + + print(f"AI interpretation: {result.get('interpretation', '')}") + print(f"\nUpdated diagram:\n{controller.model.to_mermaid()}") + + print("\n" + "="*50) + print("=== Version History ===") + for i, snap in enumerate(controller.model.snapshots): + print(f"v{snap['version']}: {len(snap['elements'])} elements - {snap['ai_summary'][:50]}...") diff --git a/src/agentgit/enhancers/mental_model.py b/src/agentgit/enhancers/mental_model.py new file mode 100644 index 0000000..fdb3371 --- /dev/null +++ b/src/agentgit/enhancers/mental_model.py @@ -0,0 +1,1152 @@ +"""Reframe - A living mental model that evolves with your codebase. + +Reframe maintains a semantic visualization of your system that automatically +adapts as code changes. It's not a file tree or class diagram - it's how +you and an AI collaboratively understand what the system *does*. + +Core concepts: +- AI observes code changes and proposes model updates +- Human can "draw a box and describe" to reshape understanding +- Insights accumulate over time, creating shared context +- Time travel through how understanding evolved + +The model structure is NOT prescribed - the AI decides what abstraction +level and relationships best capture the system. + +Name origin: "Reframe" - to see something through a different frame/lens. +From cognitive psychology, where reframing means shifting your mental model. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable + +from agentgit.plugins import hookimpl, hookspec + +if TYPE_CHECKING: + from agentgit.core import AssistantTurn, FileOperation, Prompt, PromptResponse + +logger = logging.getLogger(__name__) + +ENHANCER_NAME = "mental_model" + + +# ============================================================================ +# Data Structures (freeform - AI decides structure) +# ============================================================================ + + +@dataclass +class ModelElement: + """A freeform element in the mental model. AI decides what it represents.""" + + id: str + label: str + properties: dict[str, Any] = field(default_factory=dict) + shape: str = "box" + color: str | None = None + created_by: str = "" # "ai" or "human" + reasoning: str = "" + + def to_dict(self) -> dict: + return { + "id": self.id, + "label": self.label, + "properties": self.properties, + "shape": self.shape, + "color": self.color, + "created_by": self.created_by, + "reasoning": self.reasoning, + } + + @classmethod + def from_dict(cls, data: dict) -> "ModelElement": + return cls( + id=data["id"], + label=data["label"], + properties=data.get("properties", {}), + shape=data.get("shape", "box"), + color=data.get("color"), + created_by=data.get("created_by", ""), + reasoning=data.get("reasoning", ""), + ) + + +@dataclass +class ModelRelation: + """A freeform relationship. AI decides what it means.""" + + source_id: str + target_id: str + label: str = "" + properties: dict[str, Any] = field(default_factory=dict) + style: str = "solid" + created_by: str = "" + reasoning: str = "" + + def to_dict(self) -> dict: + return { + "source_id": self.source_id, + "target_id": self.target_id, + "label": self.label, + "properties": self.properties, + "style": self.style, + "created_by": self.created_by, + "reasoning": self.reasoning, + } + + @classmethod + def from_dict(cls, data: dict) -> "ModelRelation": + return cls( + source_id=data["source_id"], + target_id=data["target_id"], + label=data.get("label", ""), + properties=data.get("properties", {}), + style=data.get("style", "solid"), + created_by=data.get("created_by", ""), + reasoning=data.get("reasoning", ""), + ) + + +@dataclass +class SequenceStep: + """A step in a sequence diagram.""" + + from_participant: str + to_participant: str + message: str + step_type: str = "sync" # sync, async, reply, note + note: str | None = None + + +@dataclass +class SequenceDiagram: + """A sequence diagram showing interactions over time. + + Useful for showing: + - Request/response flows + - Event propagation + - User journeys through the system + """ + + id: str + title: str + participants: list[str] = field(default_factory=list) + steps: list[SequenceStep] = field(default_factory=list) + description: str = "" + created_by: str = "" + trigger: str = "" # What flow this represents (e.g., "user login") + + def to_mermaid(self) -> str: + """Export as Mermaid sequence diagram.""" + lines = ["sequenceDiagram"] + + if self.title: + lines.append(f" title {self.title}") + + # Declare participants in order + for p in self.participants: + # Sanitize participant names for mermaid + safe_p = p.replace(" ", "_") + if p != safe_p: + lines.append(f" participant {safe_p} as {p}") + else: + lines.append(f" participant {p}") + + # Add steps + for step in self.steps: + arrow = { + "sync": "->>", + "async": "-->>", + "reply": "-->>", + "create": "->>+", + "destroy": "->>-", + }.get(step.step_type, "->>") + + from_p = step.from_participant.replace(" ", "_") + to_p = step.to_participant.replace(" ", "_") + lines.append(f" {from_p}{arrow}{to_p}: {step.message}") + + if step.note: + lines.append(f" Note over {from_p},{to_p}: {step.note}") + + return "\n".join(lines) + + def to_dict(self) -> dict: + return { + "id": self.id, + "title": self.title, + "participants": self.participants, + "steps": [ + { + "from": s.from_participant, + "to": s.to_participant, + "message": s.message, + "type": s.step_type, + "note": s.note, + } + for s in self.steps + ], + "description": self.description, + "created_by": self.created_by, + "trigger": self.trigger, + } + + @classmethod + def from_dict(cls, data: dict) -> "SequenceDiagram": + diagram = cls( + id=data.get("id", ""), + title=data.get("title", ""), + participants=data.get("participants", []), + description=data.get("description", ""), + created_by=data.get("created_by", ""), + trigger=data.get("trigger", ""), + ) + for step_data in data.get("steps", []): + diagram.steps.append(SequenceStep( + from_participant=step_data["from"], + to_participant=step_data["to"], + message=step_data["message"], + step_type=step_data.get("type", "sync"), + note=step_data.get("note"), + )) + return diagram + + +@dataclass +class ModelSnapshot: + """A point-in-time snapshot for time travel.""" + + version: int + timestamp: datetime + elements: dict[str, ModelElement] + relations: list[ModelRelation] + ai_summary: str + trigger: str # What caused this snapshot + + +@dataclass +class MentalModel: + """The living mental model - structure emerges from AI observation. + + Contains two types of diagrams: + 1. Component diagram (elements + relations) - the "what" + 2. Sequence diagrams - the "how" (flows, interactions) + """ + + # Component diagram + elements: dict[str, ModelElement] = field(default_factory=dict) + relations: list[ModelRelation] = field(default_factory=list) + + # Sequence diagrams for showing flows + sequences: dict[str, SequenceDiagram] = field(default_factory=dict) + + version: int = 0 + ai_summary: str = "" + snapshots: list[ModelSnapshot] = field(default_factory=list) + + # Files associated with each element (for focused prompting) + element_files: dict[str, set[str]] = field(default_factory=dict) + + def snapshot(self, trigger: str = "") -> None: + """Save current state for time travel.""" + import copy + + self.snapshots.append( + ModelSnapshot( + version=self.version, + timestamp=datetime.now(), + elements=copy.deepcopy(self.elements), + relations=copy.deepcopy(self.relations), + ai_summary=self.ai_summary, + trigger=trigger, + ) + ) + + def to_mermaid(self) -> str: + """Export as Mermaid diagram.""" + lines = ["graph TD"] + + shapes = { + "box": ("[", "]"), + "rounded": ("(", ")"), + "circle": ("((", "))"), + "diamond": ("{", "}"), + "cylinder": ("[(", ")]"), + "hexagon": ("{{", "}}"), + "stadium": ("([", "])"), + } + + for elem in self.elements.values(): + left, right = shapes.get(elem.shape, ("[", "]")) + safe_label = elem.label.replace('"', "'") + lines.append(f' {elem.id}{left}"{safe_label}"{right}') + + arrow_styles = { + "solid": "-->", + "dashed": "-.->", + "dotted": "..>", + "thick": "==>", + } + + for rel in self.relations: + arrow = arrow_styles.get(rel.style, "-->") + if rel.label: + lines.append(f' {rel.source_id} {arrow}|"{rel.label}"| {rel.target_id}') + else: + lines.append(f' {rel.source_id} {arrow} {rel.target_id}') + + for elem in self.elements.values(): + if elem.color: + lines.append(f" style {elem.id} fill:{elem.color}") + + return "\n".join(lines) + + def to_dict(self) -> dict: + return { + "elements": {eid: e.to_dict() for eid, e in self.elements.items()}, + "relations": [r.to_dict() for r in self.relations], + "sequences": {sid: s.to_dict() for sid, s in self.sequences.items()}, + "version": self.version, + "ai_summary": self.ai_summary, + "element_files": {k: list(v) for k, v in self.element_files.items()}, + } + + def to_json(self) -> str: + return json.dumps(self.to_dict(), indent=2) + + def to_full_mermaid(self) -> str: + """Export both component diagram and sequence diagrams.""" + parts = ["## Component Diagram\n", "```mermaid", self.to_mermaid(), "```"] + + for seq_id, seq in self.sequences.items(): + parts.append(f"\n## {seq.title or seq_id}\n") + if seq.description: + parts.append(f"{seq.description}\n") + parts.append("```mermaid") + parts.append(seq.to_mermaid()) + parts.append("```") + + return "\n".join(parts) + + def to_force_graph(self) -> dict: + """Export model for react-force-graph visualization. + + Returns a dict with 'nodes' and 'links' arrays compatible with + react-force-graph-2d/3d libraries. + + Node properties: + - id: unique identifier + - name: display label + - group: for coloring/clustering (from properties or shape) + - color: explicit color if set + - val: node size (from properties.size or default 1) + + Link properties: + - source: source node id + - target: target node id + - label: relationship label + - style: line style (solid, dashed, etc.) + """ + nodes = [] + for elem in self.elements.values(): + node = { + "id": elem.id, + "name": elem.label, + "group": elem.properties.get("group", elem.shape), + "val": elem.properties.get("size", 1), + } + if elem.color: + node["color"] = elem.color + if elem.reasoning: + node["description"] = elem.reasoning + # Include associated files for context + if elem.id in self.element_files: + node["files"] = list(self.element_files[elem.id]) + nodes.append(node) + + links = [] + for rel in self.relations: + link = { + "source": rel.source_id, + "target": rel.target_id, + } + if rel.label: + link["label"] = rel.label + if rel.style and rel.style != "solid": + link["style"] = rel.style + if rel.reasoning: + link["description"] = rel.reasoning + links.append(link) + + return {"nodes": nodes, "links": links} + + @classmethod + def from_dict(cls, data: dict) -> "MentalModel": + model = cls( + version=data.get("version", 0), + ai_summary=data.get("ai_summary", ""), + ) + for eid, edata in data.get("elements", {}).items(): + model.elements[eid] = ModelElement.from_dict(edata) + for rdata in data.get("relations", []): + model.relations.append(ModelRelation.from_dict(rdata)) + for sid, sdata in data.get("sequences", {}).items(): + model.sequences[sid] = SequenceDiagram.from_dict(sdata) + for eid, files in data.get("element_files", {}).items(): + model.element_files[eid] = set(files) + return model + + +# ============================================================================ +# LLM Integration (uses same pattern as llm.py) +# ============================================================================ + + +def _get_llm_model(model: str = "claude-cli-haiku"): + """Get LLM model instance.""" + try: + import llm + + return llm.get_model(model) + except ImportError: + logger.warning("llm not installed") + return None + except Exception as e: + logger.warning("Failed to get model: %s", e) + return None + + +def _run_llm(prompt: str, model: str = "claude-cli-haiku") -> str | None: + """Run prompt through LLM.""" + llm_model = _get_llm_model(model) + if not llm_model: + return None + try: + return llm_model.prompt(prompt).text() + except Exception as e: + logger.warning("LLM request failed: %s", e) + return None + + +def _truncate(text: str, max_len: int) -> str: + if len(text) <= max_len: + return text + return text[: max_len - 3] + "..." + + +# ============================================================================ +# AI Prompts (AI decides structure - no prescribed types) +# ============================================================================ + + +def build_observation_prompt( + model: MentalModel, + context: dict, + accumulated_insights: str = "", +) -> str: + """Build prompt for AI to observe changes and update model. + + The AI receives: + - Current model structure (JSON/Mermaid) + - Code changes (files, intents, reasoning) + - Accumulated insights (previous AI + human observations) + + And outputs: + - Updated model structure + - New insight to append to the insights document + """ + + current_diagram = model.to_mermaid() if model.elements else "(empty)" + + insights_section = "" + if accumulated_insights: + # Truncate if too long, keeping most recent + if len(accumulated_insights) > 3000: + accumulated_insights = "...(earlier insights truncated)...\n" + accumulated_insights[-3000:] + insights_section = f""" +## Accumulated Insights +These are previous observations and refinements from both AI and human: + +{accumulated_insights} +""" + + return f"""You are maintaining a "mental model" of a codebase - a conceptual diagram +of how the system works, not its file structure. + +## Current Mental Model +```mermaid +{current_diagram} +``` + +{f"Current understanding: {model.ai_summary}" if model.ai_summary else ""} +{insights_section} + +## Recent Code Changes +{json.dumps(context, indent=2)} + +## Your Task + +1. **Review** the accumulated insights - they contain important context from + previous observations and human refinements. + +2. **Analyze** how these code changes affect the conceptual structure. + +3. **Update** the model if needed. You have complete freedom in structure: + - Choose whatever elements/relationships capture the system best + - Pick shapes (box, rounded, circle, diamond, cylinder, hexagon, stadium) + - Use colors to indicate groupings or importance + - Let the right abstraction level emerge from what you observe + +4. **Document** your insight - what did you learn about the system? + This will be added to the insights document for future reference. + +Respond with JSON: +```json +{{ + "insight": "What you learned or observed about the system (1-3 sentences). This gets appended to the insights document.", + "summary": "1-2 sentence description of what this system does (updated if needed)", + "updates": {{ + "add_elements": [ + {{"id": "unique_id", "label": "Name", "shape": "box", "color": "#hex", "reasoning": "why"}} + ], + "remove_element_ids": ["id"], + "modify_elements": [{{"id": "x", "label": "new", "shape": "new"}}], + "add_relations": [ + {{"source_id": "a", "target_id": "b", "label": "relationship", "style": "solid", "reasoning": "why"}} + ], + "remove_relations": [["source_id", "target_id"]] + }} +}} +``` + +If no structural changes needed, still provide an insight: +{{"insight": "...", "summary": "...", "updates": null}} +""" + + +def build_instruction_prompt(model: MentalModel, selected_ids: list[str], instruction: str) -> str: + """Build prompt for AI to process human instruction about selected region.""" + + selected_elements = [model.elements[eid] for eid in selected_ids if eid in model.elements] + + selection_desc = "" + if selected_elements: + selection_desc = "Selected elements:\n" + for elem in selected_elements: + selection_desc += f"- {elem.label} ({elem.id}): {elem.reasoning}\n" + else: + selection_desc = "(No specific elements - instruction applies to whole model)" + + return f"""The user has selected part of the mental model and given an instruction. + +## Current Model +```mermaid +{model.to_mermaid()} +``` + +## Selection +{selection_desc} + +## User Instruction +"{instruction}" + +Interpret this instruction and update the model accordingly. +The user is reshaping how they think about this part of the system. + +Respond with JSON: +```json +{{ + "interpretation": "How you understand the instruction", + "updates": {{ + "add_elements": [...], + "remove_element_ids": [...], + "modify_elements": [...], + "add_relations": [...], + "remove_relations": [...] + }} +}} +``` +""" + + +def build_focused_prompt(model: MentalModel, element_id: str, intent: str) -> str: + """Generate a focused prompt for interacting with a specific element.""" + + elem = model.elements.get(element_id) + if not elem: + return f"Element {element_id} not found in model." + + # Find connected elements + connected = [] + for rel in model.relations: + if rel.source_id == element_id and rel.target_id in model.elements: + connected.append(f"{model.elements[rel.target_id].label} ({rel.label})") + elif rel.target_id == element_id and rel.source_id in model.elements: + connected.append(f"{model.elements[rel.source_id].label} ({rel.label})") + + files = list(model.element_files.get(element_id, [])) + + context = f"**{elem.label}**" + if elem.reasoning: + context += f"\n{elem.reasoning}" + if connected: + context += f"\nConnected to: {', '.join(connected)}" + if files: + context += f"\nRelated files: {', '.join(files[:5])}" + + prompts = { + "explore": f"Tell me about the {elem.label}. What does it do and how does it fit in?", + "modify": f"I want to change the {elem.label}. What files should I look at?", + "debug": f"I'm having issues with {elem.label}. Help me understand how it works.", + "test": f"I want to add tests for {elem.label}. What should I test?", + "refactor": f"I'm considering refactoring {elem.label}. What would be the impact?", + } + + prompt = prompts.get(intent, prompts["explore"]) + if files: + prompt += f"\n\nRelevant files: {', '.join(files[:3])}" + + return prompt + + +# ============================================================================ +# Enhancer Plugin +# ============================================================================ + + +class MentalModelEnhancer: + """Enhancer that maintains a living mental model of the codebase. + + The mental model is stored in the agentgit output repository at: + .agentgit/mental_model.json + + This keeps it alongside the git history, making it part of the + artifact that agentgit produces. + """ + + # Standard location within agentgit repos + MODEL_FILENAME = ".agentgit/mental_model.json" + + def __init__(self, repo_path: Path | None = None): + """Initialize the enhancer. + + Args: + repo_path: Path to the agentgit output repo. If provided, + loads existing model from .agentgit/mental_model.json + """ + self.model = MentalModel() + self._repo_path: Path | None = None + + if repo_path: + self.set_repo_path(repo_path) + + def set_repo_path(self, repo_path: Path) -> None: + """Set the agentgit repo path and load existing model if present.""" + self._repo_path = repo_path + model_path = repo_path / self.MODEL_FILENAME + if model_path.exists(): + self.model = MentalModel.from_dict(json.loads(model_path.read_text())) + + @property + def model_path(self) -> Path | None: + """Get the full path to the mental model file.""" + if self._repo_path: + return self._repo_path / self.MODEL_FILENAME + return None + + def save(self) -> None: + """Save model to the agentgit repo.""" + if self.model_path: + self.model_path.parent.mkdir(parents=True, exist_ok=True) + self.model_path.write_text(self.model.to_json()) + + def save_insights(self, new_insight: str | None = None) -> Path | None: + """Append insights to the collaborative insights document. + + The insights file is a living document where both AI and human + contribute understanding. On each update, the AI reads this file + along with the JSON and code diff to inform its interpretation. + + Args: + new_insight: New insight to append (from AI or human). + + Returns: + Path to the insights file, or None if no repo configured. + """ + if not self._repo_path: + return None + + insights_path = self._repo_path / ".agentgit" / "mental_model.md" + insights_path.parent.mkdir(parents=True, exist_ok=True) + + # Initialize file if it doesn't exist + if not insights_path.exists(): + initial_content = """# Reframe Insights + +This document captures the evolving understanding of this system's architecture. +Both AI observations and human refinements accumulate here, creating a shared +mental model that improves over time. + +Edit this file directly to add your own insights - they'll be read by the AI +on the next update to inform its understanding. + +--- + +""" + insights_path.write_text(initial_content) + + # Append new insight if provided + if new_insight: + from datetime import datetime + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") + with open(insights_path, "a") as f: + f.write(f"\n## {timestamp}\n\n{new_insight}\n") + + return insights_path + + def load_insights(self) -> str: + """Load the accumulated insights for context. + + Returns: + The contents of the insights file, or empty string if not found. + """ + if not self._repo_path: + return "" + + insights_path = self._repo_path / ".agentgit" / "mental_model.md" + if insights_path.exists(): + return insights_path.read_text() + return "" + + def add_human_insight(self, insight: str) -> None: + """Add a human-provided insight to the document. + + Use this when the human wants to correct or refine the model's + understanding without making structural changes. + """ + self.save_insights(f"**Human insight:**\n\n{insight}") + + @hookimpl + def agentgit_get_enhancer_info(self) -> dict[str, str]: + return { + "name": ENHANCER_NAME, + "description": "Maintain a living mental model diagram of the codebase", + } + + def observe_changes( + self, + prompt_responses: list["PromptResponse"], + model: str = "claude-cli-haiku", + ) -> dict | None: + """ + Observe code changes and update the mental model. + + This is the main entry point - call after processing a transcript. + + The AI receives: + - Current model structure + - Code changes (files, intents, reasoning) + - Accumulated insights from previous observations + human input + + The AI outputs: + - Updated model structure + - New insight to append to the insights document + """ + # Build context from changes + context = {"prompts": [], "files_changed": [], "reasoning": []} + + for pr in prompt_responses[:5]: # Look at recent prompts + context["prompts"].append(_truncate(pr.prompt.text, 200)) + for turn in pr.turns: + for op in turn.operations: + context["files_changed"].append(op.file_path) + if turn.context and turn.context.summary: + context["reasoning"].append(_truncate(turn.context.summary, 200)) + + context["files_changed"] = list(set(context["files_changed"]))[:20] + + # Load accumulated insights for context + accumulated_insights = self.load_insights() + + # Snapshot before update + self.model.snapshot(trigger="observation") + + # Ask AI to interpret changes (with insights as context) + prompt = build_observation_prompt(self.model, context, accumulated_insights) + response = _run_llm(prompt, model) + + if not response: + return None + + result = self._parse_response(response) + + # Apply structural updates + if result and result.get("updates"): + self._apply_updates(result["updates"], "ai", context["files_changed"]) + + if result and result.get("summary"): + self.model.ai_summary = result["summary"] + + # Save new insight to the insights document + if result and result.get("insight"): + self.save_insights(f"**AI observation:**\n\n{result['insight']}") + + self.model.version += 1 + self.save() + + return result + + def process_instruction( + self, + selected_ids: list[str], + instruction: str, + model: str = "claude-cli-haiku", + ) -> dict | None: + """ + Process a human instruction about selected elements. + + This is the "draw a box and describe" interaction. + """ + self.model.snapshot(trigger=f"instruction: {instruction[:50]}") + + prompt = build_instruction_prompt(self.model, selected_ids, instruction) + response = _run_llm(prompt, model) + + if not response: + return None + + result = self._parse_response(response) + if result and result.get("updates"): + self._apply_updates(result["updates"], "human", []) + + self.model.version += 1 + self.save() + + return result + + def generate_focused_prompt(self, element_id: str, intent: str = "explore") -> str: + """Generate a focused prompt for a specific element.""" + return build_focused_prompt(self.model, element_id, intent) + + def get_timeline(self) -> list[dict]: + """Get version history for time travel UI.""" + timeline = [] + for snap in self.model.snapshots: + timeline.append({ + "version": snap.version, + "timestamp": snap.timestamp.isoformat(), + "trigger": snap.trigger, + "element_count": len(snap.elements), + "summary": snap.ai_summary[:50] if snap.ai_summary else "", + }) + timeline.append({ + "version": self.model.version, + "timestamp": datetime.now().isoformat(), + "trigger": "current", + "element_count": len(self.model.elements), + "summary": self.model.ai_summary[:50] if self.model.ai_summary else "", + }) + return timeline + + def get_version(self, version: int) -> MentalModel | None: + """Get model state at a specific version.""" + import copy + + for snap in self.model.snapshots: + if snap.version == version: + model = MentalModel( + version=snap.version, + ai_summary=snap.ai_summary, + ) + model.elements = copy.deepcopy(snap.elements) + model.relations = copy.deepcopy(snap.relations) + return model + if version == self.model.version: + return copy.deepcopy(self.model) + return None + + def _parse_response(self, response: str) -> dict | None: + """Parse JSON from AI response.""" + import re + + json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL) + if json_match: + try: + return json.loads(json_match.group(1)) + except json.JSONDecodeError: + pass + try: + return json.loads(response) + except json.JSONDecodeError: + return None + + def _apply_updates(self, updates: dict, created_by: str, files: list[str]) -> None: + """Apply updates to the model.""" + + # Remove elements + for eid in updates.get("remove_element_ids", []): + if eid in self.model.elements: + del self.model.elements[eid] + self.model.relations = [ + r + for r in self.model.relations + if r.source_id != eid and r.target_id != eid + ] + + # Remove relations + for source_id, target_id in updates.get("remove_relations", []): + self.model.relations = [ + r + for r in self.model.relations + if not (r.source_id == source_id and r.target_id == target_id) + ] + + # Add elements + for elem_data in updates.get("add_elements", []): + elem = ModelElement( + id=elem_data["id"], + label=elem_data["label"], + shape=elem_data.get("shape", "box"), + color=elem_data.get("color"), + properties=elem_data.get("properties", {}), + created_by=created_by, + reasoning=elem_data.get("reasoning", ""), + ) + self.model.elements[elem.id] = elem + + # Associate files with new elements + if files: + if elem.id not in self.model.element_files: + self.model.element_files[elem.id] = set() + self.model.element_files[elem.id].update(files) + + # Modify elements + for mod in updates.get("modify_elements", []): + if mod["id"] in self.model.elements: + elem = self.model.elements[mod["id"]] + if "label" in mod: + elem.label = mod["label"] + if "shape" in mod: + elem.shape = mod["shape"] + if "color" in mod: + elem.color = mod["color"] + + # Add relations + for rel_data in updates.get("add_relations", []): + rel = ModelRelation( + source_id=rel_data["source_id"], + target_id=rel_data["target_id"], + label=rel_data.get("label", ""), + style=rel_data.get("style", "solid"), + properties=rel_data.get("properties", {}), + created_by=created_by, + reasoning=rel_data.get("reasoning", ""), + ) + self.model.relations.append(rel) + + +# Global instance (for integration with agentgit flow) +_enhancer_instance: MentalModelEnhancer | None = None + + +def get_mental_model_enhancer(repo_path: Path | None = None) -> MentalModelEnhancer: + """Get the mental model enhancer, optionally bound to a repo. + + Args: + repo_path: Path to agentgit output repo. If provided and different + from current instance, creates new instance bound to that repo. + + Returns: + MentalModelEnhancer instance. + """ + global _enhancer_instance + + if repo_path: + # Create new instance for this repo + if _enhancer_instance is None or _enhancer_instance._repo_path != repo_path: + _enhancer_instance = MentalModelEnhancer(repo_path) + elif _enhancer_instance is None: + _enhancer_instance = MentalModelEnhancer() + + return _enhancer_instance + + +def reframe( + repo_path: Path, + prompt_responses: list["PromptResponse"], + model: str = "claude-cli-haiku", + verbose: bool = False, +) -> MentalModel | None: + """Reframe the mental model based on recent code changes. + + This is the main entry point - call after processing a transcript + to evolve the mental model. + + The result is stored in the repo at: + - .agentgit/mental_model.json (structural data) + - .agentgit/mental_model.md (accumulated insights) + + Args: + repo_path: Path to the agentgit output repo. + prompt_responses: The prompt responses that were processed. + model: LLM model to use for interpretation. + verbose: If True, print progress. + + Returns: + The updated MentalModel, or None if update failed. + """ + enhancer = get_mental_model_enhancer(repo_path) + + if verbose: + print("Reframing...", flush=True) + + result = enhancer.observe_changes(prompt_responses, model) + if result: + # Note: observe_changes already saves the model and insights + if verbose: + print(f" Model: v{enhancer.model.version} ({len(enhancer.model.elements)} elements)") + if result.get("insight"): + print(f" Insight: {result['insight'][:60]}...") + + logger.info( + "Reframed: v%d with %d elements", + enhancer.model.version, + len(enhancer.model.elements), + ) + return enhancer.model + + return None + + +# Alias for backwards compatibility +update_mental_model_after_build = reframe + + +def load_mental_model(repo_path: Path) -> MentalModel | None: + """Load a mental model from an agentgit repo. + + Args: + repo_path: Path to the agentgit output repo. + + Returns: + The MentalModel if found, None otherwise. + """ + model_path = repo_path / MentalModelEnhancer.MODEL_FILENAME + if model_path.exists(): + return MentalModel.from_dict(json.loads(model_path.read_text())) + return None + + +# ============================================================================ +# Demo +# ============================================================================ + +if __name__ == "__main__": + # Demo: Direct manipulation of the model (no LLM needed) + print("=== Reframe Demo ===\n") + + model = MentalModel() + + # Simulate what AI would return for an e-commerce observation + print("Step 1: AI observes e-commerce codebase") + model.snapshot("initial observation") + + model.elements["browse"] = ModelElement( + id="browse", + label="Product Discovery", + shape="stadium", + color="#e3f2fd", + reasoning="How users find products", + created_by="ai", + ) + model.elements["cart"] = ModelElement( + id="cart", + label="Cart", + shape="rounded", + color="#fff3e0", + reasoning="Accumulates items before purchase", + created_by="ai", + ) + model.elements["checkout"] = ModelElement( + id="checkout", + label="Purchase", + shape="hexagon", + color="#e8f5e9", + reasoning="Where transactions happen", + created_by="ai", + ) + model.relations.append( + ModelRelation(source_id="browse", target_id="cart", label="add to") + ) + model.relations.append( + ModelRelation(source_id="cart", target_id="checkout", label="proceed") + ) + model.version = 1 + model.ai_summary = "E-commerce platform for browsing and buying products" + + print(f"AI summary: {model.ai_summary}") + print(f"\n{model.to_mermaid()}") + + # Simulate human instruction: split checkout + print("\n" + "=" * 50) + print("Step 2: Human draws box around 'checkout' and says:") + print(' "This is actually three steps: shipping, payment, confirmation"') + + model.snapshot("human instruction") + + # Remove checkout + del model.elements["checkout"] + model.relations = [r for r in model.relations if r.target_id != "checkout"] + + # Add the three steps + model.elements["shipping"] = ModelElement( + id="shipping", + label="Shipping", + shape="rounded", + color="#e8f5e9", + created_by="human", + ) + model.elements["payment"] = ModelElement( + id="payment", + label="Payment", + shape="hexagon", + color="#ffebee", + created_by="human", + ) + model.elements["confirm"] = ModelElement( + id="confirm", + label="Confirmation", + shape="rounded", + color="#e8f5e9", + created_by="human", + ) + model.relations.append( + ModelRelation(source_id="cart", target_id="shipping", label="begin") + ) + model.relations.append( + ModelRelation(source_id="shipping", target_id="payment", label="next") + ) + model.relations.append( + ModelRelation(source_id="payment", target_id="confirm", label="complete") + ) + model.version = 2 + + print(f"\n{model.to_mermaid()}") + + # Show timeline + print("\n" + "=" * 50) + print("Timeline (for time travel):") + model.snapshot("current") + for snap in model.snapshots: + print(f" v{snap.version}: {snap.trigger} ({len(snap.elements)} elements)") + + # Show focused prompt + print("\n" + "=" * 50) + print("Focused prompt for 'payment' element:") + print(build_focused_prompt(model, "payment", "debug")) diff --git a/src/agentgit/enhancers/static_analysis.py b/src/agentgit/enhancers/static_analysis.py new file mode 100644 index 0000000..59f6483 --- /dev/null +++ b/src/agentgit/enhancers/static_analysis.py @@ -0,0 +1,443 @@ +"""Static analysis support for Reframe. + +Uses available tools to extract structural information from code: +- tree-sitter for multi-language AST parsing +- pyan for Python call graphs +- Simple regex fallbacks when tools unavailable + +This provides a foundation for the AI to work with, rather than +having to infer everything from file names and diffs. +""" + +from __future__ import annotations + +import re +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class CodeEntity: + """A code entity (function, class, module, etc.).""" + + name: str + entity_type: str # function, class, module, interface, etc. + file_path: str + line_number: int | None = None + docstring: str | None = None + properties: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class CodeRelation: + """A relationship between code entities.""" + + source: str # Entity name + target: str # Entity name + relation_type: str # calls, imports, inherits, implements, uses + file_path: str | None = None + line_number: int | None = None + + +@dataclass +class CodeStructure: + """Extracted code structure from static analysis.""" + + entities: list[CodeEntity] = field(default_factory=list) + relations: list[CodeRelation] = field(default_factory=list) + language: str = "unknown" + analysis_method: str = "unknown" + + +def analyze_python_with_pyan(path: Path) -> CodeStructure | None: + """Use pyan to analyze Python code and extract call graph. + + Requires: pip install pyan3 + """ + try: + # Check if pyan is available + result = subprocess.run( + ["pyan3", "--help"], + capture_output=True, + text=True, + ) + if result.returncode != 0: + return None + except FileNotFoundError: + return None + + # Find Python files + if path.is_file(): + py_files = [path] if path.suffix == ".py" else [] + else: + py_files = list(path.glob("**/*.py")) + + if not py_files: + return None + + # Run pyan to get call graph in dot format + try: + result = subprocess.run( + ["pyan3", "--dot", "--no-defines", "--grouped"] + [str(f) for f in py_files[:50]], + capture_output=True, + text=True, + timeout=30, + ) + except (subprocess.TimeoutExpired, Exception): + return None + + if result.returncode != 0: + return None + + # Parse dot output + structure = CodeStructure(language="python", analysis_method="pyan") + + # Extract nodes (functions/methods) + for match in re.finditer(r'"([^"]+)"\s*\[', result.stdout): + name = match.group(1) + # pyan uses module.class.method format + parts = name.split(".") + entity_type = "function" if len(parts) <= 2 else "method" + structure.entities.append(CodeEntity( + name=name, + entity_type=entity_type, + file_path="", # pyan doesn't give us this easily + )) + + # Extract edges (calls) + for match in re.finditer(r'"([^"]+)"\s*->\s*"([^"]+)"', result.stdout): + source, target = match.groups() + structure.relations.append(CodeRelation( + source=source, + target=target, + relation_type="calls", + )) + + return structure + + +def analyze_with_ctags(path: Path) -> CodeStructure | None: + """Use universal-ctags to extract code structure. + + Requires: universal-ctags installed + Works with many languages. + """ + try: + result = subprocess.run( + ["ctags", "--version"], + capture_output=True, + text=True, + ) + if "Universal Ctags" not in result.stdout: + return None + except FileNotFoundError: + return None + + # Run ctags + try: + result = subprocess.run( + [ + "ctags", + "-R", + "--output-format=json", + "--fields=+n+S+K", # line number, signature, kind + str(path), + ], + capture_output=True, + text=True, + timeout=30, + ) + except (subprocess.TimeoutExpired, Exception): + return None + + structure = CodeStructure(analysis_method="ctags") + + import json + for line in result.stdout.strip().split("\n"): + if not line: + continue + try: + tag = json.loads(line) + kind = tag.get("kind", "unknown") + entity_type = { + "function": "function", + "method": "method", + "class": "class", + "interface": "interface", + "module": "module", + "variable": "variable", + }.get(kind, kind) + + structure.entities.append(CodeEntity( + name=tag.get("name", ""), + entity_type=entity_type, + file_path=tag.get("path", ""), + line_number=tag.get("line"), + )) + except json.JSONDecodeError: + continue + + return structure + + +def analyze_python_simple(path: Path) -> CodeStructure: + """Simple regex-based Python analysis (no dependencies). + + Less accurate than AST parsing but always available. + """ + structure = CodeStructure(language="python", analysis_method="regex") + + if path.is_file(): + py_files = [path] if path.suffix == ".py" else [] + else: + py_files = list(path.glob("**/*.py")) + + for py_file in py_files[:100]: # Limit for performance + try: + content = py_file.read_text(errors="ignore") + except Exception: + continue + + rel_path = str(py_file) + + # Find classes + for match in re.finditer(r'^class\s+(\w+)(?:\(([^)]*)\))?:', content, re.MULTILINE): + class_name = match.group(1) + bases = match.group(2) + structure.entities.append(CodeEntity( + name=class_name, + entity_type="class", + file_path=rel_path, + line_number=content[:match.start()].count("\n") + 1, + )) + + # Track inheritance + if bases: + for base in bases.split(","): + base = base.strip().split("(")[0].split("[")[0] + if base and base not in ("object", "ABC", "Protocol"): + structure.relations.append(CodeRelation( + source=class_name, + target=base, + relation_type="inherits", + file_path=rel_path, + )) + + # Find functions/methods + # Use [ \t]* instead of \s* to avoid matching newlines as indentation + for match in re.finditer(r'^([ \t]*)def\s+(\w+)\s*\(', content, re.MULTILINE): + indent = len(match.group(1)) + func_name = match.group(2) + entity_type = "method" if indent > 0 else "function" + structure.entities.append(CodeEntity( + name=func_name, + entity_type=entity_type, + file_path=rel_path, + line_number=content[:match.start()].count("\n") + 1, + )) + + # Find imports + for match in re.finditer(r'^(?:from\s+(\S+)\s+)?import\s+(.+)$', content, re.MULTILINE): + module = match.group(1) or "" + imports = match.group(2) + for imp in imports.split(","): + imp = imp.strip().split(" as ")[0].strip() + if imp: + target = f"{module}.{imp}" if module else imp + structure.relations.append(CodeRelation( + source=rel_path, + target=target, + relation_type="imports", + file_path=rel_path, + )) + + return structure + + +def analyze_javascript_simple(path: Path) -> CodeStructure: + """Simple regex-based JavaScript/TypeScript analysis.""" + structure = CodeStructure(language="javascript", analysis_method="regex") + + if path.is_file(): + js_files = [path] if path.suffix in (".js", ".ts", ".jsx", ".tsx") else [] + else: + js_files = [] + for ext in ("*.js", "*.ts", "*.jsx", "*.tsx"): + js_files.extend(path.glob(f"**/{ext}")) + + for js_file in js_files[:100]: + # Skip node_modules (check for path segment, not just substring) + path_parts = js_file.parts + if "node_modules" in path_parts: + continue + + try: + content = js_file.read_text(errors="ignore") + except Exception: + continue + + rel_path = str(js_file) + + # Find classes + for match in re.finditer(r'class\s+(\w+)(?:\s+extends\s+(\w+))?', content): + class_name = match.group(1) + base = match.group(2) + structure.entities.append(CodeEntity( + name=class_name, + entity_type="class", + file_path=rel_path, + line_number=content[:match.start()].count("\n") + 1, + )) + if base: + structure.relations.append(CodeRelation( + source=class_name, + target=base, + relation_type="extends", + file_path=rel_path, + )) + + # Find functions + for match in re.finditer(r'(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\([^)]*\)\s*=>)', content): + func_name = match.group(1) or match.group(2) + if func_name: + structure.entities.append(CodeEntity( + name=func_name, + entity_type="function", + file_path=rel_path, + line_number=content[:match.start()].count("\n") + 1, + )) + + # Find React components (PascalCase functions returning JSX) + for match in re.finditer(r'(?:function|const)\s+([A-Z]\w+)', content): + comp_name = match.group(1) + if comp_name not in [e.name for e in structure.entities]: + structure.entities.append(CodeEntity( + name=comp_name, + entity_type="component", + file_path=rel_path, + line_number=content[:match.start()].count("\n") + 1, + )) + + # Find imports + for match in re.finditer(r'import\s+(?:{[^}]+}|\w+)\s+from\s+[\'"]([^\'"]+)[\'"]', content): + module = match.group(1) + structure.relations.append(CodeRelation( + source=rel_path, + target=module, + relation_type="imports", + file_path=rel_path, + )) + + return structure + + +def analyze_code(path: Path) -> CodeStructure: + """Analyze code using best available method. + + Tries in order: + 1. pyan (for Python, if available) + 2. ctags (multi-language, if available) + 3. Simple regex (always works) + """ + path = Path(path) + + # Detect primary language + if path.is_file(): + suffix = path.suffix.lower() + else: + # Look at file distribution + py_count = len(list(path.glob("**/*.py"))) + js_count = len(list(path.glob("**/*.js"))) + len(list(path.glob("**/*.ts"))) + suffix = ".py" if py_count >= js_count else ".js" + + # Try specialized tools first + if suffix == ".py": + result = analyze_python_with_pyan(path) + if result and result.entities: + return result + + # Try ctags + result = analyze_with_ctags(path) + if result and result.entities: + return result + + # Fall back to simple regex + if suffix == ".py": + return analyze_python_simple(path) + elif suffix in (".js", ".ts", ".jsx", ".tsx"): + return analyze_javascript_simple(path) + else: + # Try both + py_result = analyze_python_simple(path) + js_result = analyze_javascript_simple(path) + if len(py_result.entities) >= len(js_result.entities): + return py_result + return js_result + + +def structure_to_context(structure: CodeStructure) -> dict: + """Convert code structure to context for the AI prompt.""" + # Group entities by type + by_type: dict[str, list[str]] = {} + for entity in structure.entities: + if entity.entity_type not in by_type: + by_type[entity.entity_type] = [] + by_type[entity.entity_type].append(entity.name) + + # Summarize relations + relation_summary: dict[str, int] = {} + for rel in structure.relations: + if rel.relation_type not in relation_summary: + relation_summary[rel.relation_type] = 0 + relation_summary[rel.relation_type] += 1 + + return { + "language": structure.language, + "analysis_method": structure.analysis_method, + "entities": { + etype: names[:20] # Limit for prompt size + for etype, names in by_type.items() + }, + "entity_counts": {etype: len(names) for etype, names in by_type.items()}, + "relation_counts": relation_summary, + "sample_relations": [ + {"from": r.source, "to": r.target, "type": r.relation_type} + for r in structure.relations[:10] + ], + } + + +# Example usage +if __name__ == "__main__": + import sys + + path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path(".") + + print(f"Analyzing: {path}") + structure = analyze_code(path) + + print(f"\nLanguage: {structure.language}") + print(f"Method: {structure.analysis_method}") + print(f"Entities: {len(structure.entities)}") + print(f"Relations: {len(structure.relations)}") + + # Group by type + by_type: dict[str, list] = {} + for e in structure.entities: + if e.entity_type not in by_type: + by_type[e.entity_type] = [] + by_type[e.entity_type].append(e.name) + + print("\nEntities by type:") + for etype, names in by_type.items(): + print(f" {etype}: {len(names)}") + for name in names[:5]: + print(f" - {name}") + if len(names) > 5: + print(f" ... and {len(names) - 5} more") + + print("\nSample relations:") + for rel in structure.relations[:10]: + print(f" {rel.source} --{rel.relation_type}--> {rel.target}") diff --git a/src/agentgit/enhancers/viewer/ReframeViewer.tsx b/src/agentgit/enhancers/viewer/ReframeViewer.tsx new file mode 100644 index 0000000..d0d4eb5 --- /dev/null +++ b/src/agentgit/enhancers/viewer/ReframeViewer.tsx @@ -0,0 +1,296 @@ +/** + * ReframeViewer - React component for visualizing mental models + * + * Uses react-force-graph for interactive force-directed graph visualization. + * Loads model data from .agentgit/mental_model.json + * + * Install dependencies: + * npm install react-force-graph-2d + * + * Usage: + * + */ + +import React, { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import ForceGraph2D from "react-force-graph-2d"; + +interface Node { + id: string; + name: string; + group: string; + color?: string; + val?: number; + description?: string; + files?: string[]; +} + +interface Link { + source: string; + target: string; + label?: string; + style?: string; + description?: string; +} + +interface GraphData { + nodes: Node[]; + links: Link[]; +} + +interface MentalModelData { + elements: Record; + relations: any[]; + version: number; + ai_summary: string; +} + +// Color palette for groups +const GROUP_COLORS: Record = { + box: "#90caf9", + rounded: "#a5d6a7", + circle: "#ffcc80", + diamond: "#ce93d8", + cylinder: "#80deea", + hexagon: "#f48fb1", + stadium: "#bcaaa4", + default: "#e0e0e0", +}; + +function transformToForceGraph(model: MentalModelData): GraphData { + const nodes: Node[] = Object.values(model.elements).map((elem: any) => ({ + id: elem.id, + name: elem.label, + group: elem.properties?.group || elem.shape || "default", + color: elem.color, + val: elem.properties?.size || 1, + description: elem.reasoning, + })); + + const links: Link[] = model.relations.map((rel: any) => ({ + source: rel.source_id, + target: rel.target_id, + label: rel.label, + style: rel.style, + description: rel.reasoning, + })); + + return { nodes, links }; +} + +interface ReframeViewerProps { + modelPath?: string; + modelData?: MentalModelData; + width?: number; + height?: number; + onNodeClick?: (node: Node) => void; + onNodeHover?: (node: Node | null) => void; +} + +export function ReframeViewer({ + modelPath, + modelData, + width = 800, + height = 600, + onNodeClick, + onNodeHover, +}: ReframeViewerProps) { + const graphRef = useRef(); + const [data, setData] = useState({ nodes: [], links: [] }); + const [selectedNode, setSelectedNode] = useState(null); + const [hoveredNode, setHoveredNode] = useState(null); + + // Load model from file or use provided data + useEffect(() => { + if (modelData) { + setData(transformToForceGraph(modelData)); + } else if (modelPath) { + fetch(modelPath) + .then((res) => res.json()) + .then((model) => setData(transformToForceGraph(model))) + .catch((err) => console.error("Failed to load model:", err)); + } + }, [modelPath, modelData]); + + // Node color based on group + const nodeColor = useCallback( + (node: Node) => { + if (node.color) return node.color; + if (selectedNode?.id === node.id) return "#ff5722"; + if (hoveredNode?.id === node.id) return "#ffc107"; + return GROUP_COLORS[node.group] || GROUP_COLORS.default; + }, + [selectedNode, hoveredNode] + ); + + // Link styling + const linkColor = useCallback((link: Link) => { + return link.style === "dashed" ? "#999" : "#666"; + }, []); + + const linkWidth = useCallback((link: Link) => { + return link.style === "thick" ? 3 : 1; + }, []); + + // Node click handler + const handleNodeClick = useCallback( + (node: Node) => { + setSelectedNode(node); + onNodeClick?.(node); + + // Center on clicked node + graphRef.current?.centerAt(node.x, node.y, 1000); + graphRef.current?.zoom(2, 1000); + }, + [onNodeClick] + ); + + // Node hover handler + const handleNodeHover = useCallback( + (node: Node | null) => { + setHoveredNode(node); + onNodeHover?.(node); + }, + [onNodeHover] + ); + + // Custom node rendering + const nodeCanvasObject = useCallback( + (node: any, ctx: CanvasRenderingContext2D, globalScale: number) => { + const label = node.name; + const fontSize = 12 / globalScale; + ctx.font = `${fontSize}px Sans-Serif`; + + // Node circle + const size = Math.sqrt(node.val || 1) * 5; + ctx.beginPath(); + ctx.arc(node.x, node.y, size, 0, 2 * Math.PI); + ctx.fillStyle = nodeColor(node); + ctx.fill(); + + // Border for selected/hovered + if (selectedNode?.id === node.id || hoveredNode?.id === node.id) { + ctx.strokeStyle = selectedNode?.id === node.id ? "#ff5722" : "#ffc107"; + ctx.lineWidth = 2 / globalScale; + ctx.stroke(); + } + + // Label + ctx.textAlign = "center"; + ctx.textBaseline = "middle"; + ctx.fillStyle = "#333"; + ctx.fillText(label, node.x, node.y + size + fontSize); + }, + [nodeColor, selectedNode, hoveredNode] + ); + + // Link label rendering + const linkCanvasObjectMode = () => "after"; + const linkCanvasObject = useCallback( + (link: any, ctx: CanvasRenderingContext2D, globalScale: number) => { + if (!link.label) return; + + const fontSize = 10 / globalScale; + ctx.font = `${fontSize}px Sans-Serif`; + ctx.textAlign = "center"; + ctx.textBaseline = "middle"; + ctx.fillStyle = "#666"; + + // Position at midpoint + const midX = (link.source.x + link.target.x) / 2; + const midY = (link.source.y + link.target.y) / 2; + ctx.fillText(link.label, midX, midY); + }, + [] + ); + + return ( +
+ node.val || 1} + nodeCanvasObject={nodeCanvasObject} + nodeCanvasObjectMode={() => "replace"} + linkColor={linkColor} + linkWidth={linkWidth} + linkDirectionalArrowLength={6} + linkDirectionalArrowRelPos={1} + linkCanvasObject={linkCanvasObject} + linkCanvasObjectMode={linkCanvasObjectMode} + onNodeClick={handleNodeClick} + onNodeHover={handleNodeHover} + cooldownTicks={100} + onEngineStop={() => graphRef.current?.zoomToFit(400)} + /> + + {/* Info panel for selected node */} + {selectedNode && ( +
+

{selectedNode.name}

+ {selectedNode.description && ( +

+ {selectedNode.description} +

+ )} + {selectedNode.files && selectedNode.files.length > 0 && ( +
+ Files: +
    + {selectedNode.files.map((f) => ( +
  • + {f} +
  • + ))} +
+
+ )} + +
+ )} + + {/* Hover tooltip */} + {hoveredNode && hoveredNode !== selectedNode && ( +
+ {hoveredNode.name} + {hoveredNode.description && `: ${hoveredNode.description}`} +
+ )} +
+ ); +} + +export default ReframeViewer; diff --git a/src/agentgit/enhancers/viewer/index.html b/src/agentgit/enhancers/viewer/index.html new file mode 100644 index 0000000..6fd36fb --- /dev/null +++ b/src/agentgit/enhancers/viewer/index.html @@ -0,0 +1,220 @@ + + + + + + Reframe - Mental Model Viewer + + + + + + +
+ + + + diff --git a/src/agentgit/interactive_model.py b/src/agentgit/interactive_model.py new file mode 100644 index 0000000..e60a1f8 --- /dev/null +++ b/src/agentgit/interactive_model.py @@ -0,0 +1,562 @@ +""" +Interactive Mental Model - Time travel and focused prompting. + +Extends the living mental model with: +1. Version history / snapshots for "time travel" through the model's evolution +2. Interactive prompting - clicking on diagram elements generates contextual prompts + +The key insight: the diagram isn't just a visualization, it's an INTERFACE for +directing the AI agent. Clicking on a component lets you ask focused questions +or make targeted changes to that part of the system. +""" + +from dataclasses import dataclass, field +from typing import Optional, Callable +from datetime import datetime +from pathlib import Path +import json +import copy + +from .mental_model import ( + MentalModel, + MentalModelNode, + MentalModelEdge, + ModelUpdate, +) +from .core import Prompt + + +@dataclass +class ModelSnapshot: + """A point-in-time snapshot of the mental model.""" + + version: int + timestamp: datetime + model_state: MentalModel + trigger_prompt_id: Optional[str] = None + trigger_prompt_text: Optional[str] = None + change_summary: str = "" + + def to_dict(self) -> dict: + return { + "version": self.version, + "timestamp": self.timestamp.isoformat(), + "trigger_prompt_id": self.trigger_prompt_id, + "trigger_prompt_text": self.trigger_prompt_text, + "change_summary": self.change_summary, + "model_state": json.loads(self.model_state.to_json()), + } + + +@dataclass +class FocusedPrompt: + """A prompt generated from interacting with the diagram.""" + + prompt_text: str + context: str # What part of the diagram this relates to + suggested_files: list[str] = field(default_factory=list) + node_ids: list[str] = field(default_factory=list) # Related nodes + + +class InteractiveMentalModel: + """ + A mental model with time travel and interactive prompting capabilities. + + This transforms the diagram from a passive visualization into an + active interface for directing an AI coding agent. + """ + + def __init__(self, project_name: str = ""): + self.current_model = MentalModel(project_name=project_name) + self.history: list[ModelSnapshot] = [] + self.current_version: int = 0 + + # Map from node IDs to associated files (for focused context) + self.node_files: dict[str, set[str]] = {} + + # Map from node IDs to prompts that modified them (for provenance) + self.node_prompt_history: dict[str, list[str]] = {} + + def apply_update( + self, + update: ModelUpdate, + prompt: Optional[Prompt] = None, + affected_files: Optional[list[str]] = None, + ) -> None: + """Apply an update and save a snapshot for history.""" + + # Save snapshot BEFORE applying update + self._save_snapshot(prompt, update.reasoning) + + # Apply the update + for node_id in update.remove_node_ids: + self.current_model.remove_node(node_id) + + for source_id, target_id, relationship in update.remove_edges: + self.current_model.edges = [ + e for e in self.current_model.edges + if not (e.source_id == source_id and + e.target_id == target_id and + e.relationship == relationship) + ] + + for node in update.add_nodes: + self.current_model.add_node(node, prompt.prompt_id if prompt else None) + + # Track file associations + if affected_files: + if node.id not in self.node_files: + self.node_files[node.id] = set() + self.node_files[node.id].update(affected_files) + + # Track prompt history + if prompt: + if node.id not in self.node_prompt_history: + self.node_prompt_history[node.id] = [] + self.node_prompt_history[node.id].append(prompt.prompt_id) + + for edge in update.add_edges: + edge.introduced_by_prompt_id = prompt.prompt_id if prompt else None + self.current_model.add_edge(edge) + + self.current_model.version += 1 + self.current_model.last_updated = datetime.now() + self.current_version = self.current_model.version + + def _save_snapshot(self, prompt: Optional[Prompt], change_summary: str) -> None: + """Save current state as a snapshot.""" + snapshot = ModelSnapshot( + version=self.current_version, + timestamp=datetime.now(), + model_state=copy.deepcopy(self.current_model), + trigger_prompt_id=prompt.prompt_id if prompt else None, + trigger_prompt_text=prompt.text if prompt else None, + change_summary=change_summary, + ) + self.history.append(snapshot) + + # ===== TIME TRAVEL ===== + + def get_version(self, version: int) -> Optional[MentalModel]: + """Get the model state at a specific version.""" + for snapshot in self.history: + if snapshot.version == version: + return copy.deepcopy(snapshot.model_state) + if version == self.current_version: + return copy.deepcopy(self.current_model) + return None + + def get_timeline(self) -> list[dict]: + """Get a summary of all versions for timeline display.""" + timeline = [] + for snapshot in self.history: + timeline.append({ + "version": snapshot.version, + "timestamp": snapshot.timestamp.isoformat(), + "prompt_preview": ( + snapshot.trigger_prompt_text[:50] + "..." + if snapshot.trigger_prompt_text and len(snapshot.trigger_prompt_text) > 50 + else snapshot.trigger_prompt_text + ), + "change_summary": snapshot.change_summary, + "node_count": len(snapshot.model_state.nodes), + "edge_count": len(snapshot.model_state.edges), + }) + + # Add current state + timeline.append({ + "version": self.current_version, + "timestamp": self.current_model.last_updated.isoformat() if self.current_model.last_updated else None, + "prompt_preview": None, + "change_summary": "Current state", + "node_count": len(self.current_model.nodes), + "edge_count": len(self.current_model.edges), + }) + + return timeline + + def diff_versions(self, from_version: int, to_version: int) -> dict: + """Get the diff between two versions.""" + from_model = self.get_version(from_version) + to_model = self.get_version(to_version) + + if not from_model or not to_model: + return {"error": "Version not found"} + + from_nodes = set(from_model.nodes.keys()) + to_nodes = set(to_model.nodes.keys()) + + added_nodes = to_nodes - from_nodes + removed_nodes = from_nodes - to_nodes + + # Find modified nodes (same ID but different content) + modified_nodes = [] + for node_id in from_nodes & to_nodes: + from_node = from_model.nodes[node_id] + to_node = to_model.nodes[node_id] + if from_node.label != to_node.label or from_node.description != to_node.description: + modified_nodes.append(node_id) + + return { + "from_version": from_version, + "to_version": to_version, + "added_nodes": list(added_nodes), + "removed_nodes": list(removed_nodes), + "modified_nodes": modified_nodes, + } + + # ===== INTERACTIVE PROMPTING ===== + + def generate_node_prompt(self, node_id: str, intent: str = "explore") -> FocusedPrompt: + """ + Generate a focused prompt for a specific node in the diagram. + + This is called when a user clicks on a node and wants to interact + with the AI about that specific part of the system. + + Intent options: + - "explore": Learn more about this component + - "modify": Make changes to this component + - "connect": Add connections to/from this component + - "debug": Investigate issues with this component + - "test": Add tests for this component + """ + node = self.current_model.nodes.get(node_id) + if not node: + return FocusedPrompt( + prompt_text=f"I don't see a component with ID '{node_id}' in the current model.", + context="Unknown node", + ) + + # Get related context + connected_nodes = self._get_connected_nodes(node_id) + associated_files = list(self.node_files.get(node_id, [])) + prompt_history = self.node_prompt_history.get(node_id, []) + + # Build context string + context_parts = [f"**{node.label}** ({node.node_type})"] + if node.description: + context_parts.append(f"Description: {node.description}") + if connected_nodes: + context_parts.append(f"Connected to: {', '.join(connected_nodes)}") + if associated_files: + context_parts.append(f"Related files: {', '.join(associated_files[:5])}") + + context = "\n".join(context_parts) + + # Generate prompt based on intent + prompts = { + "explore": f"Tell me more about the {node.label} component. What does it do, how does it work, and how does it fit into the overall system?", + + "modify": f"I want to make changes to the {node.label} component. What are the key files I should look at, and what should I be careful about when modifying it?", + + "connect": f"I want to add a new connection to/from the {node.label} component. What other parts of the system might it need to interact with?", + + "debug": f"I'm having issues with the {node.label} component. Can you help me understand how it works and identify potential problem areas?", + + "test": f"I want to add tests for the {node.label} component. What are the key behaviors I should test, and what edge cases should I consider?", + + "explain": f"Explain the {node.label} component to me as if I'm new to this codebase. What's its purpose and how does it relate to the rest of the system?", + + "refactor": f"I'm considering refactoring the {node.label} component. What improvements could be made, and what would be the impact on connected components?", + } + + prompt_text = prompts.get(intent, prompts["explore"]) + + # Add context to make it more specific + if associated_files: + prompt_text += f"\n\nRelevant files: {', '.join(associated_files[:3])}" + + return FocusedPrompt( + prompt_text=prompt_text, + context=context, + suggested_files=associated_files, + node_ids=[node_id] + list(connected_nodes), + ) + + def generate_edge_prompt(self, source_id: str, target_id: str) -> FocusedPrompt: + """Generate a prompt focused on a relationship between two nodes.""" + source = self.current_model.nodes.get(source_id) + target = self.current_model.nodes.get(target_id) + + if not source or not target: + return FocusedPrompt( + prompt_text="Could not find one or both nodes.", + context="Unknown relationship", + ) + + # Find the edge + edge = None + for e in self.current_model.edges: + if e.source_id == source_id and e.target_id == target_id: + edge = e + break + + relationship_desc = edge.relationship if edge else "relates to" + + prompt_text = f"Explain how {source.label} {relationship_desc} {target.label}. How do these components interact, and what data or control flows between them?" + + return FocusedPrompt( + prompt_text=prompt_text, + context=f"Relationship: {source.label} → {target.label}", + suggested_files=list( + self.node_files.get(source_id, set()) | + self.node_files.get(target_id, set()) + ), + node_ids=[source_id, target_id], + ) + + def generate_area_prompt(self, node_ids: list[str]) -> FocusedPrompt: + """Generate a prompt for a selected area of the diagram (multiple nodes).""" + nodes = [self.current_model.nodes[nid] for nid in node_ids if nid in self.current_model.nodes] + + if not nodes: + return FocusedPrompt( + prompt_text="No valid nodes selected.", + context="Empty selection", + ) + + node_names = [n.label for n in nodes] + all_files = set() + for nid in node_ids: + all_files.update(self.node_files.get(nid, set())) + + prompt_text = f"I'm looking at this area of the system: {', '.join(node_names)}. How do these components work together? What's the data flow and control flow between them?" + + return FocusedPrompt( + prompt_text=prompt_text, + context=f"Selected components: {', '.join(node_names)}", + suggested_files=list(all_files), + node_ids=node_ids, + ) + + def _get_connected_nodes(self, node_id: str) -> set[str]: + """Get all nodes connected to the given node.""" + connected = set() + for edge in self.current_model.edges: + if edge.source_id == node_id: + if edge.target_id in self.current_model.nodes: + connected.add(self.current_model.nodes[edge.target_id].label) + elif edge.target_id == node_id: + if edge.source_id in self.current_model.nodes: + connected.add(self.current_model.nodes[edge.source_id].label) + return connected + + # ===== EXPORT ===== + + def to_interactive_html(self) -> str: + """ + Export the model as an interactive HTML page. + + This would render the Mermaid diagram with click handlers that + trigger the focused prompt generation. + """ + mermaid_code = self.current_model.to_mermaid() + timeline_data = json.dumps(self.get_timeline()) + + # This is a simplified template - a real implementation would use + # a proper frontend framework and Mermaid's click callbacks + html = f""" + + + Living Mental Model + + + + +

🧠 Living Mental Model

+ +
+
+

System Architecture

+
+{mermaid_code} +
+

Click on any node to generate a focused prompt

+ +

Time Travel

+ + Version {self.current_version} +
+ + +
+ + + +""" + return html + + def save(self, path: Path) -> None: + """Save the complete interactive model state.""" + state = { + "current_model": json.loads(self.current_model.to_json()), + "history": [s.to_dict() for s in self.history], + "current_version": self.current_version, + "node_files": {k: list(v) for k, v in self.node_files.items()}, + "node_prompt_history": self.node_prompt_history, + } + Path(path).write_text(json.dumps(state, indent=2)) + + +# Example usage +if __name__ == "__main__": + from .mental_model import MentalModelNode, MentalModelEdge, ModelUpdate + from .core import Prompt + + # Create an interactive model + model = InteractiveMentalModel(project_name="E-commerce Platform") + + # Simulate evolution through multiple prompts + prompts = [ + Prompt(text="Create a basic product catalog", timestamp=datetime.now()), + Prompt(text="Add shopping cart functionality", timestamp=datetime.now()), + Prompt(text="Implement checkout with Stripe", timestamp=datetime.now()), + ] + + updates = [ + ModelUpdate( + reasoning="Adding product catalog components", + add_nodes=[ + MentalModelNode(id="products", label="Product Catalog", node_type="component"), + MentalModelNode(id="db", label="Product Database", node_type="data"), + ], + add_edges=[ + MentalModelEdge(source_id="products", target_id="db", relationship="uses"), + ], + ), + ModelUpdate( + reasoning="Adding shopping cart", + add_nodes=[ + MentalModelNode(id="cart", label="Shopping Cart", node_type="component"), + ], + add_edges=[ + MentalModelEdge(source_id="cart", target_id="products", relationship="uses"), + ], + ), + ModelUpdate( + reasoning="Adding checkout and payment processing", + add_nodes=[ + MentalModelNode(id="checkout", label="Checkout Flow", node_type="process"), + MentalModelNode(id="stripe", label="Stripe Integration", node_type="interface"), + ], + add_edges=[ + MentalModelEdge(source_id="checkout", target_id="cart", relationship="uses"), + MentalModelEdge(source_id="checkout", target_id="stripe", relationship="uses"), + ], + ), + ] + + # Track file associations + file_associations = [ + ["src/products/catalog.py", "src/products/models.py"], + ["src/cart/cart.py", "src/cart/session.py"], + ["src/checkout/flow.py", "src/payments/stripe.py"], + ] + + # Apply updates + for prompt, update, files in zip(prompts, updates, file_associations): + model.apply_update(update, prompt, files) + print(f"Applied: {update.reasoning}") + + # Show timeline + print("\n=== Timeline ===") + for item in model.get_timeline(): + print(f" v{item['version']}: {item['change_summary']} ({item['node_count']} nodes)") + + # Show diff + print("\n=== Diff v0 → v3 ===") + diff = model.diff_versions(0, 3) + print(f" Added: {diff['added_nodes']}") + + # Generate focused prompts + print("\n=== Focused Prompts ===") + + prompt = model.generate_node_prompt("checkout", "explore") + print(f"\n[Explore Checkout]") + print(f" Context: {prompt.context}") + print(f" Prompt: {prompt.prompt_text[:100]}...") + + prompt = model.generate_node_prompt("cart", "modify") + print(f"\n[Modify Cart]") + print(f" Context: {prompt.context}") + print(f" Prompt: {prompt.prompt_text[:100]}...") + + prompt = model.generate_edge_prompt("checkout", "stripe") + print(f"\n[Checkout → Stripe relationship]") + print(f" Prompt: {prompt.prompt_text}") + + # Final diagram + print("\n=== Current Diagram ===") + print(model.current_model.to_mermaid()) diff --git a/src/agentgit/living_model.py b/src/agentgit/living_model.py new file mode 100644 index 0000000..0b15832 --- /dev/null +++ b/src/agentgit/living_model.py @@ -0,0 +1,363 @@ +""" +Living Mental Model - Real-time visualization that evolves with your codebase. + +This module ties together the mental model with the transcript watcher, +creating a visualization that automatically adapts as an AI coding agent +makes changes to the codebase. + +The key insight: as you work with AI coding agents, the important thing isn't +the specific code structure, but maintaining a clear mental model of how +the system works conceptually. This tool keeps that model current. +""" + +from pathlib import Path +from typing import Optional, Callable +from datetime import datetime +import os + +from .core import FileOperation, Transcript +from .mental_model import ( + MentalModel, + ModelUpdate, + apply_update, + generate_update_prompt, + parse_update_response, +) + + +class LivingMentalModel: + """ + A mental model that automatically updates as code changes. + + Usage: + model = LivingMentalModel(project_name="My App") + + # When changes happen (e.g., from watcher callback): + model.process_changes(operations) + + # Get current visualization: + print(model.to_mermaid()) + + # Save/load state: + model.save("mental_model.json") + model = LivingMentalModel.load("mental_model.json") + """ + + def __init__( + self, + project_name: str = "", + model_path: Optional[Path] = None, + ai_callback: Optional[Callable[[str], str]] = None, + ): + """ + Initialize a living mental model. + + Args: + project_name: Name for this project's model + model_path: Optional path to persist the model (auto-saves on update) + ai_callback: Function to call AI for model updates. + Signature: (prompt: str) -> str (AI response) + If None, uses a simple rule-based updater. + """ + self.model = MentalModel(project_name=project_name) + self.model_path = Path(model_path) if model_path else None + self.ai_callback = ai_callback + + # Track what we've seen for incremental updates + self._processed_tool_ids: set[str] = set() + + # Load existing model if path provided and exists + if self.model_path and self.model_path.exists(): + self._load() + + def process_changes( + self, + operations: list[FileOperation], + force: bool = False, + ) -> Optional[ModelUpdate]: + """ + Process a batch of file operations and update the mental model. + + Args: + operations: List of file operations from a transcript + force: If True, reprocess even already-seen operations + + Returns: + The ModelUpdate that was applied, or None if no update needed + """ + # Filter to new operations only (unless forced) + if not force: + new_ops = [ + op for op in operations + if op.tool_id and op.tool_id not in self._processed_tool_ids + ] + else: + new_ops = operations + + if not new_ops: + return None + + # Mark as processed + for op in new_ops: + if op.tool_id: + self._processed_tool_ids.add(op.tool_id) + + # Get the update + update = self._compute_update(new_ops) + + if update and not update.is_empty(): + # Get prompt_id from first operation with a prompt + prompt_id = None + for op in new_ops: + if op.prompt: + prompt_id = op.prompt.prompt_id + break + + apply_update(self.model, update, prompt_id) + + # Auto-save if path configured + if self.model_path: + self._save() + + return update + + def process_transcript(self, transcript: Transcript) -> list[ModelUpdate]: + """ + Process an entire transcript and return all updates made. + + This is useful for building a mental model from a complete session. + """ + updates = [] + + # Group operations by prompt for better context + ops_by_prompt: dict[str, list[FileOperation]] = {} + for op in transcript.operations: + key = op.prompt.prompt_id if op.prompt else "unknown" + if key not in ops_by_prompt: + ops_by_prompt[key] = [] + ops_by_prompt[key].append(op) + + # Process each group + for prompt_id, ops in ops_by_prompt.items(): + update = self.process_changes(ops) + if update: + updates.append(update) + + return updates + + def _compute_update(self, operations: list[FileOperation]) -> Optional[ModelUpdate]: + """Compute an update to the mental model based on operations.""" + + if self.ai_callback: + return self._compute_update_with_ai(operations) + else: + return self._compute_update_rule_based(operations) + + def _compute_update_with_ai(self, operations: list[FileOperation]) -> Optional[ModelUpdate]: + """Use AI to interpret changes and propose model updates.""" + + prompt = generate_update_prompt(self.model, operations) + response = self.ai_callback(prompt) + return parse_update_response(response) + + def _compute_update_rule_based(self, operations: list[FileOperation]) -> Optional[ModelUpdate]: + """ + Simple rule-based updater for when no AI is available. + + This provides basic structure based on file patterns, but won't + capture semantic meaning like an AI would. + """ + from .mental_model import MentalModelNode, MentalModelEdge + + update = ModelUpdate(reasoning="Rule-based inference from file patterns") + + # Simple heuristics based on file paths + for op in operations: + path = op.file_path.lower() + path_parts = Path(op.file_path).parts + + # Skip if we already have a node for this + # (very simplistic - AI would do much better) + if len(path_parts) < 2: + continue + + # Infer component from directory structure + if "src" in path_parts: + idx = path_parts.index("src") + if idx + 1 < len(path_parts): + component_name = path_parts[idx + 1] + node_id = f"component_{component_name}" + + if node_id not in self.model.nodes: + # Infer type from common patterns + node_type = "component" + if "api" in component_name or "routes" in component_name: + node_type = "interface" + elif "service" in component_name: + node_type = "service" + elif "model" in component_name or "db" in component_name: + node_type = "data" + + update.add_nodes.append(MentalModelNode( + id=node_id, + label=component_name.replace("_", " ").title(), + node_type=node_type, + description=f"Inferred from {op.file_path}", + )) + + return update if not update.is_empty() else None + + def to_mermaid(self) -> str: + """Get current model as Mermaid diagram.""" + return self.model.to_mermaid() + + def to_json(self) -> str: + """Get current model as JSON.""" + return self.model.to_json() + + def save(self, path: Optional[Path] = None) -> None: + """Save the model to a file.""" + save_path = Path(path) if path else self.model_path + if not save_path: + raise ValueError("No path provided and no default path configured") + save_path.write_text(self.model.to_json()) + + def _save(self) -> None: + """Internal save to configured path.""" + if self.model_path: + self.model_path.parent.mkdir(parents=True, exist_ok=True) + self.model_path.write_text(self.model.to_json()) + + def _load(self) -> None: + """Internal load from configured path.""" + if self.model_path and self.model_path.exists(): + self.model = MentalModel.from_json(self.model_path.read_text()) + + @classmethod + def load(cls, path: Path, ai_callback: Optional[Callable[[str], str]] = None) -> "LivingMentalModel": + """Load a model from a file.""" + instance = cls(model_path=path, ai_callback=ai_callback) + return instance + + +def create_watcher_callback( + living_model: LivingMentalModel, + output_path: Optional[Path] = None, + on_update: Optional[Callable[[MentalModel, ModelUpdate], None]] = None, +): + """ + Create a callback function for use with TranscriptWatcher. + + This integrates the living model with agentgit's watch functionality, + automatically updating the mental model as the transcript grows. + + Usage: + from agentgit import watch_transcript + from agentgit.living_model import LivingMentalModel, create_watcher_callback + + model = LivingMentalModel("My Project") + callback = create_watcher_callback(model, output_path=Path("model.md")) + + # Watch will call our callback on each update + watch_transcript( + "transcript.jsonl", + output_dir="./output", + on_update=lambda count: callback(latest_operations) + ) + """ + + def callback(transcript: Transcript): + """Process new operations and update model.""" + update = living_model.process_changes(transcript.operations) + + if update and not update.is_empty(): + # Write updated visualization + if output_path: + mermaid = living_model.to_mermaid() + output_path.write_text(f"# Mental Model\n\n```mermaid\n{mermaid}\n```\n") + + # Call user callback + if on_update: + on_update(living_model.model, update) + + return callback + + +# Convenience function for quick visualization +def visualize_transcript( + transcript: Transcript, + ai_callback: Optional[Callable[[str], str]] = None, +) -> str: + """ + Generate a mental model visualization from a transcript. + + This is a one-shot function for visualizing a complete transcript. + For live updates, use LivingMentalModel with a watcher. + """ + model = LivingMentalModel( + project_name=transcript.session_id or "Unknown Project", + ai_callback=ai_callback, + ) + model.process_transcript(transcript) + return model.to_mermaid() + + +# Example with mock AI for testing +if __name__ == "__main__": + from .core import FileOperation, OperationType, Prompt, AssistantContext + + # Mock AI callback that returns a simple response + def mock_ai(prompt: str) -> str: + return """```json +{ + "reasoning": "The changes introduce a new authentication system with user management", + "updates": { + "add_nodes": [ + {"id": "auth", "label": "Authentication", "node_type": "service", "description": "Handles user auth"}, + {"id": "users", "label": "User Management", "node_type": "component", "description": "User CRUD"} + ], + "add_edges": [ + {"source_id": "auth", "target_id": "users", "relationship": "uses", "label": "validates"} + ], + "remove_node_ids": [], + "remove_edges": [] + } +} +```""" + + # Create living model + model = LivingMentalModel(project_name="Test App", ai_callback=mock_ai) + + # Simulate some operations + ops = [ + FileOperation( + file_path="src/auth/login.py", + operation_type=OperationType.WRITE, + content="# Login handler", + timestamp=datetime.now(), + tool_id="op1", + prompt=Prompt(text="Add user authentication", timestamp=datetime.now()), + assistant_context=AssistantContext( + thinking="I'll create an authentication module with login/logout" + ), + ), + FileOperation( + file_path="src/users/models.py", + operation_type=OperationType.WRITE, + content="# User model", + timestamp=datetime.now(), + tool_id="op2", + ), + ] + + # Process changes + update = model.process_changes(ops) + + print("=== Update Applied ===") + if update: + print(f"Reasoning: {update.reasoning}") + print(f"Added nodes: {[n.label for n in update.add_nodes]}") + print(f"Added edges: {len(update.add_edges)}") + + print("\n=== Current Mental Model ===") + print(model.to_mermaid()) diff --git a/src/agentgit/mental_model.py b/src/agentgit/mental_model.py new file mode 100644 index 0000000..72c07b0 --- /dev/null +++ b/src/agentgit/mental_model.py @@ -0,0 +1,444 @@ +""" +Living Mental Model - A visualization that evolves with your codebase. + +This module provides a "mental model" diagram that automatically adapts as code changes, +capturing the semantic structure of a system rather than just its file layout. + +The key insight: when working with AI coding agents, what matters isn't the exact code +structure, but understanding *how the system works conceptually* and how that evolves. +""" + +from dataclasses import dataclass, field +from typing import Optional +from datetime import datetime +import json +import hashlib + +from .core import FileOperation, Prompt, AssistantContext + + +@dataclass +class MentalModelNode: + """A node in the mental model - represents a concept, component, or capability.""" + + id: str # Stable identifier + label: str # Human-readable name + node_type: str # e.g., "component", "service", "concept", "data", "interface" + description: str = "" # What this node represents + + # Provenance - where did this understanding come from? + introduced_by_prompt_id: Optional[str] = None + last_updated_prompt_id: Optional[str] = None + + def __hash__(self): + return hash(self.id) + + +@dataclass +class MentalModelEdge: + """A relationship between nodes in the mental model.""" + + source_id: str + target_id: str + relationship: str # e.g., "uses", "contains", "transforms", "depends_on" + label: str = "" # Optional label for the edge + + # Provenance + introduced_by_prompt_id: Optional[str] = None + + +@dataclass +class MentalModel: + """ + A living representation of the system's conceptual structure. + + This is NOT a file tree or class diagram - it's a semantic model of + what the system *does* and how its parts relate conceptually. + """ + + nodes: dict[str, MentalModelNode] = field(default_factory=dict) + edges: list[MentalModelEdge] = field(default_factory=list) + + # Metadata + project_name: str = "" + version: int = 0 # Incremented on each update + last_updated: Optional[datetime] = None + + # History of what changed + changelog: list[dict] = field(default_factory=list) + + def add_node(self, node: MentalModelNode, prompt_id: Optional[str] = None) -> None: + """Add or update a node in the model.""" + if node.id in self.nodes: + # Update existing + existing = self.nodes[node.id] + existing.label = node.label + existing.description = node.description + existing.last_updated_prompt_id = prompt_id + else: + # Add new + node.introduced_by_prompt_id = prompt_id + self.nodes[node.id] = node + + def add_edge(self, edge: MentalModelEdge) -> None: + """Add an edge if it doesn't already exist.""" + for existing in self.edges: + if (existing.source_id == edge.source_id and + existing.target_id == edge.target_id and + existing.relationship == edge.relationship): + return # Already exists + self.edges.append(edge) + + def remove_node(self, node_id: str) -> None: + """Remove a node and all its edges.""" + if node_id in self.nodes: + del self.nodes[node_id] + self.edges = [e for e in self.edges + if e.source_id != node_id and e.target_id != node_id] + + def to_mermaid(self) -> str: + """Export the mental model as a Mermaid diagram.""" + lines = ["graph TD"] + + # Group nodes by type for styling + node_types = {} + for node in self.nodes.values(): + if node.node_type not in node_types: + node_types[node.node_type] = [] + node_types[node.node_type].append(node) + + # Add nodes with type-specific shapes + shape_map = { + "component": ("[", "]"), # Rectangle + "service": ("[[", "]]"), # Subroutine shape + "concept": ("(", ")"), # Rounded + "data": ("[(", ")]"), # Cylinder (database) + "interface": ("{{", "}}"), # Hexagon + "process": ("([", "])"), # Stadium + } + + for node in self.nodes.values(): + left, right = shape_map.get(node.node_type, ("[", "]")) + # Escape quotes in label + safe_label = node.label.replace('"', "'") + lines.append(f' {node.id}{left}"{safe_label}"{right}') + + # Add edges + arrow_map = { + "uses": "-->", + "contains": "-->", + "transforms": "==>", + "depends_on": "-.->", + "implements": "-->", + "extends": "-->", + } + + for edge in self.edges: + arrow = arrow_map.get(edge.relationship, "-->") + if edge.label: + lines.append(f' {edge.source_id} {arrow}|"{edge.label}"| {edge.target_id}') + else: + lines.append(f' {edge.source_id} {arrow} {edge.target_id}') + + # Add styling by node type + style_map = { + "component": "fill:#e1f5fe", + "service": "fill:#fff3e0", + "concept": "fill:#f3e5f5", + "data": "fill:#e8f5e9", + "interface": "fill:#fce4ec", + "process": "fill:#fff8e1", + } + + for node_type, nodes in node_types.items(): + if node_type in style_map: + node_ids = ",".join(n.id for n in nodes) + if node_ids: + lines.append(f' style {node_ids} {style_map[node_type]}') + + return "\n".join(lines) + + def to_json(self) -> str: + """Serialize the mental model to JSON.""" + return json.dumps({ + "project_name": self.project_name, + "version": self.version, + "last_updated": self.last_updated.isoformat() if self.last_updated else None, + "nodes": { + nid: { + "id": n.id, + "label": n.label, + "node_type": n.node_type, + "description": n.description, + "introduced_by_prompt_id": n.introduced_by_prompt_id, + "last_updated_prompt_id": n.last_updated_prompt_id, + } + for nid, n in self.nodes.items() + }, + "edges": [ + { + "source_id": e.source_id, + "target_id": e.target_id, + "relationship": e.relationship, + "label": e.label, + "introduced_by_prompt_id": e.introduced_by_prompt_id, + } + for e in self.edges + ], + "changelog": self.changelog, + }, indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "MentalModel": + """Deserialize a mental model from JSON.""" + data = json.loads(json_str) + model = cls( + project_name=data.get("project_name", ""), + version=data.get("version", 0), + changelog=data.get("changelog", []), + ) + if data.get("last_updated"): + model.last_updated = datetime.fromisoformat(data["last_updated"]) + + for nid, ndata in data.get("nodes", {}).items(): + model.nodes[nid] = MentalModelNode( + id=ndata["id"], + label=ndata["label"], + node_type=ndata["node_type"], + description=ndata.get("description", ""), + introduced_by_prompt_id=ndata.get("introduced_by_prompt_id"), + last_updated_prompt_id=ndata.get("last_updated_prompt_id"), + ) + + for edata in data.get("edges", []): + model.edges.append(MentalModelEdge( + source_id=edata["source_id"], + target_id=edata["target_id"], + relationship=edata["relationship"], + label=edata.get("label", ""), + introduced_by_prompt_id=edata.get("introduced_by_prompt_id"), + )) + + return model + + +@dataclass +class ModelUpdate: + """Describes a proposed update to the mental model.""" + + add_nodes: list[MentalModelNode] = field(default_factory=list) + remove_node_ids: list[str] = field(default_factory=list) + add_edges: list[MentalModelEdge] = field(default_factory=list) + remove_edges: list[tuple[str, str, str]] = field(default_factory=list) # (source, target, relationship) + + # Explanation of why this update was made + reasoning: str = "" + + def is_empty(self) -> bool: + return (not self.add_nodes and not self.remove_node_ids and + not self.add_edges and not self.remove_edges) + + +def apply_update(model: MentalModel, update: ModelUpdate, prompt_id: Optional[str] = None) -> MentalModel: + """Apply an update to the mental model.""" + + # Remove nodes first + for node_id in update.remove_node_ids: + model.remove_node(node_id) + + # Remove edges + for source_id, target_id, relationship in update.remove_edges: + model.edges = [e for e in model.edges + if not (e.source_id == source_id and + e.target_id == target_id and + e.relationship == relationship)] + + # Add nodes + for node in update.add_nodes: + model.add_node(node, prompt_id) + + # Add edges + for edge in update.add_edges: + edge.introduced_by_prompt_id = prompt_id + model.add_edge(edge) + + # Update metadata + model.version += 1 + model.last_updated = datetime.now() + + if update.reasoning: + model.changelog.append({ + "version": model.version, + "prompt_id": prompt_id, + "reasoning": update.reasoning, + "timestamp": model.last_updated.isoformat(), + }) + + return model + + +def generate_update_prompt( + current_model: MentalModel, + operations: list[FileOperation], +) -> str: + """ + Generate a prompt for an AI to analyze changes and propose mental model updates. + + This is the core of the "living mental model" concept - we're asking the AI + to interpret code changes in terms of conceptual/architectural impact. + """ + + # Build context about what changed + changes_summary = [] + for op in operations: + change_desc = f"- {op.operation_type.value}: {op.file_path}" + if op.prompt: + change_desc += f"\n Intent: {op.prompt.text[:200]}..." + if op.assistant_context and op.assistant_context.summary: + change_desc += f"\n Reasoning: {op.assistant_context.summary[:200]}..." + changes_summary.append(change_desc) + + current_diagram = current_model.to_mermaid() if current_model.nodes else "(empty - no existing model)" + + prompt = f"""You are analyzing code changes to update a "mental model" diagram. + +The mental model represents the CONCEPTUAL structure of the system - not file paths or +class hierarchies, but the semantic components, capabilities, and how they relate. + +## Current Mental Model +```mermaid +{current_diagram} +``` + +## Recent Code Changes +{chr(10).join(changes_summary)} + +## Your Task +Analyze these changes and determine if/how the mental model should be updated. + +Consider: +1. Are new concepts or components being introduced? +2. Are existing relationships changing? +3. Is there a shift in how parts of the system interact? +4. Should any concepts be removed or renamed? + +Respond with a JSON object: +{{ + "reasoning": "Explain your interpretation of how these changes affect the system's conceptual structure", + "updates": {{ + "add_nodes": [ + {{"id": "unique_id", "label": "Human Name", "node_type": "component|service|concept|data|interface|process", "description": "What this represents"}} + ], + "remove_node_ids": ["node_id_to_remove"], + "add_edges": [ + {{"source_id": "from_node", "target_id": "to_node", "relationship": "uses|contains|transforms|depends_on", "label": "optional edge label"}} + ], + "remove_edges": [["source_id", "target_id", "relationship"]] + }} +}} + +If no updates are needed, return: {{"reasoning": "...", "updates": null}} +""" + + return prompt + + +def parse_update_response(response: str) -> Optional[ModelUpdate]: + """Parse an AI response into a ModelUpdate.""" + import re + + # Try to extract JSON from the response + # Look for JSON block or the whole response + json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL) + if json_match: + json_str = json_match.group(1) + else: + # Try to find raw JSON + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + json_str = json_match.group(0) + else: + return None + + try: + data = json.loads(json_str) + except json.JSONDecodeError: + return None + + if data.get("updates") is None: + return ModelUpdate(reasoning=data.get("reasoning", "No changes needed")) + + updates = data["updates"] + + update = ModelUpdate(reasoning=data.get("reasoning", "")) + + for node_data in updates.get("add_nodes", []): + update.add_nodes.append(MentalModelNode( + id=node_data["id"], + label=node_data["label"], + node_type=node_data.get("node_type", "component"), + description=node_data.get("description", ""), + )) + + update.remove_node_ids = updates.get("remove_node_ids", []) + + for edge_data in updates.get("add_edges", []): + update.add_edges.append(MentalModelEdge( + source_id=edge_data["source_id"], + target_id=edge_data["target_id"], + relationship=edge_data.get("relationship", "uses"), + label=edge_data.get("label", ""), + )) + + update.remove_edges = [tuple(e) for e in updates.get("remove_edges", [])] + + return update + + +# Example usage and testing +if __name__ == "__main__": + # Create a simple mental model + model = MentalModel(project_name="Example System") + + # Add some nodes + model.add_node(MentalModelNode( + id="api", + label="REST API", + node_type="interface", + description="External API for client applications" + )) + + model.add_node(MentalModelNode( + id="processor", + label="Data Processor", + node_type="service", + description="Transforms and validates incoming data" + )) + + model.add_node(MentalModelNode( + id="storage", + label="Data Store", + node_type="data", + description="Persistent storage for processed data" + )) + + # Add relationships + model.add_edge(MentalModelEdge( + source_id="api", + target_id="processor", + relationship="uses", + label="sends requests" + )) + + model.add_edge(MentalModelEdge( + source_id="processor", + target_id="storage", + relationship="transforms", + label="persists" + )) + + print("=== Mental Model (Mermaid) ===") + print(model.to_mermaid()) + print() + print("=== Mental Model (JSON) ===") + print(model.to_json()) diff --git a/tests/test_reframe.py b/tests/test_reframe.py new file mode 100644 index 0000000..7588b10 --- /dev/null +++ b/tests/test_reframe.py @@ -0,0 +1,1172 @@ +"""Tests for Reframe (mental model enhancer).""" + +import json +import pytest +from datetime import datetime +from pathlib import Path + +from agentgit.enhancers.mental_model import ( + MentalModel, + ModelElement, + ModelRelation, + ModelSnapshot, + SequenceStep, + SequenceDiagram, + MentalModelEnhancer, + build_observation_prompt, + build_instruction_prompt, + build_focused_prompt, + reframe, + load_mental_model, + get_mental_model_enhancer, + _truncate, +) + + +class TestModelElement: + """Tests for ModelElement dataclass.""" + + def test_element_creation(self): + """Should create element with required fields.""" + elem = ModelElement(id="auth", label="Authentication") + assert elem.id == "auth" + assert elem.label == "Authentication" + assert elem.shape == "box" # default + assert elem.color is None + assert elem.created_by == "" + assert elem.reasoning == "" + + def test_element_with_all_fields(self): + """Should create element with all fields.""" + elem = ModelElement( + id="auth", + label="Authentication", + properties={"complexity": "high"}, + shape="hexagon", + color="#ff0000", + created_by="ai", + reasoning="Handles user identity", + ) + assert elem.properties == {"complexity": "high"} + assert elem.shape == "hexagon" + assert elem.color == "#ff0000" + assert elem.created_by == "ai" + assert elem.reasoning == "Handles user identity" + + def test_element_to_dict(self): + """Should serialize to dictionary.""" + elem = ModelElement( + id="auth", + label="Auth", + shape="circle", + color="#00ff00", + created_by="human", + reasoning="test", + ) + d = elem.to_dict() + assert d["id"] == "auth" + assert d["label"] == "Auth" + assert d["shape"] == "circle" + assert d["color"] == "#00ff00" + assert d["created_by"] == "human" + assert d["reasoning"] == "test" + + def test_element_from_dict(self): + """Should deserialize from dictionary.""" + d = { + "id": "cart", + "label": "Shopping Cart", + "properties": {"items": 0}, + "shape": "rounded", + "color": "#0000ff", + "created_by": "ai", + "reasoning": "Stores items", + } + elem = ModelElement.from_dict(d) + assert elem.id == "cart" + assert elem.label == "Shopping Cart" + assert elem.properties == {"items": 0} + assert elem.shape == "rounded" + assert elem.color == "#0000ff" + + def test_element_equality_by_id(self): + """Elements with same id should be distinguishable.""" + elem1 = ModelElement(id="test", label="Test") + elem2 = ModelElement(id="test", label="Different Label") + # Dataclasses compare by all fields, so these are different + assert elem1.id == elem2.id + assert elem1.label != elem2.label + + +class TestModelRelation: + """Tests for ModelRelation dataclass.""" + + def test_relation_creation(self): + """Should create relation with required fields.""" + rel = ModelRelation(source_id="a", target_id="b") + assert rel.source_id == "a" + assert rel.target_id == "b" + assert rel.label == "" + assert rel.style == "solid" + + def test_relation_with_all_fields(self): + """Should create relation with all fields.""" + rel = ModelRelation( + source_id="auth", + target_id="users", + label="validates", + properties={"async": True}, + style="dashed", + created_by="ai", + reasoning="Auth checks user DB", + ) + assert rel.label == "validates" + assert rel.properties == {"async": True} + assert rel.style == "dashed" + + def test_relation_to_dict(self): + """Should serialize to dictionary.""" + rel = ModelRelation( + source_id="a", + target_id="b", + label="uses", + style="dotted", + ) + d = rel.to_dict() + assert d["source_id"] == "a" + assert d["target_id"] == "b" + assert d["label"] == "uses" + assert d["style"] == "dotted" + + def test_relation_from_dict(self): + """Should deserialize from dictionary.""" + d = { + "source_id": "x", + "target_id": "y", + "label": "depends", + "style": "solid", + } + rel = ModelRelation.from_dict(d) + assert rel.source_id == "x" + assert rel.target_id == "y" + assert rel.label == "depends" + + +class TestSequenceStep: + """Tests for SequenceStep dataclass.""" + + def test_step_creation(self): + """Should create step with required fields.""" + step = SequenceStep( + from_participant="Client", + to_participant="Server", + message="GET /api/users", + ) + assert step.from_participant == "Client" + assert step.to_participant == "Server" + assert step.message == "GET /api/users" + assert step.step_type == "sync" # default + assert step.note is None + + def test_step_with_all_fields(self): + """Should create step with all fields.""" + step = SequenceStep( + from_participant="API", + to_participant="Database", + message="SELECT * FROM users", + step_type="async", + note="May take a while", + ) + assert step.step_type == "async" + assert step.note == "May take a while" + + +class TestSequenceDiagram: + """Tests for SequenceDiagram dataclass.""" + + def test_diagram_creation(self): + """Should create empty diagram.""" + diagram = SequenceDiagram(id="login", title="User Login Flow") + assert diagram.id == "login" + assert diagram.title == "User Login Flow" + assert diagram.participants == [] + assert diagram.steps == [] + + def test_diagram_with_participants(self): + """Should create diagram with participants.""" + diagram = SequenceDiagram( + id="api-flow", + title="API Request", + participants=["Client", "API", "Database"], + ) + assert len(diagram.participants) == 3 + assert "API" in diagram.participants + + def test_to_mermaid_basic(self): + """Should generate valid mermaid sequence diagram.""" + diagram = SequenceDiagram( + id="test", + title="Test Flow", + participants=["Client", "Server"], + ) + diagram.steps.append(SequenceStep( + from_participant="Client", + to_participant="Server", + message="Request", + )) + + mermaid = diagram.to_mermaid() + + assert "sequenceDiagram" in mermaid + assert "title Test Flow" in mermaid + assert "participant Client" in mermaid + assert "participant Server" in mermaid + assert "Client->>Server: Request" in mermaid + + def test_to_mermaid_arrow_styles(self): + """Should use correct arrow styles for step types.""" + diagram = SequenceDiagram( + id="test", + title="", + participants=["A", "B"], + ) + diagram.steps.append(SequenceStep( + from_participant="A", + to_participant="B", + message="sync", + step_type="sync", + )) + diagram.steps.append(SequenceStep( + from_participant="B", + to_participant="A", + message="async", + step_type="async", + )) + diagram.steps.append(SequenceStep( + from_participant="A", + to_participant="B", + message="reply", + step_type="reply", + )) + + mermaid = diagram.to_mermaid() + + assert "A->>B: sync" in mermaid + assert "B-->>A: async" in mermaid + assert "A-->>B: reply" in mermaid + + def test_to_mermaid_with_notes(self): + """Should include notes in mermaid output.""" + diagram = SequenceDiagram( + id="test", + title="", + participants=["A", "B"], + ) + diagram.steps.append(SequenceStep( + from_participant="A", + to_participant="B", + message="Request", + note="Important step", + )) + + mermaid = diagram.to_mermaid() + + assert "Note over A,B: Important step" in mermaid + + def test_to_mermaid_sanitizes_spaces(self): + """Should sanitize participant names with spaces.""" + diagram = SequenceDiagram( + id="test", + title="", + participants=["API Gateway"], + ) + diagram.steps.append(SequenceStep( + from_participant="API Gateway", + to_participant="API Gateway", + message="Self call", + )) + + mermaid = diagram.to_mermaid() + + assert "participant API_Gateway as API Gateway" in mermaid + assert "API_Gateway->>API_Gateway: Self call" in mermaid + + def test_to_dict(self): + """Should serialize to dictionary.""" + diagram = SequenceDiagram( + id="login", + title="Login", + participants=["User", "Auth"], + description="User authentication flow", + created_by="ai", + trigger="user login", + ) + diagram.steps.append(SequenceStep( + from_participant="User", + to_participant="Auth", + message="Login", + step_type="sync", + note="With credentials", + )) + + d = diagram.to_dict() + + assert d["id"] == "login" + assert d["title"] == "Login" + assert d["participants"] == ["User", "Auth"] + assert d["description"] == "User authentication flow" + assert len(d["steps"]) == 1 + assert d["steps"][0]["from"] == "User" + assert d["steps"][0]["to"] == "Auth" + assert d["steps"][0]["message"] == "Login" + assert d["steps"][0]["type"] == "sync" + assert d["steps"][0]["note"] == "With credentials" + + def test_from_dict(self): + """Should deserialize from dictionary.""" + d = { + "id": "checkout", + "title": "Checkout Flow", + "participants": ["Cart", "Payment", "Order"], + "steps": [ + {"from": "Cart", "to": "Payment", "message": "Pay", "type": "sync"}, + {"from": "Payment", "to": "Order", "message": "Create", "type": "async", "note": "Background"}, + ], + "description": "Purchase flow", + "created_by": "human", + "trigger": "checkout button", + } + + diagram = SequenceDiagram.from_dict(d) + + assert diagram.id == "checkout" + assert diagram.title == "Checkout Flow" + assert len(diagram.participants) == 3 + assert len(diagram.steps) == 2 + assert diagram.steps[0].from_participant == "Cart" + assert diagram.steps[1].note == "Background" + + +class TestMentalModel: + """Tests for MentalModel dataclass.""" + + def test_empty_model(self): + """Should create empty model.""" + model = MentalModel() + assert len(model.elements) == 0 + assert len(model.relations) == 0 + assert model.version == 0 + assert model.ai_summary == "" + + def test_add_elements(self): + """Should add elements to model.""" + model = MentalModel() + model.elements["auth"] = ModelElement(id="auth", label="Auth") + model.elements["db"] = ModelElement(id="db", label="Database") + assert len(model.elements) == 2 + assert "auth" in model.elements + assert "db" in model.elements + + def test_add_relations(self): + """Should add relations to model.""" + model = MentalModel() + model.elements["a"] = ModelElement(id="a", label="A") + model.elements["b"] = ModelElement(id="b", label="B") + model.relations.append(ModelRelation(source_id="a", target_id="b")) + assert len(model.relations) == 1 + + def test_snapshot_saves_state(self): + """Should save current state as snapshot.""" + model = MentalModel() + model.elements["test"] = ModelElement(id="test", label="Test") + model.version = 1 + + model.snapshot("first snapshot") + + assert len(model.snapshots) == 1 + assert model.snapshots[0].version == 1 + assert model.snapshots[0].trigger == "first snapshot" + assert "test" in model.snapshots[0].elements + + def test_snapshot_is_deep_copy(self): + """Snapshot should be independent of current state.""" + model = MentalModel() + model.elements["test"] = ModelElement(id="test", label="Original") + model.snapshot("before change") + + # Modify current state + model.elements["test"].label = "Modified" + + # Snapshot should still have original + assert model.snapshots[0].elements["test"].label == "Original" + + def test_to_mermaid_empty(self): + """Empty model should produce minimal mermaid.""" + model = MentalModel() + mermaid = model.to_mermaid() + assert "graph TD" in mermaid + + def test_to_mermaid_with_elements(self): + """Should generate valid mermaid with elements.""" + model = MentalModel() + model.elements["api"] = ModelElement(id="api", label="REST API", shape="hexagon") + model.elements["db"] = ModelElement(id="db", label="Database", shape="cylinder") + + mermaid = model.to_mermaid() + + assert "graph TD" in mermaid + assert 'api{{"REST API"}}' in mermaid + assert 'db[("Database")]' in mermaid + + def test_to_mermaid_with_relations(self): + """Should generate mermaid with relationships.""" + model = MentalModel() + model.elements["a"] = ModelElement(id="a", label="A") + model.elements["b"] = ModelElement(id="b", label="B") + model.relations.append(ModelRelation(source_id="a", target_id="b", label="uses")) + + mermaid = model.to_mermaid() + + assert 'a -->|"uses"| b' in mermaid + + def test_to_mermaid_with_colors(self): + """Should include style directives for colors.""" + model = MentalModel() + model.elements["test"] = ModelElement(id="test", label="Test", color="#ff0000") + + mermaid = model.to_mermaid() + + assert "style test fill:#ff0000" in mermaid + + def test_to_mermaid_shapes(self): + """Should use correct mermaid shapes.""" + model = MentalModel() + model.elements["box"] = ModelElement(id="box", label="Box", shape="box") + model.elements["rounded"] = ModelElement(id="rounded", label="Rounded", shape="rounded") + model.elements["circle"] = ModelElement(id="circle", label="Circle", shape="circle") + model.elements["diamond"] = ModelElement(id="diamond", label="Diamond", shape="diamond") + model.elements["cylinder"] = ModelElement(id="cylinder", label="Cylinder", shape="cylinder") + model.elements["hexagon"] = ModelElement(id="hexagon", label="Hexagon", shape="hexagon") + model.elements["stadium"] = ModelElement(id="stadium", label="Stadium", shape="stadium") + + mermaid = model.to_mermaid() + + assert 'box["Box"]' in mermaid + assert 'rounded("Rounded")' in mermaid + assert 'circle(("Circle"))' in mermaid + assert 'diamond{"Diamond"}' in mermaid + assert 'cylinder[("Cylinder")]' in mermaid + assert 'hexagon{{"Hexagon"}}' in mermaid + assert 'stadium(["Stadium"])' in mermaid + + def test_to_mermaid_relation_styles(self): + """Should use correct arrow styles for relations.""" + model = MentalModel() + model.elements["a"] = ModelElement(id="a", label="A") + model.elements["b"] = ModelElement(id="b", label="B") + model.elements["c"] = ModelElement(id="c", label="C") + model.elements["d"] = ModelElement(id="d", label="D") + + model.relations.append(ModelRelation(source_id="a", target_id="b", style="solid")) + model.relations.append(ModelRelation(source_id="b", target_id="c", style="dashed")) + model.relations.append(ModelRelation(source_id="c", target_id="d", style="thick")) + + mermaid = model.to_mermaid() + + assert "a --> b" in mermaid + assert "b -.-> c" in mermaid + assert "c ==> d" in mermaid + + def test_to_json_roundtrip(self): + """Should serialize and deserialize correctly.""" + model = MentalModel() + model.elements["test"] = ModelElement( + id="test", label="Test", shape="hexagon", color="#123456" + ) + model.relations.append(ModelRelation(source_id="test", target_id="test", label="self")) + model.version = 5 + model.ai_summary = "A test system" + + json_str = model.to_json() + restored = MentalModel.from_dict(json.loads(json_str)) + + assert len(restored.elements) == 1 + assert restored.elements["test"].label == "Test" + assert restored.elements["test"].shape == "hexagon" + assert len(restored.relations) == 1 + assert restored.relations[0].label == "self" + assert restored.version == 5 + assert restored.ai_summary == "A test system" + + def test_to_dict(self): + """Should convert to dictionary.""" + model = MentalModel() + model.elements["x"] = ModelElement(id="x", label="X") + model.version = 3 + + d = model.to_dict() + + assert "elements" in d + assert "relations" in d + assert "version" in d + assert d["version"] == 3 + + def test_sequences_in_model(self): + """Should support sequence diagrams.""" + model = MentalModel() + diagram = SequenceDiagram( + id="login", + title="Login Flow", + participants=["User", "Auth", "DB"], + ) + diagram.steps.append(SequenceStep( + from_participant="User", + to_participant="Auth", + message="Login", + )) + model.sequences["login"] = diagram + + assert len(model.sequences) == 1 + assert "login" in model.sequences + assert model.sequences["login"].title == "Login Flow" + + def test_to_dict_includes_sequences(self): + """Should include sequences in to_dict.""" + model = MentalModel() + model.sequences["test"] = SequenceDiagram( + id="test", + title="Test", + participants=["A", "B"], + ) + + d = model.to_dict() + + assert "sequences" in d + assert "test" in d["sequences"] + assert d["sequences"]["test"]["title"] == "Test" + + def test_from_dict_restores_sequences(self): + """Should restore sequences from dict.""" + d = { + "elements": {}, + "relations": [], + "sequences": { + "checkout": { + "id": "checkout", + "title": "Checkout", + "participants": ["Cart", "Payment"], + "steps": [ + {"from": "Cart", "to": "Payment", "message": "Pay", "type": "sync"} + ], + } + }, + "version": 1, + "ai_summary": "", + } + + model = MentalModel.from_dict(d) + + assert len(model.sequences) == 1 + assert "checkout" in model.sequences + assert model.sequences["checkout"].title == "Checkout" + assert len(model.sequences["checkout"].steps) == 1 + + def test_to_full_mermaid(self): + """Should export both component and sequence diagrams.""" + model = MentalModel() + model.elements["api"] = ModelElement(id="api", label="API") + model.sequences["flow"] = SequenceDiagram( + id="flow", + title="Request Flow", + participants=["Client", "Server"], + description="How requests work", + ) + model.sequences["flow"].steps.append(SequenceStep( + from_participant="Client", + to_participant="Server", + message="Request", + )) + + full = model.to_full_mermaid() + + # Should have component diagram + assert "## Component Diagram" in full + assert "graph TD" in full + assert "API" in full + + # Should have sequence diagram + assert "## Request Flow" in full + assert "How requests work" in full + assert "sequenceDiagram" in full + assert "Client->>Server: Request" in full + + def test_to_force_graph_empty(self): + """Empty model should produce empty nodes and links.""" + model = MentalModel() + graph = model.to_force_graph() + + assert graph["nodes"] == [] + assert graph["links"] == [] + + def test_to_force_graph_with_elements(self): + """Should export elements as nodes.""" + model = MentalModel() + model.elements["api"] = ModelElement( + id="api", + label="API Gateway", + shape="hexagon", + color="#e1f5fe", + reasoning="Entry point for all requests", + ) + model.elements["db"] = ModelElement( + id="db", + label="Database", + shape="cylinder", + properties={"size": 2, "group": "storage"}, + ) + + graph = model.to_force_graph() + + assert len(graph["nodes"]) == 2 + + api_node = next(n for n in graph["nodes"] if n["id"] == "api") + assert api_node["name"] == "API Gateway" + assert api_node["group"] == "hexagon" # falls back to shape + assert api_node["color"] == "#e1f5fe" + assert api_node["description"] == "Entry point for all requests" + assert api_node["val"] == 1 # default size + + db_node = next(n for n in graph["nodes"] if n["id"] == "db") + assert db_node["name"] == "Database" + assert db_node["group"] == "storage" # from properties + assert db_node["val"] == 2 # from properties.size + + def test_to_force_graph_with_relations(self): + """Should export relations as links.""" + model = MentalModel() + model.elements["a"] = ModelElement(id="a", label="A") + model.elements["b"] = ModelElement(id="b", label="B") + model.relations.append(ModelRelation( + source_id="a", + target_id="b", + label="calls", + style="dashed", + reasoning="A invokes B for processing", + )) + + graph = model.to_force_graph() + + assert len(graph["links"]) == 1 + link = graph["links"][0] + assert link["source"] == "a" + assert link["target"] == "b" + assert link["label"] == "calls" + assert link["style"] == "dashed" + assert link["description"] == "A invokes B for processing" + + def test_to_force_graph_includes_files(self): + """Should include associated files in nodes.""" + model = MentalModel() + model.elements["auth"] = ModelElement(id="auth", label="Auth") + model.element_files["auth"] = {"src/auth.py", "src/auth_utils.py"} + + graph = model.to_force_graph() + + auth_node = graph["nodes"][0] + assert "files" in auth_node + assert set(auth_node["files"]) == {"src/auth.py", "src/auth_utils.py"} + + def test_to_force_graph_omits_empty_fields(self): + """Should not include empty optional fields.""" + model = MentalModel() + model.elements["simple"] = ModelElement(id="simple", label="Simple") + model.relations.append(ModelRelation(source_id="simple", target_id="simple")) + + graph = model.to_force_graph() + + node = graph["nodes"][0] + assert "color" not in node + assert "description" not in node + assert "files" not in node + + link = graph["links"][0] + assert "label" not in link + assert "style" not in link + assert "description" not in link + + +class TestMentalModelEnhancer: + """Tests for MentalModelEnhancer class.""" + + def test_init_without_repo(self): + """Should initialize without repo path.""" + enhancer = MentalModelEnhancer() + assert enhancer.model is not None + assert enhancer._repo_path is None + + def test_init_with_nonexistent_repo(self, tmp_path): + """Should initialize with new repo path.""" + repo_path = tmp_path / "new_repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + + assert enhancer._repo_path == repo_path + assert len(enhancer.model.elements) == 0 + + def test_init_loads_existing_model(self, tmp_path): + """Should load existing model from repo.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + # Create existing model file + model_dir = repo_path / ".agentgit" + model_dir.mkdir() + model_file = model_dir / "mental_model.json" + + existing_model = { + "elements": { + "existing": { + "id": "existing", + "label": "Existing Element", + "shape": "box", + } + }, + "relations": [], + "version": 10, + "ai_summary": "Existing model", + } + model_file.write_text(json.dumps(existing_model)) + + enhancer = MentalModelEnhancer(repo_path) + + assert len(enhancer.model.elements) == 1 + assert "existing" in enhancer.model.elements + assert enhancer.model.version == 10 + + def test_model_path_property(self, tmp_path): + """Should return correct model path.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + + assert enhancer.model_path == repo_path / ".agentgit" / "mental_model.json" + + def test_model_path_none_without_repo(self): + """Should return None if no repo configured.""" + enhancer = MentalModelEnhancer() + assert enhancer.model_path is None + + def test_save_creates_directory(self, tmp_path): + """Should create .agentgit directory if needed.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + enhancer.model.elements["test"] = ModelElement(id="test", label="Test") + enhancer.save() + + assert (repo_path / ".agentgit").exists() + assert (repo_path / ".agentgit" / "mental_model.json").exists() + + def test_save_writes_model(self, tmp_path): + """Should write model to file.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + enhancer.model.elements["test"] = ModelElement(id="test", label="Test Element") + enhancer.model.version = 42 + enhancer.save() + + content = (repo_path / ".agentgit" / "mental_model.json").read_text() + data = json.loads(content) + + assert "test" in data["elements"] + assert data["elements"]["test"]["label"] == "Test Element" + assert data["version"] == 42 + + def test_save_insights_creates_file(self, tmp_path): + """Should create insights file with header.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + path = enhancer.save_insights() + + assert path is not None + assert path.exists() + content = path.read_text() + assert "# Reframe Insights" in content + + def test_save_insights_appends(self, tmp_path): + """Should append insights to file.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + enhancer.save_insights("First insight") + enhancer.save_insights("Second insight") + + content = (repo_path / ".agentgit" / "mental_model.md").read_text() + + assert "First insight" in content + assert "Second insight" in content + + def test_save_insights_includes_timestamp(self, tmp_path): + """Should include timestamp with each insight.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + enhancer.save_insights("Test insight") + + content = (repo_path / ".agentgit" / "mental_model.md").read_text() + + # Should have a timestamp header like "## 2026-01-10 14:30" + import re + assert re.search(r"## \d{4}-\d{2}-\d{2} \d{2}:\d{2}", content) + + def test_load_insights_returns_content(self, tmp_path): + """Should load existing insights.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + enhancer.save_insights("Test insight content") + + loaded = enhancer.load_insights() + + assert "Test insight content" in loaded + + def test_load_insights_empty_if_no_file(self, tmp_path): + """Should return empty string if no insights file.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + loaded = enhancer.load_insights() + + assert loaded == "" + + def test_add_human_insight(self, tmp_path): + """Should add human insight with marker.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + enhancer.add_human_insight("The checkout is actually three steps") + + content = (repo_path / ".agentgit" / "mental_model.md").read_text() + + assert "**Human insight:**" in content + assert "The checkout is actually three steps" in content + + +class TestPromptBuilding: + """Tests for prompt generation functions.""" + + def test_build_observation_prompt_empty_model(self): + """Should handle empty model.""" + model = MentalModel() + context = {"prompts": ["Add auth"], "files_changed": ["auth.py"]} + + prompt = build_observation_prompt(model, context) + + assert "mental model" in prompt.lower() + assert "(empty)" in prompt + assert "auth.py" in prompt + + def test_build_observation_prompt_with_model(self): + """Should include current model diagram.""" + model = MentalModel() + model.elements["test"] = ModelElement(id="test", label="Test Component") + context = {"prompts": [], "files_changed": []} + + prompt = build_observation_prompt(model, context) + + assert "Test Component" in prompt + assert "```mermaid" in prompt + + def test_build_observation_prompt_with_insights(self): + """Should include accumulated insights.""" + model = MentalModel() + context = {"prompts": [], "files_changed": []} + insights = "Previous insight: The system uses microservices" + + prompt = build_observation_prompt(model, context, insights) + + assert "Previous insight" in prompt + assert "microservices" in prompt + + def test_build_observation_prompt_truncates_long_insights(self): + """Should truncate very long insights.""" + model = MentalModel() + context = {} + long_insights = "x" * 5000 + + prompt = build_observation_prompt(model, context, long_insights) + + assert "(earlier insights truncated)" in prompt + + def test_build_observation_prompt_includes_summary(self): + """Should include AI summary if present.""" + model = MentalModel() + model.ai_summary = "An e-commerce platform" + context = {} + + prompt = build_observation_prompt(model, context) + + assert "An e-commerce platform" in prompt + + def test_build_instruction_prompt_with_selection(self): + """Should include selected elements.""" + model = MentalModel() + model.elements["checkout"] = ModelElement( + id="checkout", + label="Checkout", + reasoning="Where purchases happen", + ) + + prompt = build_instruction_prompt( + model, + selected_ids=["checkout"], + instruction="Split this into three steps", + ) + + assert "Checkout" in prompt + assert "Split this into three steps" in prompt + assert "Where purchases happen" in prompt + + def test_build_instruction_prompt_empty_selection(self): + """Should handle empty selection.""" + model = MentalModel() + + prompt = build_instruction_prompt( + model, + selected_ids=[], + instruction="Add a database component", + ) + + assert "Add a database component" in prompt + assert "No specific elements" in prompt + + def test_build_focused_prompt_explore(self): + """Should generate explore prompt.""" + model = MentalModel() + model.elements["auth"] = ModelElement(id="auth", label="Authentication") + + prompt = build_focused_prompt(model, "auth", "explore") + + assert "Authentication" in prompt + assert "Tell me" in prompt or "tell me" in prompt + + def test_build_focused_prompt_debug(self): + """Should generate debug prompt.""" + model = MentalModel() + model.elements["auth"] = ModelElement(id="auth", label="Authentication") + + prompt = build_focused_prompt(model, "auth", "debug") + + assert "issues" in prompt.lower() or "help" in prompt.lower() + + def test_build_focused_prompt_unknown_element(self): + """Should handle unknown element.""" + model = MentalModel() + + prompt = build_focused_prompt(model, "nonexistent", "explore") + + assert "not found" in prompt.lower() + + def test_build_focused_prompt_mentions_element(self): + """Should mention the element in the prompt.""" + model = MentalModel() + model.elements["auth"] = ModelElement(id="auth", label="Authentication") + model.elements["users"] = ModelElement(id="users", label="User Store") + model.relations.append( + ModelRelation(source_id="auth", target_id="users", label="validates") + ) + + prompt = build_focused_prompt(model, "auth", "explore") + + # Prompt should mention the element being explored + assert "Authentication" in prompt + + +class TestHelperFunctions: + """Tests for helper functions.""" + + def test_truncate_short_text(self): + """Should not truncate short text.""" + result = _truncate("short", 100) + assert result == "short" + + def test_truncate_long_text(self): + """Should truncate and add ellipsis.""" + result = _truncate("a" * 100, 10) + assert len(result) == 10 + assert result.endswith("...") + + def test_truncate_exact_length(self): + """Should not truncate at exact length.""" + result = _truncate("exact", 5) + assert result == "exact" + + +class TestGlobalFunctions: + """Tests for module-level functions.""" + + def test_load_mental_model_exists(self, tmp_path): + """Should load model from repo.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + model_dir = repo_path / ".agentgit" + model_dir.mkdir() + + model_data = { + "elements": {"test": {"id": "test", "label": "Test", "shape": "box"}}, + "relations": [], + "version": 5, + "ai_summary": "Test model", + } + (model_dir / "mental_model.json").write_text(json.dumps(model_data)) + + loaded = load_mental_model(repo_path) + + assert loaded is not None + assert "test" in loaded.elements + assert loaded.version == 5 + + def test_load_mental_model_not_exists(self, tmp_path): + """Should return None if model doesn't exist.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + loaded = load_mental_model(repo_path) + + assert loaded is None + + def test_get_mental_model_enhancer_creates_instance(self): + """Should create enhancer instance.""" + # Reset global state + import agentgit.enhancers.mental_model as mm + mm._enhancer_instance = None + + enhancer = get_mental_model_enhancer() + + assert enhancer is not None + assert isinstance(enhancer, MentalModelEnhancer) + + def test_get_mental_model_enhancer_reuses_instance(self): + """Should reuse existing instance.""" + # Reset global state + import agentgit.enhancers.mental_model as mm + mm._enhancer_instance = None + + enhancer1 = get_mental_model_enhancer() + enhancer2 = get_mental_model_enhancer() + + assert enhancer1 is enhancer2 + + def test_get_mental_model_enhancer_new_repo(self, tmp_path): + """Should create new instance for different repo.""" + import agentgit.enhancers.mental_model as mm + mm._enhancer_instance = None + + repo1 = tmp_path / "repo1" + repo1.mkdir() + repo2 = tmp_path / "repo2" + repo2.mkdir() + + enhancer1 = get_mental_model_enhancer(repo1) + enhancer2 = get_mental_model_enhancer(repo2) + + assert enhancer1._repo_path == repo1 + assert enhancer2._repo_path == repo2 + + +class TestIntegration: + """Integration tests for the full reframe flow.""" + + def test_full_model_evolution(self, tmp_path): + """Test a complete model evolution scenario.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + enhancer = MentalModelEnhancer(repo_path) + + # Step 1: Initial model + enhancer.model.elements["browse"] = ModelElement( + id="browse", label="Product Browser", shape="stadium" + ) + enhancer.model.elements["cart"] = ModelElement( + id="cart", label="Shopping Cart", shape="rounded" + ) + enhancer.model.relations.append( + ModelRelation(source_id="browse", target_id="cart", label="add to") + ) + enhancer.model.version = 1 + enhancer.model.ai_summary = "E-commerce product browsing" + enhancer.save() + + # Step 2: Human adds insight + enhancer.add_human_insight("The cart should support wishlists too") + + # Step 3: Snapshot and evolve + enhancer.model.snapshot("before wishlist") + enhancer.model.elements["wishlist"] = ModelElement( + id="wishlist", label="Wishlist", shape="rounded" + ) + enhancer.model.relations.append( + ModelRelation(source_id="browse", target_id="wishlist", label="save for later") + ) + enhancer.model.version = 2 + enhancer.save() + + # Verify final state + loaded = load_mental_model(repo_path) + assert len(loaded.elements) == 3 + assert "wishlist" in loaded.elements + assert loaded.version == 2 + + # Verify insights + insights = enhancer.load_insights() + assert "wishlists" in insights + + # Verify snapshots + assert len(enhancer.model.snapshots) == 1 + assert len(enhancer.model.snapshots[0].elements) == 2 # Before wishlist + + def test_mermaid_output_is_valid(self, tmp_path): + """Test that generated mermaid is syntactically valid.""" + model = MentalModel() + + # Add various elements and relations + model.elements["api"] = ModelElement( + id="api", label="API Gateway", shape="hexagon", color="#e1f5fe" + ) + model.elements["auth"] = ModelElement( + id="auth", label="Auth Service", shape="box", color="#fff3e0" + ) + model.elements["db"] = ModelElement( + id="db", label="Database", shape="cylinder", color="#e8f5e9" + ) + + model.relations.append( + ModelRelation(source_id="api", target_id="auth", label="authenticates", style="solid") + ) + model.relations.append( + ModelRelation(source_id="auth", target_id="db", label="queries", style="dashed") + ) + + mermaid = model.to_mermaid() + + # Check structure + lines = mermaid.strip().split("\n") + assert lines[0].strip() == "graph TD" + + # Check no syntax errors (basic validation) + assert mermaid.count("(") == mermaid.count(")") + assert mermaid.count("[") == mermaid.count("]") + assert mermaid.count("{") == mermaid.count("}") diff --git a/tests/test_static_analysis.py b/tests/test_static_analysis.py new file mode 100644 index 0000000..be835b1 --- /dev/null +++ b/tests/test_static_analysis.py @@ -0,0 +1,516 @@ +"""Tests for static analysis module.""" + +import pytest +import textwrap +from pathlib import Path + +from agentgit.enhancers.static_analysis import ( + CodeEntity, + CodeRelation, + CodeStructure, + analyze_python_simple, + analyze_javascript_simple, + analyze_code, + structure_to_context, +) + + +class TestCodeEntity: + """Tests for CodeEntity dataclass.""" + + def test_entity_creation(self): + """Should create entity with required fields.""" + entity = CodeEntity( + name="UserService", + entity_type="class", + file_path="src/services/user.py", + ) + assert entity.name == "UserService" + assert entity.entity_type == "class" + assert entity.file_path == "src/services/user.py" + assert entity.line_number is None + assert entity.docstring is None + assert entity.properties == {} + + def test_entity_with_all_fields(self): + """Should create entity with all fields.""" + entity = CodeEntity( + name="authenticate", + entity_type="function", + file_path="auth.py", + line_number=42, + docstring="Authenticate a user.", + properties={"async": True, "public": True}, + ) + assert entity.line_number == 42 + assert entity.docstring == "Authenticate a user." + assert entity.properties["async"] is True + + +class TestCodeRelation: + """Tests for CodeRelation dataclass.""" + + def test_relation_creation(self): + """Should create relation with required fields.""" + rel = CodeRelation( + source="UserService", + target="Database", + relation_type="uses", + file_path="services.py", + ) + assert rel.source == "UserService" + assert rel.target == "Database" + assert rel.relation_type == "uses" + assert rel.file_path == "services.py" + + def test_relation_types(self): + """Should support various relation types.""" + for rel_type in ["calls", "imports", "inherits", "implements", "uses"]: + rel = CodeRelation( + source="A", + target="B", + relation_type=rel_type, + file_path="test.py", + ) + assert rel.relation_type == rel_type + + +class TestCodeStructure: + """Tests for CodeStructure dataclass.""" + + def test_empty_structure(self): + """Should create empty structure.""" + structure = CodeStructure() + assert structure.entities == [] + assert structure.relations == [] + assert structure.language == "unknown" + assert structure.analysis_method == "unknown" + + def test_structure_with_data(self): + """Should hold entities and relations.""" + structure = CodeStructure( + language="python", + analysis_method="regex", + ) + structure.entities.append(CodeEntity( + name="TestClass", + entity_type="class", + file_path="test.py", + )) + structure.relations.append(CodeRelation( + source="test.py", + target="os", + relation_type="imports", + file_path="test.py", + )) + + assert len(structure.entities) == 1 + assert len(structure.relations) == 1 + assert structure.language == "python" + + +class TestAnalyzePythonSimple: + """Tests for analyze_python_simple function.""" + + def test_finds_classes(self, tmp_path): + """Should find class definitions.""" + py_file = tmp_path / "test.py" + py_file.write_text(""" +class UserService: + pass + +class OrderService(BaseService): + pass +""") + + structure = analyze_python_simple(tmp_path) + + class_entities = [e for e in structure.entities if e.entity_type == "class"] + assert len(class_entities) == 2 + names = [e.name for e in class_entities] + assert "UserService" in names + assert "OrderService" in names + + def test_finds_functions(self, tmp_path): + """Should find function definitions.""" + py_file = tmp_path / "test.py" + py_file.write_text(textwrap.dedent("""\ + def top_level_func(): + pass + + def another_function(arg): + return arg + """)) + + structure = analyze_python_simple(tmp_path) + + func_entities = [e for e in structure.entities if e.entity_type == "function"] + assert len(func_entities) == 2 + names = [e.name for e in func_entities] + assert "top_level_func" in names + assert "another_function" in names + + def test_finds_methods(self, tmp_path): + """Should find method definitions.""" + py_file = tmp_path / "test.py" + py_file.write_text(""" +class MyClass: + def my_method(self): + pass + + def another_method(self): + pass +""") + + structure = analyze_python_simple(tmp_path) + + method_entities = [e for e in structure.entities if e.entity_type == "method"] + assert len(method_entities) == 2 + + def test_tracks_inheritance(self, tmp_path): + """Should track inheritance relations.""" + py_file = tmp_path / "test.py" + py_file.write_text(""" +class Child(Parent): + pass + +class MultiChild(Parent1, Parent2): + pass +""") + + structure = analyze_python_simple(tmp_path) + + inherit_rels = [r for r in structure.relations if r.relation_type == "inherits"] + assert len(inherit_rels) >= 2 + + targets = [r.target for r in inherit_rels] + assert "Parent" in targets + assert "Parent1" in targets + + def test_tracks_imports(self, tmp_path): + """Should track import relations.""" + py_file = tmp_path / "test.py" + py_file.write_text(""" +import os +from pathlib import Path +from typing import List, Dict +""") + + structure = analyze_python_simple(tmp_path) + + import_rels = [r for r in structure.relations if r.relation_type == "imports"] + assert len(import_rels) >= 3 + + targets = [r.target for r in import_rels] + assert "os" in targets + assert "pathlib.Path" in targets + + def test_handles_single_file(self, tmp_path): + """Should handle single file path.""" + py_file = tmp_path / "single.py" + py_file.write_text(""" +class SingleClass: + pass +""") + + structure = analyze_python_simple(py_file) + + assert len(structure.entities) >= 1 + assert structure.language == "python" + assert structure.analysis_method == "regex" + + def test_ignores_object_base(self, tmp_path): + """Should ignore 'object' as base class.""" + py_file = tmp_path / "test.py" + py_file.write_text(""" +class MyClass(object): + pass +""") + + structure = analyze_python_simple(tmp_path) + + inherit_rels = [r for r in structure.relations if r.relation_type == "inherits"] + targets = [r.target for r in inherit_rels] + assert "object" not in targets + + def test_includes_line_numbers(self, tmp_path): + """Should include line numbers.""" + py_file = tmp_path / "test.py" + py_file.write_text(""" +# Line 1 is blank + +class MyClass: + pass +""") + + structure = analyze_python_simple(tmp_path) + + class_entities = [e for e in structure.entities if e.entity_type == "class"] + assert len(class_entities) == 1 + assert class_entities[0].line_number is not None + assert class_entities[0].line_number > 0 + + +class TestAnalyzeJavaScriptSimple: + """Tests for analyze_javascript_simple function.""" + + def test_finds_classes(self, tmp_path): + """Should find class definitions.""" + js_file = tmp_path / "test.js" + js_file.write_text(""" +class UserService { + constructor() {} +} + +class OrderService extends BaseService { + process() {} +} +""") + + structure = analyze_javascript_simple(tmp_path) + + class_entities = [e for e in structure.entities if e.entity_type == "class"] + assert len(class_entities) == 2 + names = [e.name for e in class_entities] + assert "UserService" in names + assert "OrderService" in names + + def test_tracks_extends(self, tmp_path): + """Should track extends relations.""" + js_file = tmp_path / "test.js" + js_file.write_text(""" +class Child extends Parent { +} +""") + + structure = analyze_javascript_simple(tmp_path) + + extend_rels = [r for r in structure.relations if r.relation_type == "extends"] + assert len(extend_rels) == 1 + assert extend_rels[0].target == "Parent" + + def test_finds_functions(self, tmp_path): + """Should find function definitions.""" + js_file = tmp_path / "test.js" + js_file.write_text(""" +function regularFunction() { + return 1; +} + +const arrowFunc = () => { + return 2; +}; + +const anotherArrow = (x) => x * 2; +""") + + structure = analyze_javascript_simple(tmp_path) + + func_entities = [e for e in structure.entities if e.entity_type == "function"] + names = [e.name for e in func_entities] + assert "regularFunction" in names + assert "arrowFunc" in names + + def test_finds_react_components(self, tmp_path): + """Should find React components (PascalCase functions).""" + tsx_file = tmp_path / "test.tsx" + tsx_file.write_text(textwrap.dedent("""\ + function UserProfile() { + return
Profile
; + } + + const OrderList = () => { + return
    ; + }; + """)) + + structure = analyze_javascript_simple(tmp_path) + + # React components are detected as functions (PascalCase naming) + names = [e.name for e in structure.entities] + # Should find the PascalCase function declarations + assert "UserProfile" in names + assert "OrderList" in names + + def test_tracks_imports(self, tmp_path): + """Should track import relations.""" + js_file = tmp_path / "test.js" + js_file.write_text(""" +import React from 'react'; +import { useState, useEffect } from 'react'; +import axios from './utils/axios'; +""") + + structure = analyze_javascript_simple(tmp_path) + + import_rels = [r for r in structure.relations if r.relation_type == "imports"] + targets = [r.target for r in import_rels] + assert "react" in targets + assert "./utils/axios" in targets + + def test_handles_typescript(self, tmp_path): + """Should handle TypeScript files.""" + ts_file = tmp_path / "test.ts" + ts_file.write_text(""" +interface User { + name: string; +} + +class UserService { + getUser(): User { + return { name: 'test' }; + } +} +""") + + structure = analyze_javascript_simple(tmp_path) + + class_entities = [e for e in structure.entities if e.entity_type == "class"] + assert len(class_entities) == 1 + assert class_entities[0].name == "UserService" + + def test_skips_node_modules(self, tmp_path): + """Should skip node_modules directory.""" + node_modules = tmp_path / "node_modules" / "some-package" + node_modules.mkdir(parents=True) + (node_modules / "index.js").write_text("class ExternalClass {}") + + src = tmp_path / "src" + src.mkdir() + (src / "app.js").write_text("class AppClass {}") + + structure = analyze_javascript_simple(tmp_path) + + names = [e.name for e in structure.entities] + assert "AppClass" in names + assert "ExternalClass" not in names + + +class TestAnalyzeCode: + """Tests for analyze_code function.""" + + def test_detects_python_project(self, tmp_path): + """Should detect and analyze Python project.""" + (tmp_path / "main.py").write_text("class Main: pass") + (tmp_path / "utils.py").write_text("def helper(): pass") + + structure = analyze_code(tmp_path) + + assert structure.language == "python" + assert len(structure.entities) >= 2 + + def test_detects_javascript_project(self, tmp_path): + """Should detect and analyze JavaScript project.""" + (tmp_path / "app.js").write_text("function app() {}") + (tmp_path / "utils.js").write_text("const helper = () => {}") + (tmp_path / "index.ts").write_text("class Index {}") + + structure = analyze_code(tmp_path) + + assert structure.language == "javascript" + assert len(structure.entities) >= 2 + + def test_handles_single_python_file(self, tmp_path): + """Should handle single Python file.""" + py_file = tmp_path / "single.py" + py_file.write_text("class Single: pass") + + structure = analyze_code(py_file) + + assert structure.language == "python" + assert len(structure.entities) == 1 + + def test_handles_single_javascript_file(self, tmp_path): + """Should handle single JavaScript file.""" + js_file = tmp_path / "single.js" + js_file.write_text("class Single {}") + + structure = analyze_code(js_file) + + assert structure.language == "javascript" + assert len(structure.entities) >= 1 + + def test_mixed_project_uses_majority(self, tmp_path): + """Should use majority language for mixed projects.""" + # More Python files + (tmp_path / "a.py").write_text("class A: pass") + (tmp_path / "b.py").write_text("class B: pass") + (tmp_path / "c.py").write_text("class C: pass") + (tmp_path / "one.js").write_text("class One {}") + + structure = analyze_code(tmp_path) + + assert structure.language == "python" + + +class TestStructureToContext: + """Tests for structure_to_context function.""" + + def test_groups_entities_by_type(self): + """Should group entities by type.""" + structure = CodeStructure(language="python", analysis_method="regex") + structure.entities.append(CodeEntity(name="UserService", entity_type="class", file_path="a.py")) + structure.entities.append(CodeEntity(name="OrderService", entity_type="class", file_path="b.py")) + structure.entities.append(CodeEntity(name="helper", entity_type="function", file_path="c.py")) + + context = structure_to_context(structure) + + assert context["language"] == "python" + assert context["analysis_method"] == "regex" + assert "class" in context["entities"] + assert "function" in context["entities"] + assert len(context["entities"]["class"]) == 2 + assert len(context["entities"]["function"]) == 1 + + def test_includes_entity_counts(self): + """Should include entity counts.""" + structure = CodeStructure() + for i in range(25): + structure.entities.append(CodeEntity(name=f"Class{i}", entity_type="class", file_path="x.py")) + + context = structure_to_context(structure) + + assert context["entity_counts"]["class"] == 25 + # But entities list should be limited + assert len(context["entities"]["class"]) <= 20 + + def test_summarizes_relations(self): + """Should summarize relation counts.""" + structure = CodeStructure() + for i in range(5): + structure.relations.append(CodeRelation( + source=f"Class{i}", + target="Base", + relation_type="inherits", + file_path="x.py", + )) + for i in range(3): + structure.relations.append(CodeRelation( + source="x.py", + target=f"module{i}", + relation_type="imports", + file_path="x.py", + )) + + context = structure_to_context(structure) + + assert context["relation_counts"]["inherits"] == 5 + assert context["relation_counts"]["imports"] == 3 + + def test_includes_sample_relations(self): + """Should include sample relations.""" + structure = CodeStructure() + for i in range(15): + structure.relations.append(CodeRelation( + source=f"A{i}", + target=f"B{i}", + relation_type="calls", + file_path="x.py", + )) + + context = structure_to_context(structure) + + assert len(context["sample_relations"]) <= 10 + assert context["sample_relations"][0]["from"] == "A0" + assert context["sample_relations"][0]["to"] == "B0" + assert context["sample_relations"][0]["type"] == "calls"