#!/usr/bin/env python3
"""
Image Format Conversion Script
Converts a lossless DNG/RAW image to PNG, JPEG, and JPEG XL for quality comparison.

Workflow:
  - Input: DNG (RAW) file - lossless source
  - Output:
    * PNG (lossless reference - baseline for quality)
    * JPEG (traditional lossy codec)
    * JPEG XL (modern codec - better compression/quality)

This allows fair comparison of how each codec handles fine details
(like door numbers from far away) starting from the same lossless source.

Requirements:
  - Python 3.9+
  - Pillow: pip install Pillow
  - rawpy: pip install rawpy (for DNG/RAW support)
  - For JPEG XL support, install ONE of:
    * libjxl (provides cjxl command):
        - macOS:   brew install libjxl
        - Ubuntu:  sudo apt install libjxl-tools
        - Windows: Download from https://github.com/libjxl/libjxl/releases
    * ImageMagick with JXL support:
        - macOS:   brew install imagemagick
        - Ubuntu:  sudo apt install imagemagick
        - Windows: Download from https://imagemagick.org/script/download.php
"""

import argparse
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path

try:
    from PIL import Image
except ImportError:
    print("=" * 60)
    print("ERROR: Pillow library is required but not installed.")
    print("=" * 60)
    print("\nInstall it with one of these commands:\n")
    print("  pip install Pillow")
    print("  pip3 install Pillow")
    print("  python -m pip install Pillow")
    print("  python3 -m pip install Pillow")
    print()
    sys.exit(1)

# Try to import rawpy for DNG support
RAWPY_AVAILABLE = False
try:
    import rawpy
    RAWPY_AVAILABLE = True
except ImportError:
    pass


def get_file_size_human(size_bytes: int) -> str:
    """Convert bytes to human-readable format."""
    for unit in ['B', 'KB', 'MB', 'GB']:
        if size_bytes < 1024:
            return f"{size_bytes:.1f} {unit}" if unit != 'B' else f"{size_bytes} {unit}"
        size_bytes /= 1024
    return f"{size_bytes:.1f} TB"


def check_jxl_support() -> tuple[bool, str]:
    """Check if JPEG XL encoding is available."""
    # Check for cjxl (libjxl reference encoder)
    if shutil.which('cjxl'):
        return True, 'cjxl'

    # Check for ImageMagick with JXL support
    if shutil.which('magick') or shutil.which('convert'):
        try:
            cmd = ['magick', 'identify', '-list', 'format'] if shutil.which('magick') else ['convert', '-list', 'format']
            result = subprocess.run(cmd, capture_output=True, text=True)
            if 'JXL' in result.stdout:
                return True, 'imagemagick'
        except Exception:
            pass

    return False, None


def load_dng_image(input_path: Path) -> Image.Image:
    """Load a DNG/RAW file and return as PIL Image."""
    if not RAWPY_AVAILABLE:
        print("=" * 60)
        print("ERROR: rawpy library is required for DNG/RAW files.")
        print("=" * 60)
        print("\nInstall it with one of these commands:\n")
        print("  pip install rawpy")
        print("  pip3 install rawpy")
        print("  python -m pip install rawpy")
        print()
        sys.exit(1)

    with rawpy.imread(str(input_path)) as raw:
        # Process with high quality settings
        rgb = raw.postprocess(
            use_camera_wb=True,      # Use camera white balance
            half_size=False,          # Full resolution
            no_auto_bright=False,     # Allow auto brightness
            output_bps=8,             # 8-bit output for compatibility
            demosaic_algorithm=rawpy.DemosaicAlgorithm.AHD  # High quality demosaic
        )

    return Image.fromarray(rgb)


