#!/usr/bin/env python3
"""
EthoVision local tracker worker — runs DeepLabCut SuperAnimal on the user's GPU.

Polls the EthoVision server for pending tracking jobs, downloads the video,
runs DLC inference, and uploads the keypoints back.

Usage:
    ethovision_worker.py --server https://your-ev-host:3012 --token <worker-token>

Environment variables (override flags):
    ETHOVISION_SERVER   base URL of the EthoVision app (e.g. http://217.160.138.89:3012)
    ETHOVISION_TOKEN    bearer token from the server's .env.local
    ETHOVISION_WORKER_ID  human-readable id reported to the server (default: hostname)

Run with --dry-run to poll without executing DLC — useful to validate your
config before installing the full DeepLabCut stack.
"""
from __future__ import annotations

import argparse
import json
import os
import platform
import signal
import socket
import sys
import time
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional

try:
    import requests
except ImportError:
    print("Missing dependency `requests`. Install with: pip install requests", file=sys.stderr)
    sys.exit(2)

# DLC / torch are heavy; import lazily so --dry-run works without them.


SUPPORTED_MODELS = ["superanimal_topviewmouse", "superanimal_quadruped"]


@dataclass
class Config:
    server: str
    token: str
    worker_id: str
    poll_seconds: int
    work_dir: Path
    dry_run: bool


def parse_args() -> Config:
    p = argparse.ArgumentParser(description="EthoVision local tracker worker (DLC SuperAnimal)")
    p.add_argument("--server", default=os.environ.get("ETHOVISION_SERVER"),
                   help="Base URL of the EthoVision server (e.g. http://217.160.138.89:3012)")
    p.add_argument("--token", default=os.environ.get("ETHOVISION_TOKEN"),
                   help="Worker auth token")
    p.add_argument("--worker-id", default=os.environ.get("ETHOVISION_WORKER_ID", socket.gethostname()),
                   help="Human-readable worker id")
    p.add_argument("--poll-seconds", type=int, default=10, help="Seconds between polls when idle")
    p.add_argument("--work-dir", default=Path.home() / ".ethovision-worker",
                   help="Where to cache downloaded videos + DLC outputs")
    p.add_argument("--dry-run", action="store_true",
                   help="Poll and download, but skip DLC inference (useful for setup verification)")
    args = p.parse_args()

    if not args.server:
        sys.exit("--server (or ETHOVISION_SERVER) is required")
    if not args.token:
        sys.exit("--token (or ETHOVISION_TOKEN) is required")

    work_dir = Path(args.work_dir).expanduser()
    work_dir.mkdir(parents=True, exist_ok=True)

    return Config(
        server=args.server.rstrip("/"),
        token=args.token,
        worker_id=args.worker_id,
        poll_seconds=args.poll_seconds,
        work_dir=work_dir,
        dry_run=args.dry_run,
    )


def auth_headers(cfg: Config) -> dict[str, str]:
    return {"Authorization": f"Bearer {cfg.token}"}


def claim_next_job(cfg: Config) -> Optional[dict[str, Any]]:
    r = requests.post(
        f"{cfg.server}/api/jobs/next",
        json={"worker_id": cfg.worker_id, "models": SUPPORTED_MODELS},
        headers={**auth_headers(cfg), "content-type": "application/json"},
        timeout=20,
    )
    if r.status_code == 401:
        sys.exit("Unauthorized — check ETHOVISION_TOKEN matches the server's .env.local")
    r.raise_for_status()
    return r.json().get("job")


def report_progress(cfg: Config, job_id: str, **fields: Any) -> None:
    try:
        requests.post(
            f"{cfg.server}/api/jobs/{job_id}/progress",
            json=fields,
            headers={**auth_headers(cfg), "content-type": "application/json"},
            timeout=20,
        )
    except Exception as exc:
        print(f"  ! progress report failed: {exc}", file=sys.stderr)


def download_video(cfg: Config, video_id: str, dest: Path) -> None:
    url = f"{cfg.server}/api/videos/{video_id}/stream?original=1"
    with requests.get(url, stream=True, timeout=60) as r:
        r.raise_for_status()
        with dest.open("wb") as f:
            for chunk in r.iter_content(chunk_size=1 << 20):
                if chunk:
                    f.write(chunk)


