#!/usr/bin/env python3
"""
Sketchnote Generator using Google Gemini API (Nano Banana Pro)

This module provides automated sketchnote-style visual note generation from raw text
using Google's Gemini image generation API. Creates authentic hand-drawn style
visual notes with doodles, icons, connectors, and hand-lettering.

Based on the infographic-generator pattern with sketchnote-specific adaptations
from the PPT2Vid slide generator service.

Author: Claude Code
License: MIT
"""

import os
import re
import json
import base64
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, List, Any
from dataclasses import dataclass, field

try:
    from google import genai
    from google.genai import types
except ImportError:
    raise ImportError(
        "google-genai package required. Install with: pip install google-genai"
    )


@dataclass
class StyleGuide:
    """Configuration for sketchnote visual style."""

    aspect_ratio: str = "1:1"
    resolution: str = "2K"
    primary_color: str = "marker blue"
    primary_hex: str = "#1E40AF"
    secondary_color: str = "sketch gray"
    secondary_hex: str = "#4B5563"
    accent_color: str = "highlighter yellow"
    accent_hex: str = "#FDE047"
    background: str = "notebook cream"
    background_hex: str = "#FFFBEB"
    text_color: str = "ink black"
    text_hex: str = "#1C1917"
    illustration_style: str = "authentic sketchnote visual note-taking style"
    typography_style: str = "authentic hand-lettering with natural variation"
    lighting: str = "flat even lighting"
    camera: str = "flat top-down birds-eye view"


@dataclass
class ParsedContent:
    """Extracted structural elements from raw text."""

    title: str = ""
    concepts: List[str] = field(default_factory=list)
    relationships: List[Dict[str, str]] = field(default_factory=list)
    quotes: List[str] = field(default_factory=list)
    keywords: List[str] = field(default_factory=list)


@dataclass
class GenerationResult:
    """Result of sketchnote generation."""

    success: bool
    output_path: Optional[str] = None
    error: Optional[str] = None
    prompt_used: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)


# Quality prefix for all prompts - based on Gemini best practices
QUALITY_PREFIX = """QUALITY REQUIREMENTS (Apply to entire image):
- Masterpiece quality, best quality, highly detailed, professional grade
- 4K resolution rendering, crisp edges, authentic hand-drawn feel
- Mike Rohde sketchnote quality, conference notes excellence
- Expert visual note-taking, authentic marker and pen aesthetic
- Perfect composition, balanced information density
"""

# Text rendering best practices from Google documentation
TEXT_RENDERING_GUIDELINES = """
TEXT RENDERING REQUIREMENTS (CRITICAL FOR ACCURACY):
- Render ALL text EXACTLY as specified in quotation marks - no substitutions
- Use clean, bold, sans-serif style for headings
- Use casual handwritten style for body text
- Each text element must be LEGIBLE and CORRECTLY SPELLED
- Text should be LARGE enough to read clearly
- Limit text per section to SHORT PHRASES (3-6 words per line)
- Position text INSIDE designated areas with clear margins
"""

# Negative prompts to avoid common generation issues
NEGATIVE_PROMPT = """
AVOID (CRITICAL - Do NOT include any of these):
- Misspelled words or garbled text
- Random gibberish letters or illegible text
- Blurry or unreadable text
- Text that differs from what was specified
- Digital or computer-generated aesthetic
- Perfect geometric shapes or lines
- Stock photo look, clip art aesthetic
- Watermarks, logos, or signatures
- Empty whitespace - fill the page with content
- 3D effects or drop shadows
"""

