#!/usr/bin/env python3
"""
Document Decomposer for Visual Generators

Parses documents (PDF, Markdown, text) and extracts structured content
optimized for sketchnote and infographic generation.

Supports two parsing modes:
- Heuristic: Fast regex-based parsing (default)
- LLM: Intelligent parsing using Google Gemini for complex documents

Author: Claude Code
License: MIT
"""

import os
import re
import json
import argparse
from pathlib import Path
from typing import Optional, List, Dict, Any
from dataclasses import dataclass, field, asdict
from collections import Counter


# Configuration from environment
MAX_SECTIONS = int(os.environ.get("DOC_DECOMPOSER_MAX_SECTIONS", 6))
MAX_CONCEPTS = int(os.environ.get("DOC_DECOMPOSER_MAX_CONCEPTS", 8))
MAX_QUOTES = int(os.environ.get("DOC_DECOMPOSER_MAX_QUOTES", 3))

# LLM Configuration
LLM_MODEL = os.environ.get("DOC_DECOMPOSER_LLM_MODEL", "gemini-2.0-flash")

# Text length constraints for AI image generation
MAX_TITLE_CHARS = 40
MAX_HEADER_WORDS = 5
MAX_POINT_WORDS = 6
MAX_CONCEPT_WORDS = 3
MAX_DESCRIPTION_WORDS = 10


@dataclass
class Section:
    """A document section with header and key points."""
    header: str
    key_points: List[str] = field(default_factory=list)
    icon_hint: str = ""


@dataclass
class Concept:
    """A key concept with term and description."""
    term: str
    description: str = ""


@dataclass
class Relationship:
    """A relationship between concepts."""
    from_concept: str
    to_concept: str
    rel_type: str = "relates_to"


@dataclass
class Metadata:
    """Document metadata."""
    source_type: str = "text"
    source_path: str = ""
    word_count: int = 0
    complexity: str = "medium"


@dataclass
class DecomposedDocument:
    """The fully decomposed document structure."""
    title: str = ""
    subtitle: str = ""
    sections: List[Section] = field(default_factory=list)
    key_concepts: List[Concept] = field(default_factory=list)
    relationships: List[Relationship] = field(default_factory=list)
    quotes: List[str] = field(default_factory=list)
    metadata: Metadata = field(default_factory=Metadata)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return {
            "title": self.title,
            "subtitle": self.subtitle,
            "sections": [
                {
                    "header": s.header,
                    "key_points": s.key_points,
                    "icon_hint": s.icon_hint
                }
                for s in self.sections
            ],
            "key_concepts": [
                {"term": c.term, "description": c.description}
                for c in self.key_concepts
            ],
            "relationships": [
                {"from": r.from_concept, "to": r.to_concept, "type": r.rel_type}
                for r in self.relationships
            ],
            "quotes": self.quotes,
            "metadata": asdict(self.metadata)
        }

    def save(self, path: str) -> None:
        """Save to JSON file."""
        with open(path, "w") as f:
            json.dump(self.to_dict(), f, indent=2)
        print(f"Saved to: {path}")

    def for_sketchnote(self) -> Dict[str, Any]:
        """Format for sketchnote generator input."""
        return {
            "title": self.title,
            "concepts": [
                f"{s.header}: {', '.join(s.key_points[:2])}"
                for s in self.sections[:6]
            ],
            "quotes": self.quotes[:2],
            "style_hints": {
                "layout": "dense",
                "icons": [s.icon_hint for s in self.sections if s.icon_hint]
            }
        }

    def for_infographic(self) -> Dict[str, Any]:
        """Format for infographic generator input."""
        # Build markdown-style content for infographic parser
        lines = [f"# {self.title}"]
        if self.subtitle:
            lines.append(f"\n{self.subtitle}\n")

        for section in self.sections[:5]:
            lines.append(f"\n## {section.header}")
            for point in section.key_points[:3]:
                lines.append(f"- {point}")

        return {
            "raw_text": "\n".join(lines),
            "headers": [s.header for s in self.sections[:5]],
            "keywords": [c.term for c in self.key_concepts[:5]]
        }


