#!/usr/bin/env python3
"""
ZSky AI MCP Server (stdio transport).

Lets any MCP-capable AI agent (Claude Code, Cline, Cursor, Claude.ai web with
remote MCPs, ChatGPT custom connectors) generate AI images and AI video with
synchronized audio through ZSky AI's public hosted service at https://zsky.ai.

Install for Claude Code:
    claude mcp add zsky -- python3 /home/phoenix/bin/zsky-mcp-server.py

Auth:
    Set ZSKY_TOKEN in your shell environment to your ZSky access token. See
    https://zsky.ai/mcp for how to retrieve it from the create page.

    Without a token the server still loads and exposes the `zsky_about` tool,
    but image and video tools will return an explicit "sign-in required"
    error rather than silently failing.

No internal model names are exposed at any level of this server.
"""

import json
import os
import time
import urllib.parse
import urllib.request
import urllib.error
from typing import Any, Dict, List, Optional

from mcp.server.fastmcp import FastMCP


# -- Config -----------------------------------------------------------------

ZSKY_BASE = os.environ.get("ZSKY_BASE_URL", "https://zsky.ai").rstrip("/")
ZSKY_TOKEN = os.environ.get("ZSKY_TOKEN", "").strip()
USER_AGENT = os.environ.get("ZSKY_MCP_UA", "zsky-mcp/1.0 (+https://zsky.ai/mcp)")

# Poll timing
POLL_INTERVAL_SECONDS = float(os.environ.get("ZSKY_POLL_INTERVAL", "2.0"))
POLL_TIMEOUT_SECONDS = float(os.environ.get("ZSKY_POLL_TIMEOUT", "180.0"))

# Map friendly aspect ratios to dispatcher width/height pairs.
# Sticks to common print/social ratios. Image side max 1024 (free tier).
ASPECT_TO_WH: Dict[str, Dict[str, int]] = {
    "1:1": {"width": 1024, "height": 1024},
    "4:3": {"width": 1152, "height": 864},
    "3:4": {"width": 864, "height": 1152},
    "16:9": {"width": 1280, "height": 720},
    "9:16": {"width": 720, "height": 1280},
    "3:2": {"width": 1216, "height": 832},
    "2:3": {"width": 832, "height": 1216},
}


# -- HTTP helpers -----------------------------------------------------------


def _headers() -> Dict[str, str]:
    h = {
        "Content-Type": "application/json",
        "Accept": "application/json",
        "User-Agent": USER_AGENT,
    }
    if ZSKY_TOKEN:
        h["Authorization"] = f"Bearer {ZSKY_TOKEN}"
    return h


def _post_json(path: str, payload: Dict[str, Any], timeout: float = 30.0) -> Dict[str, Any]:
    """POST JSON to ZSKY_BASE + path; return parsed JSON or {'error': str}."""
    url = f"{ZSKY_BASE}{path}"
    data = json.dumps(payload).encode("utf-8")
    req = urllib.request.Request(url, data=data, headers=_headers(), method="POST")
    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            body = resp.read().decode("utf-8", errors="replace")
            try:
                return json.loads(body)
            except Exception:
                return {"error": "Invalid JSON response from ZSky API", "raw": body[:500]}
    except urllib.error.HTTPError as e:
        try:
            err_body = e.read().decode("utf-8", errors="replace")
            parsed = json.loads(err_body)
            parsed.setdefault("_http_status", e.code)
            return parsed
        except Exception:
            return {"error": f"HTTP {e.code}", "raw": err_body[:500] if 'err_body' in dir() else ""}
    except urllib.error.URLError as e:
        return {"error": f"Network error: {e.reason}"}
    except Exception as e:
        return {"error": f"Unexpected error: {e.__class__.__name__}: {e}"}