# Layout presets for different sketchnote arrangements
LAYOUT_PRESETS = {
    "dense": {
        "description": "Maximum information density, conference notes style",
        "layout_instruction": "Pack the entire page densely with interconnected visual notes. Use every available space. Multiple idea clusters connected by arrows. Margin annotations. Banner headings. No empty space.",
    },
    "centered": {
        "description": "Central topic with radiating concepts",
        "layout_instruction": "Place the main topic in a decorated banner at the center. Radiate related concepts outward like a sun burst. Use connectors and arrows to link ideas back to center.",
    },
    "linear": {
        "description": "Top-to-bottom sequential flow",
        "layout_instruction": "Arrange concepts in a clear top-to-bottom flow. Use numbered sections. Draw flow arrows between sections. Add margin icons and annotations along the sides.",
    },
    "grid": {
        "description": "Organized grid of concept boxes",
        "layout_instruction": "Organize concepts in a loose hand-drawn grid. Each cell contains one key idea with supporting doodles. Connect related cells with arrows across the grid.",
    },
    "mind-map": {
        "description": "Branching tree structure",
        "layout_instruction": "Create a mind-map structure with the main topic at center. Branch out with major concepts. Sub-branch with details. Use organic curved connectors.",
    },
}

# Color palette presets
PALETTE_PRESETS = {
    "classic": {
        "primary": "marker blue", "primary_hex": "#1E40AF",
        "secondary": "sketch gray", "secondary_hex": "#4B5563",
        "accent": "highlighter yellow", "accent_hex": "#FDE047",
        "background": "notebook cream", "background_hex": "#FFFBEB",
        "text": "ink black", "text_hex": "#1C1917",
    },
    "warm": {
        "primary": "sepia brown", "primary_hex": "#8B4513",
        "secondary": "warm gray", "secondary_hex": "#6B7280",
        "accent": "burnt orange", "accent_hex": "#EA580C",
        "background": "aged cream", "background_hex": "#FEF3C7",
        "text": "dark brown", "text_hex": "#451A03",
    },
    "cool": {
        "primary": "teal", "primary_hex": "#0D9488",
        "secondary": "slate", "secondary_hex": "#64748B",
        "accent": "sky blue", "accent_hex": "#38BDF8",
        "background": "cool white", "background_hex": "#F8FAFC",
        "text": "dark slate", "text_hex": "#0F172A",
    },
    "mono": {
        "primary": "pure black", "primary_hex": "#000000",
        "secondary": "dark gray", "secondary_hex": "#374151",
        "accent": "medium gray", "accent_hex": "#6B7280",
        "background": "cream paper", "background_hex": "#FFFBEB",
        "text": "black", "text_hex": "#000000",
    },
    "vibrant": {
        "primary": "electric blue", "primary_hex": "#2563EB",
        "secondary": "magenta", "secondary_hex": "#DB2777",
        "accent": "lime green", "accent_hex": "#84CC16",
        "background": "white", "background_hex": "#FFFFFF",
        "text": "dark purple", "text_hex": "#581C87",
    },
}