def upload_result(cfg: Config, job_id: str, result_path: Path,
                  n_frames: int, n_keypoints: int) -> None:
    with result_path.open("rb") as f:
        files = {"file": (result_path.name, f, "application/octet-stream")}
        data = {"n_frames": str(n_frames), "n_keypoints": str(n_keypoints)}
        r = requests.post(
            f"{cfg.server}/api/jobs/{job_id}/result",
            files=files, data=data,
            headers=auth_headers(cfg),
            timeout=300,
        )
        r.raise_for_status()


POSE_MODEL = "hrnet_w32"
DETECTOR_MODEL = "fasterrcnn_mobilenet_v3_large_fpn"


def _expected_output_stem(video_stem: str, model: str) -> str:
    """File stem DLC uses for SuperAnimal outputs, e.g.
    '<video>_superanimal_topviewmouse_hrnet_w32_fasterrcnn_mobilenet_v3_large_fpn'."""
    return f"{video_stem}_{model}_{POSE_MODEL}_{DETECTOR_MODEL}"


def _find_existing_h5(video_path: Path, model: str) -> Optional[Path]:
    """Look for a DLC .h5 that matches this (video, model) in likely locations."""
    stem = _expected_output_stem(video_path.stem, model)
    for parent in (video_path.parent, video_path.parent / "pose_estimation"):
        if not parent.exists():
            continue
        # Exact match first
        exact = parent / f"{stem}.h5"
        if exact.exists():
            return exact
        # Fall back to anything that looks right (DLC sometimes varies suffixes).
        candidates = sorted(
            [p for p in parent.glob(f"{video_path.stem}_{model}_*.h5") if not p.name.endswith("_meta.h5")],
            key=lambda p: p.stat().st_mtime,
            reverse=True,
        )
        if candidates:
            return candidates[0]
    return None


def _h5_to_csv(h5_path: Path) -> Path:
    """Convert a DLC-format .h5 to a CSV compatible with our server parser.
    DLC's HDF5 stores a pandas DataFrame with a 3-level column index (scorer, bodypart, coord),
    which pandas exports to CSV with exactly the three header rows our parser expects."""
    import pandas as pd  # type: ignore[import-not-found]
    df = pd.read_hdf(h5_path)
    csv_path = h5_path.with_suffix(".csv")
    df.to_csv(csv_path)
    return csv_path


def run_dlc(cfg: Config, job: dict[str, Any], video_path: Path) -> tuple[Path, int, int]:
    """Run DLC SuperAnimal inference (or reuse existing output), return (csv_path, n_frames, n_keypoints)."""
    model = job["model"]
    if model not in SUPPORTED_MODELS:
        raise RuntimeError(f"unsupported model: {model}")

    # Shortcut: if we already have an .h5 from a previous run, skip inference.
    existing = _find_existing_h5(video_path, model)
    if existing is not None:
        report_progress(cfg, job["id"], status="running", progress=80,
                        progress_msg=f"Reusing cached inference ({existing.name})")
        print(f"  reusing existing DLC output: {existing.name}")
        csv_path = _h5_to_csv(existing)
        n_frames, n_keypoints = _summarize_dlc_csv(csv_path)
        report_progress(cfg, job["id"], progress=90,
                        progress_msg=f"Uploading {n_frames} frames")
        return csv_path, n_frames, n_keypoints

    report_progress(cfg, job["id"], status="running", progress=5, progress_msg="Loading DLC")
    import deeplabcut  # type: ignore[import-not-found]

    out_dir = video_path.parent
    report_progress(cfg, job["id"], status="running", progress=10, progress_msg="Starting inference")

    deeplabcut.video_inference_superanimal(
        videos=[str(video_path)],
        superanimal_name=model,
        model_name=POSE_MODEL,
        detector_name=DETECTOR_MODEL,
        videotype=video_path.suffix.lstrip("."),
        dest_folder=str(out_dir),
    )

    # DLC 3.x writes .h5 by default. Look for it in all plausible folders.
    h5 = _find_existing_h5(video_path, model)
    if h5 is None:
        # Last-ditch: grab the newest .h5 anywhere under out_dir.
        h5s = sorted(out_dir.rglob("*.h5"), key=lambda p: p.stat().st_mtime, reverse=True)
        h5 = h5s[0] if h5s else None
    if h5 is None:
        raise RuntimeError(
            "DLC ran but no .h5 output was found in "
            f"{out_dir} or its subfolders. Contents: {[p.name for p in out_dir.iterdir()]}"
        )

    csv_path = _h5_to_csv(h5)
    n_frames, n_keypoints = _summarize_dlc_csv(csv_path)
    report_progress(cfg, job["id"], progress=90, progress_msg=f"Uploading {n_frames} frames")
    return csv_path, n_frames, n_keypoints