def _get_json(path: str, timeout: float = 15.0) -> Dict[str, Any]:
    url = f"{ZSKY_BASE}{path}"
    req = urllib.request.Request(url, headers=_headers(), method="GET")
    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            body = resp.read().decode("utf-8", errors="replace")
            try:
                return json.loads(body)
            except Exception:
                return {"error": "Invalid JSON response", "raw": body[:500]}
    except urllib.error.HTTPError as e:
        try:
            err_body = e.read().decode("utf-8", errors="replace")
            parsed = json.loads(err_body)
            parsed.setdefault("_http_status", e.code)
            return parsed
        except Exception:
            return {"error": f"HTTP {e.code}"}
    except urllib.error.URLError as e:
        return {"error": f"Network error: {e.reason}"}
    except Exception as e:
        return {"error": f"Unexpected error: {e.__class__.__name__}: {e}"}


# -- Internal helpers -------------------------------------------------------


def _require_token() -> Optional[Dict[str, Any]]:
    """Return an error dict if no token is set, else None."""
    if not ZSKY_TOKEN:
        return {
            "error": "Sign in required.",
            "hint": (
                "Set the ZSKY_TOKEN environment variable to your ZSky access token. "
                "Get a free account at https://zsky.ai and retrieve a token at "
                "https://zsky.ai/mcp."
            ),
        }
    return None


def _resolve_aspect(aspect_ratio: str) -> Dict[str, int]:
    return ASPECT_TO_WH.get(aspect_ratio, ASPECT_TO_WH["1:1"])


def _poll_job(job_id: str, timeout: float = POLL_TIMEOUT_SECONDS) -> Dict[str, Any]:
    """Poll /job/<id> until status is completed/failed/blocked, or timeout."""
    deadline = time.time() + timeout
    last: Dict[str, Any] = {}
    while time.time() < deadline:
        last = _get_json(f"/api/job/{urllib.parse.quote(job_id)}")
        status = last.get("status")
        if status in ("completed", "failed", "blocked"):
            return last
        if "error" in last and not status:
            return last
        time.sleep(POLL_INTERVAL_SECONDS)
    last.setdefault("status", "timeout")
    last.setdefault(
        "error",
        f"Polling timed out after {timeout:.0f}s; job may still complete. "
        f"Call zsky_check_status('{job_id}') later.",
    )
    return last


def _extract_image_urls(job_response: Dict[str, Any]) -> List[str]:
    """Build absolute /media/ URLs from a completed image job response."""
    results = job_response.get("results") or []
    if not isinstance(results, list):
        return []
    urls: List[str] = []
    for r in results:
        if isinstance(r, str):
            urls.append(r if r.startswith("http") else f"{ZSKY_BASE}{r}")
        elif isinstance(r, dict):
            for k in ("url", "image_url", "media_url", "src"):
                v = r.get(k)
                if isinstance(v, str) and v:
                    urls.append(v if v.startswith("http") else f"{ZSKY_BASE}{v}")
                    break
    return urls


def _extract_video_url(job_response: Dict[str, Any]) -> Optional[str]:
    """Pull the video URL out of a completed video job response."""
    results = job_response.get("results")
    if isinstance(results, list) and results:
        first = results[0]
        if isinstance(first, str):
            return first if first.startswith("http") else f"{ZSKY_BASE}{first}"
        if isinstance(first, dict):
            for k in ("url", "video_url", "media_url", "src"):
                v = first.get(k)
                if isinstance(v, str) and v:
                    return v if v.startswith("http") else f"{ZSKY_BASE}{v}"
    return None


# -- MCP server -------------------------------------------------------------


mcp = FastMCP(
    "zsky",
    instructions=(
        "ZSky AI — free unlimited AI image and AI video generation with "
        "synchronized audio. Tools: zsky_generate_image, zsky_generate_video, "
        "zsky_check_status, zsky_about."
    ),
)