class SketchnoteGenerator:
    """
    Generate sketchnote-style visual notes from text using Google Gemini's image generation API.

    Implements a workflow optimized for authentic hand-drawn visual notes:
    1. Content Parsing - Extract key concepts and relationships
    2. Prompt Construction - Build sketchnote-specific prompts
    3. Image Generation - Create authentic hand-drawn style visuals

    Example:
        generator = SketchnoteGenerator(api_key="your-api-key")
        result = generator.generate(
            raw_text="Your meeting notes or article content...",
            layout="dense",
            aspect_ratio="1:1"
        )
    """

    # Model configurations - Nano Banana Pro
    # gemini-3-pro-image-preview is best quality (same as PPT2Vid)
    # gemini-2.5-flash-image is faster alternative with higher quota
    MODELS = {
        "nano_banana_pro": "gemini-3-pro-image-preview",  # Best quality
        "nano_banana": "gemini-2.5-flash-image",          # Faster/higher quota
    }

    # Valid aspect ratios per API docs
    VALID_ASPECT_RATIOS = [
        "1:1", "2:3", "3:2", "3:4", "4:3",
        "4:5", "5:4", "9:16", "16:9", "21:9"
    ]

    # Valid resolutions
    VALID_RESOLUTIONS = ["1K", "2K", "4K"]

    def __init__(
        self,
        api_key: Optional[str] = None,
        model: str = "nano_banana_pro",
        output_dir: Optional[str] = None,
    ):
        """
        Initialize the SketchnoteGenerator.

        Args:
            api_key: Google API key. Falls back to GOOGLE_API_KEY env var.
            model: Model to use ("nano_banana_pro" or "nano_banana").
            output_dir: Default output directory for generated images.
        """
        self.api_key = api_key or os.environ.get("GOOGLE_API_KEY")
        if not self.api_key:
            raise ValueError(
                "API key required. Pass api_key or set GOOGLE_API_KEY environment variable."
            )

        # Initialize client (new google.genai SDK)
        self.client = genai.Client(api_key=self.api_key)

        # Set model
        if model in self.MODELS:
            self.model = self.MODELS[model]
        else:
            self.model = model  # Allow direct model ID

        # Set output directory
        self.output_dir = Path(
            output_dir or os.environ.get("SKETCHNOTE_OUTPUT_DIR", ".")
        )
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def parse_content(self, raw_text: str) -> ParsedContent:
        """
        Parse raw text to extract key concepts for sketchnote visualization.

        Focuses on extracting:
        - Main title/topic
        - Key concepts (5-7 major ideas)
        - Relationships between concepts
        - Important quotes or callouts

        Args:
            raw_text: The full text content (article, notes, transcript).

        Returns:
            ParsedContent with title, concepts, relationships, and quotes.
        """
        parsed = ParsedContent()
        lines = raw_text.strip().split("\n")

        # Extract title (first H1 or first substantial line)
        for line in lines:
            line = line.strip()
            if not line:
                continue

            # Check for markdown H1
            if line.startswith("# "):
                parsed.title = line[2:].strip()
                break
            # Check for underlined title
            idx = lines.index(line + "\n") if line + "\n" in lines else -1
            if idx >= 0 and idx + 1 < len(lines):
                next_line = lines[idx + 1].strip()
                if next_line and all(c == "=" for c in next_line):
                    parsed.title = line
                    break
            # Fall back to first substantial line
            if len(line) > 10 and not line.startswith("-"):
                parsed.title = line[:80]  # Truncate long titles
                break

        # Extract headers as concepts (H2, H3)
        header_pattern = re.compile(r"^#{2,3}\s+(.+)$")
        for line in lines:
            line = line.strip()
            match = header_pattern.match(line)
            if match:
                concept = match.group(1).strip()
                if concept and concept != parsed.title and len(parsed.concepts) < 7:
                    parsed.concepts.append(concept)

        # If no headers found, extract key sentences
        if not parsed.concepts:
            # Look for bullet points
            bullet_pattern = re.compile(r"^[-*•]\s+(.+)$")
            for line in lines:
                line = line.strip()
                match = bullet_pattern.match(line)
                if match and len(parsed.concepts) < 7:
                    concept = match.group(1).strip()
                    if len(concept) > 10:
                        parsed.concepts.append(concept[:60])

        # Extract quotes (lines in quotes or starting with >)
        quote_pattern = re.compile(r'^[">](.+)["<]?$|^>\s*(.+)$')
        for line in lines:
            line = line.strip()
            match = quote_pattern.match(line)
            if match:
                quote = (match.group(1) or match.group(2) or "").strip()
                if quote and len(parsed.quotes) < 3:
                    parsed.quotes.append(quote[:100])

        # Extract keywords (frequency-based)
        common_words = {
            "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
            "of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
            "being", "have", "has", "had", "do", "does", "did", "will", "would",
            "could", "should", "may", "might", "must", "shall", "can", "this",
            "that", "these", "those", "i", "you", "he", "she", "it", "we", "they",
            "what", "which", "who", "when", "where", "why", "how", "all", "each",
        }

        words = re.findall(r"\b[a-zA-Z]{4,}\b", raw_text.lower())
        word_counts = {}
        for word in words:
            if word not in common_words:
                word_counts[word] = word_counts.get(word, 0) + 1

        sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
        parsed.keywords = [word for word, count in sorted_words[:5]]

        return parsed

    def generate_style_guide(
        self,
        palette: str = "classic",
        aspect_ratio: str = "1:1",
        resolution: str = "2K",
    ) -> StyleGuide:
        """
        Generate a style guide from palette preset.

        Args:
            palette: Color palette name (see PALETTE_PRESETS).
            aspect_ratio: Image aspect ratio.
            resolution: Image resolution (1K, 2K, 4K).

        Returns:
            Configured StyleGuide instance.
        """
        # Validate aspect ratio
        if aspect_ratio not in self.VALID_ASPECT_RATIOS:
            print(f"Warning: Invalid aspect ratio '{aspect_ratio}', defaulting to 1:1")
            aspect_ratio = "1:1"

        # Validate resolution
        if resolution not in self.VALID_RESOLUTIONS:
            print(f"Warning: Invalid resolution '{resolution}', defaulting to 2K")
            resolution = "2K"

        # Get palette or default to classic
        colors = PALETTE_PRESETS.get(palette.lower(), PALETTE_PRESETS["classic"])

        style = StyleGuide(
            aspect_ratio=aspect_ratio,
            resolution=resolution,
            primary_color=colors["primary"],
            primary_hex=colors["primary_hex"],
            secondary_color=colors["secondary"],
            secondary_hex=colors["secondary_hex"],
            accent_color=colors["accent"],
            accent_hex=colors["accent_hex"],
            background=colors["background"],
            background_hex=colors["background_hex"],
            text_color=colors["text"],
            text_hex=colors["text_hex"],
        )

        return style

    def construct_prompt(
        self,
        parsed: ParsedContent,
        style: StyleGuide,
        layout: str = "dense",
        sections: Optional[List[Dict[str, Any]]] = None,
    ) -> str:
        """
        Construct the image generation prompt for sketchnote style.

        Uses Google's best practices for text rendering:
        - Put exact text in quotation marks
        - Keep text SHORT (3-6 words per element)
        - Describe typography descriptively
        - Be explicit about text placement

        Args:
            parsed: Parsed content with title, concepts, and quotes.
            style: Style guide configuration.
            layout: Layout preset name.
            sections: Optional list of section dicts with 'header' and 'key_points'.

        Returns:
            Complete prompt string for image generation.
        """
        # Get layout instruction
        layout_config = LAYOUT_PRESETS.get(layout.lower(), LAYOUT_PRESETS["dense"])
        layout_instruction = layout_config["layout_instruction"]

        # Truncate title to short phrase for better text rendering
        short_title = parsed.title[:40] if len(parsed.title) > 40 else parsed.title

        # Format content - use sections if provided (from decomposed JSON), otherwise use parsed concepts
        concept_items = []

        if sections:
            # Use structured sections with headers and key_points
            for i, section in enumerate(sections[:6], 1):
                header = section.get('header', '')[:30]  # Limit header length
                key_points = section.get('key_points', [])

                # Format key points as short bullet items (max 5 points, 4 words each)
                bullets = []
                for point in key_points[:5]:
                    words = point.split()[:4]
                    short_point = " ".join(words)
                    bullets.append(short_point)

                bullet_text = ", ".join(bullets) if bullets else ""
                concept_items.append(f'BOX {i}: Header "{header}" with bullets: {bullet_text}')
        else:
            # Fall back to parsed concepts (header-only mode)
            for i, concept in enumerate(parsed.concepts[:6], 1):
                words = concept.split()[:5]
                short_concept = " ".join(words)
                concept_items.append(f'BOX {i}: Render text "{short_concept}" in clean bold sans-serif')

        concepts_section = "\n".join(concept_items)

        prompt = f"""{QUALITY_PREFIX}
{TEXT_RENDERING_GUIDELINES}

Create a SKETCHNOTE visual summary with the following EXACT TEXT content:

=== TITLE (in decorative banner at top) ===
Render the text "{short_title}" in large, bold, clean sans-serif lettering inside a hand-drawn banner ribbon.

=== CONTENT BOXES (6 labeled sections) ===
{concepts_section}

=== VISUAL STYLE ===
- Authentic sketchnote hand-drawn aesthetic
- Hand-drawn boxes with blue borders for each concept
- Small doodle icons next to each box
- Arrows connecting related boxes
- Cream/off-white paper background
- Blue ink for borders and headings ({style.primary_hex})
- Yellow highlighter accents ({style.accent_hex})
- Margin decorations: stars, checkmarks, lightbulbs

=== TEXT RENDERING (CRITICAL) ===
- The TITLE text must read EXACTLY: "{short_title}"
- Each BOX must contain its specified text EXACTLY as written above
- Use clean, legible, hand-lettered style
- Text must be LARGE and READABLE
- NO spelling errors, NO gibberish, NO random letters
- SPELL OUT difficult words letter-by-letter if needed to ensure accuracy
- Do NOT substitute similar words (e.g., "Requirements" not "Risks", "Responsibilities" not "Roles")

=== LAYOUT ===
{layout_instruction}

ASPECT RATIO: {style.aspect_ratio}
RESOLUTION: {style.resolution}

{NEGATIVE_PROMPT}

Generate the sketchnote image now with ALL text rendered EXACTLY as specified above."""

        return prompt

    def _call_api_with_retry(
        self,
        prompt: str,
        max_retries: int = 3,
        initial_delay: float = 1.0,
    ) -> Any:
        """
        Call the Gemini API with exponential backoff retry.

        Args:
            prompt: The generation prompt.
            max_retries: Maximum number of retry attempts.
            initial_delay: Initial delay between retries in seconds.

        Returns:
            API response object.

        Raises:
            Exception: If all retries fail.
        """
        delay = initial_delay
        last_error = None

        for attempt in range(max_retries):
            try:
                # New google.genai SDK pattern with image generation config
                response = self.client.models.generate_content(
                    model=self.model,
                    contents=prompt,
                    config=types.GenerateContentConfig(
                        response_modalities=["IMAGE"],
                    ),
                )
                return response

            except Exception as e:
                last_error = e
                error_str = str(e).lower()

                # Check for rate limiting
                if "rate" in error_str or "quota" in error_str:
                    print(f"Rate limited, waiting {delay}s before retry...")
                    time.sleep(delay)
                    delay *= 2
                    continue

                # Check for safety blocks
                if "safety" in error_str or "blocked" in error_str:
                    print(f"Safety block encountered: {e}")
                    raise

                # Other errors - retry with backoff
                print(f"API error (attempt {attempt + 1}/{max_retries}): {e}")
                time.sleep(delay)
                delay *= 2

        raise last_error or Exception("Max retries exceeded")

    def generate_image(
        self,
        prompt: str,
        output_path: Optional[str] = None,
    ) -> GenerationResult:
        """
        Generate a sketchnote image using the Gemini API.

        Args:
            prompt: The complete generation prompt.
            output_path: Where to save the image. Auto-generated if not provided.

        Returns:
            GenerationResult with success status and output path.
        """
        result = GenerationResult(success=False, prompt_used=prompt)

        try:
            # Call API
            response = self._call_api_with_retry(prompt)

            # Extract image from response (new google.genai SDK format)
            image_data = None
            mime_type = "image/png"

            # New SDK: response.candidates[0].content.parts[0].inline_data
            if response.candidates and response.candidates[0].content:
                for part in response.candidates[0].content.parts:
                    if hasattr(part, "inline_data") and part.inline_data:
                        image_data = part.inline_data.data
                        if hasattr(part.inline_data, "mime_type"):
                            mime_type = part.inline_data.mime_type
                        break

            if not image_data:
                result.error = "No image data in response"
                return result

            # Determine output path
            if not output_path:
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                ext = "png" if "png" in mime_type else "jpg"
                output_path = self.output_dir / f"sketchnote_{timestamp}.{ext}"
            else:
                output_path = Path(output_path)

            # Decode and save image
            if isinstance(image_data, str):
                image_bytes = base64.b64decode(image_data)
            else:
                image_bytes = image_data

            output_path.parent.mkdir(parents=True, exist_ok=True)
            with open(output_path, "wb") as f:
                f.write(image_bytes)

            result.success = True
            result.output_path = str(output_path)
            result.metadata = {
                "model": self.model,
                "timestamp": datetime.now().isoformat(),
                "file_size": len(image_bytes),
            }

        except Exception as e:
            result.error = str(e)

        return result

    def generate(
        self,
        raw_text: str,
        layout: str = "dense",
        aspect_ratio: str = "1:1",
        resolution: str = "2K",
        palette: str = "classic",
        output_path: Optional[str] = None,
    ) -> GenerationResult:
        """
        Full pipeline: parse content, build prompt, generate sketchnote.

        This is the main entry point for generating sketchnotes.

        Args:
            raw_text: Raw text content (article, notes, transcript).
            layout: Layout preset name.
            aspect_ratio: Image aspect ratio.
            resolution: Image resolution.
            palette: Color palette name.
            output_path: Where to save the output image.

        Returns:
            GenerationResult with success status and output details.
        """
        # Step 1: Parse content
        print("Step 1: Parsing content structure...")
        parsed = self.parse_content(raw_text)

        if not parsed.title:
            return GenerationResult(
                success=False,
                error="Could not extract title from content"
            )

        print(f"  Title: {parsed.title}")
        print(f"  Concepts: {parsed.concepts}")
        print(f"  Keywords: {parsed.keywords}")

        # Step 2: Generate style guide
        print(f"Step 2: Applying '{palette}' palette with '{layout}' layout...")
        style = self.generate_style_guide(
            palette=palette,
            aspect_ratio=aspect_ratio,
            resolution=resolution,
        )

        # Step 3: Construct prompt
        print("Step 3: Constructing sketchnote prompt...")
        prompt = self.construct_prompt(parsed, style, layout)

        # Step 4: Generate image
        print("Step 4: Generating sketchnote via Gemini API...")
        result = self.generate_image(prompt, output_path)

        if result.success:
            print(f"Success! Sketchnote saved to: {result.output_path}")
        else:
            print(f"Generation failed: {result.error}")

        return result