class DocumentDecomposer:
    """
    Decomposes documents into structured content for visual generators.

    Handles PDF, Markdown, and plain text formats, extracting:
    - Title and subtitle
    - Section headers with key points
    - Key concepts and relationships
    - Notable quotes
    """

    # Common stop words to filter from keywords
    STOP_WORDS = {
        "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
        "of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
        "being", "have", "has", "had", "do", "does", "did", "will", "would",
        "could", "should", "may", "might", "must", "shall", "can", "this",
        "that", "these", "those", "i", "you", "he", "she", "it", "we", "they",
        "what", "which", "who", "when", "where", "why", "how", "all", "each",
        "every", "both", "few", "more", "most", "other", "some", "such", "no",
        "not", "only", "own", "same", "so", "than", "too", "very", "just",
        "your", "our", "their", "my", "its", "as", "if", "then", "because",
    }

    # Icon hints based on keywords
    ICON_MAPPINGS = {
        "security": "shield",
        "cyber": "shield",
        "governance": "hierarchy",
        "management": "gear",
        "team": "people",
        "regional": "globe",
        "reporting": "chart",
        "incident": "alert",
        "training": "book",
        "assessment": "checklist",
        "compliance": "checkmark",
        "risk": "warning",
        "data": "database",
        "network": "network",
        "cloud": "cloud",
        "automation": "robot",
        "process": "flowchart",
        "budget": "money",
        "time": "clock",
        "communication": "speech",
    }

    # LLM prompt template for document decomposition
    LLM_PROMPT_TEMPLATE = '''Analyze this document and extract structured content for visual generation (infographics/sketchnotes).

IMPORTANT TEXT CONSTRAINTS (for AI image generation accuracy):
- Title: Maximum 40 characters
- Section headers: Maximum 5 words each
- Key points: Maximum 6 words each
- Concept terms: Maximum 3 words each
- Descriptions: Maximum 10 words each

Extract and return ONLY valid JSON in this exact format:
{{
  "title": "Short descriptive title (max 40 chars)",
  "subtitle": "Optional brief subtitle",
  "sections": [
    {{
      "header": "Section Name (max 5 words)",
      "key_points": ["Point 1 (3-6 words)", "Point 2", "Point 3"],
      "icon_hint": "shield|gear|people|globe|chart|alert|book|checklist|checkmark|warning|database|network|cloud|flowchart|money|clock|document"
    }}
  ],
  "key_concepts": [
    {{"term": "Concept (1-3 words)", "description": "Brief description (max 10 words)"}}
  ],
  "relationships": [
    {{"from": "Concept A", "to": "Concept B", "type": "leads_to|supports|requires|enables"}}
  ],
  "quotes": ["Notable short quote if any"],
  "summary": "One sentence summary of the document"
}}

Rules:
1. Extract 4-6 main sections maximum
2. Each section should have 2-4 key points
3. Identify 5-8 key concepts
4. Find logical relationships between concepts
5. Keep all text SHORT and CONCISE - this is critical
6. Use clear, simple language
7. Icon hints should match the content theme

DOCUMENT TO ANALYZE:
---
{document_text}
---

Return ONLY the JSON, no explanation or markdown formatting.'''

    def __init__(self, api_key: Optional[str] = None):
        """
        Initialize the decomposer.

        Args:
            api_key: Google API key for LLM parsing. Falls back to GOOGLE_API_KEY env var.
        """
        self.api_key = api_key or os.environ.get("GOOGLE_API_KEY")
        self._llm_client = None

    def _get_llm_client(self):
        """Get or create the LLM client."""
        if self._llm_client is None:
            if not self.api_key:
                raise ValueError(
                    "LLM parsing requires a Google API key. "
                    "Pass api_key or set GOOGLE_API_KEY environment variable."
                )
            try:
                import google.generativeai as genai
                genai.configure(api_key=self.api_key)
                self._llm_client = genai.GenerativeModel(LLM_MODEL)
            except ImportError:
                raise ImportError(
                    "LLM parsing requires google-generativeai. "
                    "Install with: pip install google-generativeai"
                )
        return self._llm_client

    def decompose_file(
        self,
        path: str,
        use_llm: bool = False
    ) -> DecomposedDocument:
        """
        Decompose a file into structured content.

        Args:
            path: Path to the document file.
            use_llm: Use LLM-based parsing for better results (requires API key).

        Returns:
            DecomposedDocument with extracted structure.
        """
        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"File not found: {path}")

        suffix = path.suffix.lower()

        if suffix == ".pdf":
            text = self._extract_pdf(path)
            source_type = "pdf"
        elif suffix in (".md", ".markdown"):
            text = path.read_text(encoding="utf-8")
            source_type = "markdown"
        elif suffix == ".docx":
            text = self._extract_docx(path)
            source_type = "docx"
        else:
            text = path.read_text(encoding="utf-8")
            source_type = "text"

        doc = self.decompose_text(text, source_type, use_llm=use_llm)
        doc.metadata.source_path = str(path)
        return doc

    def decompose_text(
        self,
        text: str,
        source_type: str = "text",
        use_llm: bool = False
    ) -> DecomposedDocument:
        """
        Decompose raw text into structured content.

        Args:
            text: The raw text content.
            source_type: Type of source (pdf, markdown, text).
            use_llm: Use LLM-based parsing for better results.

        Returns:
            DecomposedDocument with extracted structure.
        """
        if use_llm:
            return self._decompose_with_llm(text, source_type)

        # Heuristic-based parsing
        doc = DecomposedDocument()
        doc.metadata.source_type = source_type
        doc.metadata.word_count = len(text.split())

        # Determine complexity
        doc.metadata.complexity = self._assess_complexity(text)

        # Extract title
        doc.title = self._extract_title(text)

        # Extract sections
        doc.sections = self._extract_sections(text)

        # Extract key concepts
        doc.key_concepts = self._extract_concepts(text)

        # Extract relationships (simple heuristic)
        doc.relationships = self._extract_relationships(doc.sections)

        # Extract quotes
        doc.quotes = self._extract_quotes(text)

        return doc

    def _decompose_with_llm(
        self,
        text: str,
        source_type: str = "text"
    ) -> DecomposedDocument:
        """
        Decompose text using LLM for intelligent parsing.

        Args:
            text: The raw text content.
            source_type: Type of source.

        Returns:
            DecomposedDocument with LLM-extracted structure.
        """
        print("Using LLM-based parsing (Gemini)...")

        # Truncate very long documents to stay within token limits
        max_chars = 30000
        if len(text) > max_chars:
            print(f"  Truncating document from {len(text)} to {max_chars} chars")
            text = text[:max_chars]

        # Build prompt
        prompt = self.LLM_PROMPT_TEMPLATE.format(document_text=text)

        # Call LLM
        client = self._get_llm_client()
        response = client.generate_content(prompt)

        # Parse JSON response
        response_text = response.text.strip()

        # Clean up response - remove markdown code blocks if present
        if response_text.startswith("```"):
            lines = response_text.split("\n")
            # Remove first and last lines (```json and ```)
            lines = [l for l in lines if not l.startswith("```")]
            response_text = "\n".join(lines)

        try:
            data = json.loads(response_text)
        except json.JSONDecodeError as e:
            print(f"  Warning: Failed to parse LLM response as JSON: {e}")
            print(f"  Falling back to heuristic parsing")
            return self.decompose_text(text, source_type, use_llm=False)

        # Build DecomposedDocument from LLM response
        doc = DecomposedDocument()
        doc.metadata.source_type = source_type
        doc.metadata.word_count = len(text.split())
        doc.metadata.complexity = self._assess_complexity(text)

        # Extract fields from LLM response
        doc.title = data.get("title", "Untitled")[:MAX_TITLE_CHARS]
        doc.subtitle = data.get("subtitle", "")

        # Sections
        for section_data in data.get("sections", [])[:MAX_SECTIONS]:
            doc.sections.append(Section(
                header=self._truncate_words(section_data.get("header", ""), MAX_HEADER_WORDS),
                key_points=[
                    self._truncate_words(p, MAX_POINT_WORDS)
                    for p in section_data.get("key_points", [])[:4]
                ],
                icon_hint=section_data.get("icon_hint", "document")
            ))

        # Key concepts
        for concept_data in data.get("key_concepts", [])[:MAX_CONCEPTS]:
            doc.key_concepts.append(Concept(
                term=self._truncate_words(concept_data.get("term", ""), MAX_CONCEPT_WORDS),
                description=self._truncate_words(concept_data.get("description", ""), MAX_DESCRIPTION_WORDS)
            ))

        # Relationships
        for rel_data in data.get("relationships", [])[:5]:
            doc.relationships.append(Relationship(
                from_concept=rel_data.get("from", ""),
                to_concept=rel_data.get("to", ""),
                rel_type=rel_data.get("type", "relates_to")
            ))

        # Quotes
        doc.quotes = [
            self._truncate_words(q, 15)
            for q in data.get("quotes", [])[:MAX_QUOTES]
        ]

        print(f"  LLM parsing complete: {len(doc.sections)} sections, {len(doc.key_concepts)} concepts")
        return doc

    def _extract_pdf(self, path: Path) -> str:
        """Extract text from PDF."""
        try:
            import pdfplumber
            text_parts = []
            with pdfplumber.open(path) as pdf:
                for page in pdf.pages:
                    text = page.extract_text()
                    if text:
                        text_parts.append(text)
            return "\n\n".join(text_parts)
        except ImportError:
            pass

        try:
            from PyPDF2 import PdfReader
            reader = PdfReader(path)
            text_parts = []
            for page in reader.pages:
                text = page.extract_text()
                if text:
                    text_parts.append(text)
            return "\n\n".join(text_parts)
        except ImportError:
            raise ImportError(
                "PDF extraction requires pdfplumber or PyPDF2. "
                "Install with: pip install pdfplumber"
            )

    def _extract_docx(self, path: Path) -> str:
        """Extract text from DOCX."""
        try:
            from docx import Document
            doc = Document(path)
            paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
            return "\n\n".join(paragraphs)
        except ImportError:
            raise ImportError(
                "DOCX extraction requires python-docx. "
                "Install with: pip install python-docx"
            )

    def _truncate_words(self, text: str, max_words: int) -> str:
        """Truncate text to max words."""
        words = text.split()
        if len(words) <= max_words:
            return text
        return " ".join(words[:max_words])

    def _truncate_chars(self, text: str, max_chars: int) -> str:
        """Truncate text to max characters."""
        if len(text) <= max_chars:
            return text
        return text[:max_chars-3] + "..."

    def _assess_complexity(self, text: str) -> str:
        """Assess document complexity."""
        word_count = len(text.split())
        if word_count < 500:
            return "low"
        elif word_count < 2000:
            return "medium"
        else:
            return "high"

    def _extract_title(self, text: str) -> str:
        """Extract and truncate the title."""
        lines = text.strip().split("\n")

        for line in lines:
            line = line.strip()
            if not line:
                continue

            # Check for markdown H1
            if line.startswith("# "):
                title = line[2:].strip()
                return self._truncate_chars(title, MAX_TITLE_CHARS)

            # Check for substantial first line (likely title)
            if len(line) > 10 and len(line) < 100:
                # Remove common prefixes
                title = re.sub(r"^(title|heading|chapter)[\s:]+", "", line, flags=re.I)
                return self._truncate_chars(title, MAX_TITLE_CHARS)

        return "Untitled Document"

    def _extract_sections(self, text: str) -> List[Section]:
        """Extract sections with headers and key points."""
        sections = []

        # Try markdown headers first
        header_pattern = re.compile(r"^#{2,3}\s+(.+)$", re.MULTILINE)
        matches = list(header_pattern.finditer(text))

        if matches:
            for i, match in enumerate(matches[:MAX_SECTIONS]):
                header = self._truncate_words(match.group(1).strip(), MAX_HEADER_WORDS)

                # Extract content until next header
                start = match.end()
                end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
                content = text[start:end].strip()

                # Extract key points from content
                key_points = self._extract_key_points(content)

                # Suggest icon
                icon_hint = self._suggest_icon(header + " " + content)

                sections.append(Section(
                    header=header,
                    key_points=key_points,
                    icon_hint=icon_hint
                ))
        else:
            # Fallback: use paragraph breaks and heuristics
            paragraphs = re.split(r"\n\s*\n", text)

            for para in paragraphs[:MAX_SECTIONS]:
                para = para.strip()
                if not para or len(para) < 20:
                    continue

                # First line as header
                lines = para.split("\n")
                header = self._truncate_words(lines[0].strip(), MAX_HEADER_WORDS)

                # Rest as content for key points
                content = "\n".join(lines[1:]) if len(lines) > 1 else ""
                key_points = self._extract_key_points(content) if content else []

                icon_hint = self._suggest_icon(header)

                sections.append(Section(
                    header=header,
                    key_points=key_points,
                    icon_hint=icon_hint
                ))

        return sections[:MAX_SECTIONS]

    def _extract_key_points(self, content: str) -> List[str]:
        """Extract key points from section content."""
        points = []

        # Look for bullet points
        bullet_pattern = re.compile(r"^[\s]*[-•*]\s+(.+)$", re.MULTILINE)
        bullets = bullet_pattern.findall(content)

        for bullet in bullets[:4]:
            point = self._truncate_words(bullet.strip(), MAX_POINT_WORDS)
            if point:
                points.append(point)

        # If no bullets, extract sentences
        if not points:
            sentences = re.split(r"[.!?]+", content)
            for sentence in sentences[:3]:
                sentence = sentence.strip()
                if len(sentence) > 10:
                    point = self._truncate_words(sentence, MAX_POINT_WORDS)
                    points.append(point)

        return points[:4]

    def _extract_concepts(self, text: str) -> List[Concept]:
        """Extract key concepts from the document."""
        concepts = []

        # Extract capitalized phrases (likely important terms)
        cap_pattern = re.compile(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\b")
        cap_matches = cap_pattern.findall(text)

        # Count frequency
        freq = Counter(cap_matches)

        # Filter and create concepts
        for term, count in freq.most_common(MAX_CONCEPTS * 2):
            # Skip common words and too-short terms
            if term.lower() in self.STOP_WORDS:
                continue
            if len(term) < 4:
                continue

            # Truncate term
            short_term = self._truncate_words(term, MAX_CONCEPT_WORDS)

            # Find context for description
            desc = self._find_concept_context(text, term)

            concepts.append(Concept(
                term=short_term,
                description=desc
            ))

            if len(concepts) >= MAX_CONCEPTS:
                break

        return concepts

    def _find_concept_context(self, text: str, term: str) -> str:
        """Find contextual description for a concept."""
        # Look for "Term: description" or "Term - description" patterns
        patterns = [
            rf"{re.escape(term)}[:\-–]\s*([^.]+)",
            rf"{re.escape(term)}\s+(?:is|are|means?|refers? to)\s+([^.]+)",
        ]

        for pattern in patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                desc = match.group(1).strip()
                return self._truncate_words(desc, MAX_DESCRIPTION_WORDS)

        return ""

    def _suggest_icon(self, text: str) -> str:
        """Suggest an icon based on content keywords."""
        text_lower = text.lower()

        for keyword, icon in self.ICON_MAPPINGS.items():
            if keyword in text_lower:
                return icon

        return "document"

    def _extract_relationships(self, sections: List[Section]) -> List[Relationship]:
        """Extract relationships between sections/concepts."""
        relationships = []

        # Simple sequential relationship
        for i in range(len(sections) - 1):
            relationships.append(Relationship(
                from_concept=sections[i].header,
                to_concept=sections[i + 1].header,
                rel_type="leads_to"
            ))

        return relationships[:5]

    def _extract_quotes(self, text: str) -> List[str]:
        """Extract notable quotes."""
        quotes = []

        # Look for quoted text
        quote_pattern = re.compile(r'"([^"]{20,100})"')
        matches = quote_pattern.findall(text)

        for match in matches[:MAX_QUOTES]:
            quote = self._truncate_words(match.strip(), 15)
            quotes.append(quote)

        return quotes


def main():
    """CLI entry point."""
    parser = argparse.ArgumentParser(
        description="Decompose documents for visual generators"
    )
    parser.add_argument("file", help="Path to document file")
    parser.add_argument("-o", "--output", help="Output JSON file path")
    parser.add_argument(
        "--for",
        dest="generator",
        choices=["sketchnote", "infographic"],
        help="Format output for specific generator"
    )
    parser.add_argument(
        "--json",
        action="store_true",
        help="Output raw JSON to stdout"
    )
    parser.add_argument(
        "--llm",
        action="store_true",
        help="Use LLM-based parsing (requires GOOGLE_API_KEY)"
    )

    args = parser.parse_args()

    decomposer = DocumentDecomposer()
    doc = decomposer.decompose_file(args.file, use_llm=args.llm)

    if args.generator == "sketchnote":
        output = doc.for_sketchnote()
    elif args.generator == "infographic":
        output = doc.for_infographic()
    else:
        output = doc.to_dict()

    if args.output:
        with open(args.output, "w") as f:
            json.dump(output, f, indent=2)
        print(f"Saved to: {args.output}")
    elif args.json:
        print(json.dumps(output, indent=2))
    else:
        # Pretty print summary
        print(f"\n{'='*60}")
        print(f"Document Decomposition")
        print(f"{'='*60}")
        print(f"\nTitle: {doc.title}")
        print(f"Source: {doc.metadata.source_type}")
        print(f"Words: {doc.metadata.word_count}")
        print(f"Complexity: {doc.metadata.complexity}")
        print(f"\nSections ({len(doc.sections)}):")
        for i, section in enumerate(doc.sections, 1):
            print(f"  {i}. {section.header}")
            for point in section.key_points[:2]:
                print(f"     - {point}")
        print(f"\nKey Concepts ({len(doc.key_concepts)}):")
        for concept in doc.key_concepts[:5]:
            desc = f": {concept.description}" if concept.description else ""
            print(f"  • {concept.term}{desc}")
        if doc.quotes:
            print(f"\nQuotes ({len(doc.quotes)}):")
            for quote in doc.quotes:
                print(f'  "{quote}"')
        print(f"\n{'='*60}")


if __name__ == "__main__":
    main()