def convert_to_png(img: Image.Image, output_path: Path) -> dict:
    """Save image as PNG format (lossless)."""
    # PNG supports RGB and RGBA
    if img.mode not in ('RGB', 'RGBA', 'L', 'LA', 'P'):
        img = img.convert('RGB')
    img.save(output_path, 'PNG', optimize=True)

    size_bytes = output_path.stat().st_size
    return {
        'format': 'PNG',
        'filename': output_path.name,
        'size_bytes': size_bytes,
        'size_human': get_file_size_human(size_bytes),
        'compression': 'lossless'
    }


def convert_to_jpeg(img: Image.Image, output_path: Path, quality: int = 85) -> dict:
    """Save image as JPEG format (lossy)."""
    # JPEG requires RGB
    if img.mode != 'RGB':
        img = img.convert('RGB')
    img.save(output_path, 'JPEG', quality=quality, optimize=True)

    size_bytes = output_path.stat().st_size
    return {
        'format': 'JPEG',
        'filename': output_path.name,
        'size_bytes': size_bytes,
        'size_human': get_file_size_human(size_bytes),
        'quality': quality,
        'compression': 'lossy'
    }


def convert_to_jxl(input_png_path: Path, output_path: Path, quality: int = 90, tool: str = 'cjxl', lossless: bool = False) -> dict:
    """Convert PNG to JPEG XL format with progressive encoding enabled."""
    progressive = False

    if tool == 'cjxl':
        if lossless:
            # Lossless JPEG XL with progressive decoding
            cmd = ['cjxl', str(input_png_path), str(output_path), '-d', '0', '-p']
        else:
            # Lossy JPEG XL with progressive decoding
            # cjxl uses distance (0 = lossless, 1 = high quality, higher = more compression)
            # -p enables progressive/responsive decoding (shows preview while loading)
            distance = max(0.1, (100 - quality) / 10)
            cmd = ['cjxl', str(input_png_path), str(output_path), '-d', str(distance), '-p']

        progressive = True
        subprocess.run(cmd, check=True, capture_output=True)

    elif tool == 'imagemagick':
        # ImageMagick doesn't support progressive JXL encoding
        cmd_base = 'magick' if shutil.which('magick') else 'convert'
        if lossless:
            cmd = [cmd_base, str(input_png_path), '-quality', '100', str(output_path)]
        else:
            cmd = [cmd_base, str(input_png_path), '-quality', str(quality), str(output_path)]
        subprocess.run(cmd, check=True, capture_output=True)

    size_bytes = output_path.stat().st_size
    return {
        'format': 'JPEG XL',
        'filename': output_path.name,
        'size_bytes': size_bytes,
        'size_human': get_file_size_human(size_bytes),
        'quality': 'lossless' if lossless else quality,
        'compression': 'lossless' if lossless else 'lossy',
        'progressive': progressive
    }