# Convenience function for quick generation
def create_sketchnote(
    text: str,
    layout: str = "dense",
    aspect_ratio: str = "1:1",
    resolution: str = "2K",
    palette: str = "classic",
    output_path: Optional[str] = None,
    api_key: Optional[str] = None,
) -> str:
    """
    Convenience function to generate a sketchnote in one call.

    Args:
        text: Raw text content.
        layout: Layout preset.
        aspect_ratio: Image dimensions ratio.
        resolution: Output resolution.
        palette: Color palette name.
        output_path: Where to save the image.
        api_key: Google API key (or set GOOGLE_API_KEY env var).

    Returns:
        Path to generated image, or error message.
    """
    generator = SketchnoteGenerator(api_key=api_key)
    result = generator.generate(
        raw_text=text,
        layout=layout,
        aspect_ratio=aspect_ratio,
        resolution=resolution,
        palette=palette,
        output_path=output_path,
    )

    if result.success:
        return result.output_path
    else:
        return f"Error: {result.error}"


if __name__ == "__main__":
    # Example usage
    sample_text = """
    # Cybersecurity Best Practices

    ## Defense in Depth
    Layer multiple security controls. No single point of failure.

    ## Zero Trust Architecture
    Never trust, always verify. Authenticate every access request.

    ## Incident Response
    Prepare, detect, contain, eradicate, recover, learn.

    ## Security Awareness
    Train users regularly. Phishing is the top attack vector.

    ## Risk Management
    Identify, assess, mitigate, monitor risks continuously.
    """

    # Check for API key
    if not os.environ.get("GOOGLE_API_KEY"):
        print("Set GOOGLE_API_KEY environment variable to run this example")
        print("\nExample parsed content:")
        gen = SketchnoteGenerator.__new__(SketchnoteGenerator)
        parsed = gen.parse_content(sample_text)
        print(f"Title: {parsed.title}")
        print(f"Concepts: {parsed.concepts}")
        print(f"Keywords: {parsed.keywords}")
    else:
        result = create_sketchnote(
            text=sample_text,
            layout="dense",
            aspect_ratio="1:1",
            resolution="2K",
            palette="classic",
        )
        print(f"Result: {result}")