def _summarize_dlc_csv(csv_path: Path) -> tuple[int, int]:
    """DLC CSVs: header rows (scorer, [individuals?,] bodyparts, coords). Each bodypart
    contributes x/y/likelihood columns. Count distinct bodyparts + data rows."""
    with csv_path.open() as f:
        lines = [f.readline() for _ in range(4)]
        body_idx = 2 if lines[1].lower().startswith("individuals") else 1
        bodypart_row = lines[body_idx].strip().split(",")
        # Remaining rows in file = number of frames
        rows = sum(1 for _ in f) + (0 if lines[3].strip() == "" else 1)
    bodyparts = {name for name in bodypart_row[1:] if name}
    return rows, len(bodyparts)


def handle_job(cfg: Config, job: dict[str, Any]) -> None:
    print(f"> claimed job {job['id']} for video {job['video_id']} ({job['model']})")
    video_dir = cfg.work_dir / job["video_id"]
    video_dir.mkdir(parents=True, exist_ok=True)
    video_path = video_dir / f"{job['video_id']}.mp4"

    try:
        report_progress(cfg, job["id"], status="running", progress=1, progress_msg="Downloading video")
        download_video(cfg, job["video_id"], video_path)
        print(f"  downloaded → {video_path} ({video_path.stat().st_size / 1_048_576:.1f} MB)")

        if cfg.dry_run:
            report_progress(cfg, job["id"], status="error", error="dry-run: DLC skipped")
            print("  dry-run: marking as error and moving on")
            return

        result_csv, n_frames, n_keypoints = run_dlc(cfg, job, video_path)
        upload_result(cfg, job["id"], result_csv, n_frames, n_keypoints)
        print(f"  done · {n_frames} frames · {n_keypoints} keypoints → uploaded {result_csv.name}")
    except Exception as exc:
        tb = traceback.format_exc()
        print(f"  ! job {job['id']} failed: {exc}\n{tb}", file=sys.stderr)
        report_progress(cfg, job["id"], status="error", error=str(exc))


_stopping = False


def _handle_signal(signum: int, _frame: Any) -> None:
    global _stopping
    _stopping = True
    print(f"\nCaught signal {signum}, finishing current job then exiting…")


def main() -> None:
    cfg = parse_args()
    signal.signal(signal.SIGINT, _handle_signal)
    signal.signal(signal.SIGTERM, _handle_signal)

    print(f"EthoVision worker · id={cfg.worker_id} · host={platform.node()} · server={cfg.server}")
    print(f"Supported models: {', '.join(SUPPORTED_MODELS)}")
    if cfg.dry_run:
        print("DRY-RUN mode: jobs will be claimed + downloaded but DLC will be skipped.")

    # Sanity-check GPU availability (DLC will still run on CPU if this fails, just slowly).
    if not cfg.dry_run:
        try:
            import torch  # type: ignore[import-not-found]
            cuda = torch.cuda.is_available()
            if cuda:
                print(f"GPU detected: {torch.cuda.get_device_name(0)} · CUDA {torch.version.cuda}")
            else:
                print("WARNING: no CUDA GPU detected — DLC will run on CPU (slow).")
        except Exception:
            print("WARNING: torch not importable — DLC will fail. Install with: pip install torch deeplabcut")

    while not _stopping:
        try:
            job = claim_next_job(cfg)
            if job:
                handle_job(cfg, job)
            else:
                time.sleep(cfg.poll_seconds)
        except requests.exceptions.ConnectionError:
            print(f"  server unreachable, retrying in {cfg.poll_seconds}s…")
            time.sleep(cfg.poll_seconds)
        except Exception as exc:
            print(f"  ! worker loop error: {exc}", file=sys.stderr)
            time.sleep(cfg.poll_seconds)

    print("Worker exiting.")


if __name__ == "__main__":
    main()