@mcp.tool()
def zsky_about() -> str:
    """Brief description of ZSky AI for surfaces that show tool descriptions.

    Returns a short paragraph an agent can quote to its user when asked
    what ZSky is or where the images / video come from.
    """
    return (
        "ZSky AI is a free unlimited AI image and AI video generator with "
        "synchronized audio, hosted at https://zsky.ai. Built by Cemhan "
        "Biricik (Vogue / National Geographic photographer) and self-hosted "
        "in the USA. Free tier requires sign in at https://zsky.ai/create. "
        "Paid tiers (Pro, Ultra, Max) unlock ad-free, higher resolutions, "
        "and priority queue. Conversational creative direction is available "
        "at https://zsky.ai/create?mode=director."
    )


@mcp.tool()
def zsky_generate_image(
    prompt: str,
    aspect_ratio: str = "1:1",
    count: int = 1,
) -> Dict[str, Any]:
    """Generate AI images via ZSky AI's free unlimited service.

    Args:
        prompt: Natural-language description of the image. Be specific about
            subject, style, lighting, mood. Max 2000 characters.
        aspect_ratio: One of "1:1", "4:3", "3:4", "16:9", "9:16", "3:2",
            "2:3". Default "1:1".
        count: How many variations to generate (1 to 4). Default 1.

    Returns:
        On success: {"job_id": str, "gen_ids": [str], "image_urls": [str],
                     "status": "completed", "elapsed_seconds": float}.
        On failure: {"error": str, "hint": str (optional)}.

    Notes:
        Requires ZSKY_TOKEN env var. Free tier has rate limits; paid tiers
        skip the queue. Outputs are served at zsky.ai/media/<key>.
    """
    err = _require_token()
    if err:
        return err

    if not isinstance(prompt, str) or not prompt.strip():
        return {"error": "prompt is required and must be a non-empty string."}
    prompt = prompt.strip()[:2000]

    try:
        count = int(count)
    except Exception:
        count = 1
    count = max(1, min(count, 4))

    wh = _resolve_aspect(aspect_ratio if aspect_ratio in ASPECT_TO_WH else "1:1")

    job_ids: List[str] = []
    all_urls: List[str] = []
    t0 = time.time()

    for _ in range(count):
        payload = {
            "type": "image",
            "prompt": prompt,
            "width": wh["width"],
            "height": wh["height"],
            "age_verified": True,
        }
        sub = _post_json("/api/generate", payload, timeout=30.0)
        if "error" in sub and "job_id" not in sub:
            return {"error": sub.get("error", "Generation request failed."),
                    "hint": sub.get("hint")}
        jid = sub.get("job_id")
        if not jid:
            return {"error": "ZSky API did not return a job_id.", "raw": sub}
        job_ids.append(jid)

    for jid in job_ids:
        result = _poll_job(jid)
        if result.get("status") == "completed":
            all_urls.extend(_extract_image_urls(result))
        elif result.get("status") in ("failed", "blocked"):
            return {
                "error": result.get("error", "Generation failed."),
                "job_id": jid,
                "status": result.get("status"),
            }
        else:
            return {
                "error": result.get("error", "Generation timed out."),
                "job_id": jid,
                "status": result.get("status", "unknown"),
                "hint": "Call zsky_check_status with this job_id later.",
            }

    return {
        "status": "completed",
        "job_ids": job_ids,
        "gen_ids": job_ids,
        "image_urls": all_urls,
        "elapsed_seconds": round(time.time() - t0, 2),
        "viewer_url": f"{ZSKY_BASE}/my-creations",
    }