def process_image(input_path: str, output_dir: str = 'output', jpeg_quality: int = 85, jxl_quality: int = 90, jxl_lossless: bool = False) -> dict:
    """Process input DNG image and generate PNG, JPEG, and JPEG XL variants."""
    input_path = Path(input_path)
    base_output_dir = Path(output_dir)

    if not input_path.exists():
        raise FileNotFoundError(f"Input file not found: {input_path}")

    # Create output subdirectory based on input filename (without extension)
    image_name = input_path.stem
    output_dir = base_output_dir / image_name
    output_dir.mkdir(parents=True, exist_ok=True)

    # Check if it's a RAW/DNG file
    raw_extensions = {'.dng', '.cr2', '.cr3', '.nef', '.arw', '.orf', '.rw2', '.raw'}
    is_raw = input_path.suffix.lower() in raw_extensions

    print(f"Processing: {input_path.name}")

    if is_raw:
        print("Loading DNG/RAW file...")
        img = load_dng_image(input_path)
    else:
        # Regular image file
        img = Image.open(input_path)
        if img.mode not in ('RGB', 'RGBA'):
            img = img.convert('RGB')

    width, height = img.size
    print(f"Image size: {width}x{height}")
    print()

    original_size = input_path.stat().st_size
    stats = {
        'source': {
            'format': 'DNG (RAW)' if is_raw else input_path.suffix.upper().replace('.', ''),
            'filename': input_path.name,
            'size_bytes': original_size,
            'size_human': get_file_size_human(original_size),
            'width': width,
            'height': height,
            'compression': 'lossless (RAW)' if is_raw else 'source'
        }
    }

    # Convert to PNG (lossless reference)
    print("Converting to PNG (lossless reference)...")
    png_output = output_dir / "converted.png"
    png_stats = convert_to_png(img, png_output)
    png_stats['width'] = width
    png_stats['height'] = height
    stats['png'] = png_stats
    print(f"  PNG: {png_stats['size_human']}")

    # Convert to JPEG
    print(f"Converting to JPEG (quality: {jpeg_quality})...")
    jpeg_output = output_dir / "converted.jpg"
    jpeg_stats = convert_to_jpeg(img, jpeg_output, jpeg_quality)
    jpeg_stats['width'] = width
    jpeg_stats['height'] = height
    jpeg_stats['size_vs_png'] = f"{((jpeg_stats['size_bytes'] - png_stats['size_bytes']) / png_stats['size_bytes'] * 100):+.1f}%"
    stats['jpeg'] = jpeg_stats
    print(f"  JPEG: {jpeg_stats['size_human']} ({jpeg_stats['size_vs_png']} vs PNG)")

    # Convert to JPEG XL (from PNG to ensure fair comparison)
    jxl_support, jxl_tool = check_jxl_support()
    if jxl_support:
        mode_str = "lossless" if jxl_lossless else f"quality {jxl_quality}"
        progressive_note = ", progressive" if jxl_tool == 'cjxl' else ""
        print(f"Converting to JPEG XL ({mode_str}{progressive_note}, using {jxl_tool})...")
        jxl_output = output_dir / "converted.jxl"
        try:
            jxl_stats = convert_to_jxl(png_output, jxl_output, jxl_quality, jxl_tool, jxl_lossless)
            jxl_stats['width'] = width
            jxl_stats['height'] = height
            jxl_stats['size_vs_png'] = f"{((jxl_stats['size_bytes'] - png_stats['size_bytes']) / png_stats['size_bytes'] * 100):+.1f}%"
            jxl_stats['size_vs_jpeg'] = f"{((jxl_stats['size_bytes'] - jpeg_stats['size_bytes']) / jpeg_stats['size_bytes'] * 100):+.1f}%"
            stats['jxl'] = jxl_stats
            prog_str = " [progressive]" if jxl_stats.get('progressive') else ""
            print(f"  JPEG XL{prog_str}: {jxl_stats['size_human']} ({jxl_stats['size_vs_png']} vs PNG, {jxl_stats['size_vs_jpeg']} vs JPEG)")
        except subprocess.CalledProcessError as e:
            print(f"  Warning: JPEG XL conversion failed: {e}")
            stats['jxl'] = {'error': 'Conversion failed', 'available': False}
    else:
        print("\n" + "=" * 60)
        print("JPEG XL encoder not found!")
        print("=" * 60)
        print("\nTo enable JPEG XL conversion, install ONE of the following:\n")
        print("libjxl (recommended):")
        print("  macOS:   brew install libjxl")
        print("  Ubuntu:  sudo apt install libjxl-tools")
        print("  Windows: https://github.com/libjxl/libjxl/releases")
        print("\nImageMagick:")
        print("  macOS:   brew install imagemagick")
        print("  Ubuntu:  sudo apt install imagemagick")
        print("  Windows: https://imagemagick.org/script/download.php")
        print()
        stats['jxl'] = {
            'error': 'JPEG XL encoder not available',
            'available': False,
            'install_hint': 'Install libjxl (cjxl) or ImageMagick with JXL support'
        }

    # Save stats
    stats_path = output_dir / "stats.json"
    with open(stats_path, 'w') as f:
        json.dump(stats, f, indent=2)
    print(f"\nStats saved to: {stats_path}")

    return stats