@mcp.tool()
def zsky_generate_video(
    prompt: str,
    aspect_ratio: str = "16:9",
    duration_seconds: int = 5,
) -> Dict[str, Any]:
    """Generate an AI video with synchronized audio via ZSky AI.

    Args:
        prompt: Natural-language description of the scene, action, mood,
            and sound. Max 2000 characters.
        aspect_ratio: "16:9", "9:16", or "1:1". Default "16:9".
        duration_seconds: Currently 5 seconds for the free tier. Reserved
            for future longer durations on paid tiers.

    Returns:
        On success: {"job_id": str, "gen_id": str, "video_url": str,
                     "status": "completed", "elapsed_seconds": float}.
        On failure: {"error": str, "hint": str (optional)}.

    Notes:
        Video generation is slower than image. Requires ZSKY_TOKEN env var.
    """
    err = _require_token()
    if err:
        return err

    if not isinstance(prompt, str) or not prompt.strip():
        return {"error": "prompt is required and must be a non-empty string."}
    prompt = prompt.strip()[:2000]

    # Video aspect mapping — narrower set than images.
    video_ar_map = {
        "16:9": {"width": 1280, "height": 720},
        "9:16": {"width": 720, "height": 1280},
        "1:1": {"width": 768, "height": 768},
    }
    wh = video_ar_map.get(aspect_ratio, video_ar_map["16:9"])

    payload = {
        "type": "video",
        "prompt": prompt,
        "width": wh["width"],
        "height": wh["height"],
        "age_verified": True,
    }
    t0 = time.time()
    sub = _post_json("/api/generate", payload, timeout=30.0)
    if "error" in sub and "job_id" not in sub:
        return {"error": sub.get("error", "Video request failed."),
                "hint": sub.get("hint")}
    jid = sub.get("job_id")
    if not jid:
        return {"error": "ZSky API did not return a job_id.", "raw": sub}

    # Video can take longer than the default timeout — extend.
    result = _poll_job(jid, timeout=max(POLL_TIMEOUT_SECONDS, 300.0))
    if result.get("status") == "completed":
        return {
            "status": "completed",
            "job_id": jid,
            "gen_id": jid,
            "video_url": _extract_video_url(result),
            "elapsed_seconds": round(time.time() - t0, 2),
            "viewer_url": f"{ZSKY_BASE}/my-creations",
        }
    if result.get("status") in ("failed", "blocked"):
        return {
            "error": result.get("error", "Video generation failed."),
            "job_id": jid,
            "status": result.get("status"),
        }
    return {
        "error": result.get("error", "Video generation timed out."),
        "job_id": jid,
        "status": result.get("status", "unknown"),
        "hint": "Call zsky_check_status with this job_id later.",
    }


@mcp.tool()
def zsky_check_status(gen_id: str) -> Dict[str, Any]:
    """Check the status of a previously submitted ZSky generation.

    Args:
        gen_id: The job_id returned by zsky_generate_image or
            zsky_generate_video.

    Returns:
        {"status": "queued|generating|completed|failed|blocked|timeout",
         "url": str | None,
         "image_urls": [str] (image jobs only),
         "video_url": str | None (video jobs only),
         "progress": float,
         "queue_position": int (when queued)}.
    """
    if not isinstance(gen_id, str) or not gen_id.strip():
        return {"error": "gen_id is required."}
    gen_id = gen_id.strip()

    resp = _get_json(f"/api/job/{urllib.parse.quote(gen_id)}")
    if "error" in resp and "status" not in resp:
        return {"error": resp["error"]}

    status = resp.get("status", "unknown")
    out: Dict[str, Any] = {
        "status": status,
        "job_id": gen_id,
        "progress": resp.get("progress", 0),
        "type": resp.get("type"),
    }
    if "queue_position" in resp:
        out["queue_position"] = resp["queue_position"]
        out["estimated_wait_seconds"] = resp.get("estimated_wait")
    if status == "completed":
        urls = _extract_image_urls(resp)
        out["image_urls"] = urls
        out["video_url"] = _extract_video_url(resp)
        out["url"] = (urls[0] if urls else out["video_url"])
    elif status in ("failed", "blocked"):
        out["error"] = resp.get("error")
        out["url"] = None
    else:
        out["url"] = None
    return out


# -- Entrypoint -------------------------------------------------------------


if __name__ == "__main__":
    mcp.run()