def check_dependencies():
    """Check and display status of all dependencies."""
    print("=" * 60)
    print("Dependency Check")
    print("=" * 60)
    print()

    # Python version
    py_version = sys.version_info
    py_ok = py_version >= (3, 9)
    print(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}")
    print(f"  Status: {'OK' if py_ok else 'WARNING - Python 3.9+ recommended'}")
    print()

    # Pillow
    try:
        import PIL
        print(f"Pillow {PIL.__version__}")
        print("  Status: OK")
    except ImportError:
        print("Pillow: NOT INSTALLED")
        print("  Install: pip install Pillow")
    print()

    # rawpy (for DNG support)
    if RAWPY_AVAILABLE:
        import rawpy
        print(f"rawpy {rawpy.__version__ if hasattr(rawpy, '__version__') else '(version unknown)'}")
        print("  Status: OK (DNG/RAW support enabled)")
    else:
        print("rawpy: NOT INSTALLED")
        print("  Install: pip install rawpy")
        print("  (Required for DNG/RAW file support)")
    print()

    # JPEG XL support
    jxl_support, jxl_tool = check_jxl_support()
    if jxl_support:
        print(f"JPEG XL encoder: {jxl_tool}")
        print("  Status: OK")
    else:
        print("JPEG XL encoder: NOT FOUND")
        print("  Install (choose one):")
        print("    macOS:   brew install libjxl")
        print("    Ubuntu:  sudo apt install libjxl-tools")
        print("    Windows: https://github.com/libjxl/libjxl/releases")
    print()

    print("=" * 60)
    all_ok = py_ok and RAWPY_AVAILABLE and jxl_support
    if all_ok:
        print("All dependencies are installed!")
    else:
        print("Some dependencies are missing. Install them for full functionality.")
    print("=" * 60)


def main():
    parser = argparse.ArgumentParser(
        description='Convert DNG/RAW images to PNG, JPEG, and JPEG XL for quality comparison',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Workflow:
  Input:  DNG (RAW) file - lossless source with maximum quality
  Output: PNG (lossless reference), JPEG (traditional), JPEG XL (modern)

This allows fair comparison starting from the same lossless source,
showing how each codec handles fine details (door numbers, text, etc.)

Examples:
  python convert_images.py photo.dng
  python convert_images.py photo.dng -o comparison_output
  python convert_images.py photo.dng --jpeg-quality 90 --jxl-quality 95
  python convert_images.py photo.dng --jxl-lossless
  python convert_images.py --check-deps
        """
    )
    parser.add_argument('input', nargs='?', help='Input DNG/RAW image path')
    parser.add_argument('-o', '--output', default='output', help='Output directory (default: output)')
    parser.add_argument('--jpeg-quality', type=int, default=90, help='JPEG quality 1-100 (default: 90)')
    parser.add_argument('--jxl-quality', type=int, default=90, help='JPEG XL quality 1-100 (default: 90)')
    parser.add_argument('--jxl-lossless', action='store_true', help='Use lossless JPEG XL encoding')
    parser.add_argument('--check-deps', action='store_true', help='Check dependencies and exit')

    args = parser.parse_args()

    if args.check_deps:
        check_dependencies()
        sys.exit(0)

    if not args.input:
        parser.print_help()
        print("\nError: Input file is required (unless using --check-deps)")
        sys.exit(1)

    try:
        stats = process_image(
            args.input,
            args.output,
            args.jpeg_quality,
            args.jxl_quality,
            args.jxl_lossless
        )
        print("\nConversion complete!")
        print(f"Open index.html in your browser to compare the images.")
    except FileNotFoundError as e:
        print(f"Error: {e}")
        sys.exit(1)
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)


if __name__ == '__main__':
    main()
