Files
pi_mcps/mcp/mcp-image-gen/src/server.py
T

634 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""mcp-image-gen — FastMCP server for AI image generation via ComfyUI."""
import asyncio
import base64
import copy
import json
import logging
import os
import random
import re
import subprocess
import time
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
from typing import Annotated
import httpx
from fastmcp import FastMCP
from mcp.types import ImageContent, TextContent
from pydantic import Field
logger = logging.getLogger("mcp-image-gen")
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
COMFYUI_URL = os.environ.get("COMFYUI_URL", "http://localhost:8188").rstrip("/")
IMAGE_OUTPUT_DIR = os.environ.get("IMAGE_OUTPUT_DIR", "~/Pictures/mcp-generated")
COMFYUI_TIMEOUT = int(os.environ.get("COMFYUI_TIMEOUT", "120"))
# Directory where ComfyUI is installed (used for auto-start only)
# Override via COMFYUI_DIR env var. Systemd service sets this automatically.
COMFYUI_DIR = Path(
os.environ.get("COMFYUI_DIR", "~/ComfyUI")
).expanduser().resolve()
# Maximum number of images allowed in a single batch call
MAX_COUNT = 10
# Workflow registry: model filename → workflow JSON path
# This allows us to support multiple models (FLUX.1-schnell + FLUX.2 Klein with Heretic encoder)
_WORKFLOW_REGISTRY: dict[str, Path] = {
"flux1-schnell.safetensors": Path(__file__).parent / "workflows" / "flux_schnell.json",
"flux-2-klein-4b.safetensors": Path(__file__).parent / "workflows" / "flux2_klein_heretic.json",
}
_DEFAULT_MODEL = "flux1-schnell.safetensors"
# ---------------------------------------------------------------------------
# ComfyUI health check + auto-start
# ---------------------------------------------------------------------------
async def _ping_comfyui(url: str, timeout: float = 5.0) -> bool:
"""Return True if ComfyUI is reachable at *url*/system_stats."""
try:
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.get(f"{url}/system_stats")
return resp.status_code == 200
except (httpx.ConnectError, httpx.TimeoutException, OSError):
return False
async def check_and_start_comfyui() -> None:
"""Ping ComfyUI; if not reachable, attempt to launch it as a subprocess.
Called once at server startup from the lifespan context manager.
Uses COMFYUI_DIR to locate the installation and its venv Python.
The HSA_OVERRIDE_GFX_VERSION=11.0.0 env var is injected automatically
for AMD ROCm / RX 7900 XTX compatibility.
"""
if await _ping_comfyui(COMFYUI_URL):
logger.info("ComfyUI is already running at %s", COMFYUI_URL)
return
logger.warning(
"ComfyUI not reachable at %s — attempting to start from %s",
COMFYUI_URL, COMFYUI_DIR,
)
python = COMFYUI_DIR / ".venv" / "bin" / "python"
main_py = COMFYUI_DIR / "main.py"
if not python.exists():
logger.error(
"ComfyUI venv Python not found at %s. "
"Install ComfyUI first (see docs/wiki/pages/mcp-image-gen-ComfyUI-Setup.md).",
python,
)
return
if not main_py.exists():
logger.error(
"ComfyUI main.py not found at %s — is COMFYUI_DIR correct?",
main_py,
)
return
# Build environment: inherit current env, set ROCm override for AMD RX 7900 XTX
env = os.environ.copy()
env.setdefault("HSA_OVERRIDE_GFX_VERSION", "11.0.0")
try:
proc = subprocess.Popen(
[str(python), str(main_py), "--listen", "--port", "8188"],
cwd=str(COMFYUI_DIR),
env=env,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True, # detach from MCP server process group
)
logger.info("ComfyUI launched (PID %d) — waiting for readiness…", proc.pid)
except OSError as exc:
logger.error("Failed to start ComfyUI subprocess: %s", exc)
return
# Wait up to 30 s for ComfyUI to become ready (polls every 2 s)
wait_limit = 30
for attempt in range(wait_limit // 2):
await asyncio.sleep(2)
if await _ping_comfyui(COMFYUI_URL):
logger.info(
"ComfyUI ready at %s after ~%ds ✓", COMFYUI_URL, (attempt + 1) * 2
)
return
logger.warning(
"ComfyUI did not respond within %ds. "
"Generation calls will fail until it is ready. "
"Check logs: journalctl --user -u comfyui -f",
wait_limit,
)
@asynccontextmanager
async def lifespan(app):
"""FastMCP lifespan: run ComfyUI health check at server startup."""
await check_and_start_comfyui()
yield # server is live here
# Nothing to tear down — ComfyUI is managed by systemd, not this process
mcp = FastMCP("mcp-image-gen", lifespan=lifespan)
# ---------------------------------------------------------------------------
# ComfyUI client
# ---------------------------------------------------------------------------
class ComfyUIClient:
"""Async HTTP client wrapper for the ComfyUI REST API."""
def __init__(self, base_url: str = COMFYUI_URL):
self.base_url = base_url.rstrip("/")
async def queue_prompt(self, workflow: dict) -> str:
"""Submit a workflow to ComfyUI and return the prompt_id."""
# Strip internal metadata keys (e.g. "_meta") — they are not ComfyUI nodes
clean_workflow = {k: v for k, v in workflow.items() if not k.startswith("_")}
payload = {"prompt": clean_workflow}
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(f"{self.base_url}/api/prompt", json=payload)
resp.raise_for_status()
return resp.json()["prompt_id"]
async def get_status(self, prompt_id: str) -> dict:
"""Return the current queue state (queue_running + queue_pending lists)."""
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/api/queue")
resp.raise_for_status()
return resp.json()
async def get_history(self, prompt_id: str) -> dict:
"""Return the history entry for a completed prompt_id."""
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/api/history/{prompt_id}")
resp.raise_for_status()
return resp.json()
async def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
"""Download image bytes from ComfyUI's /api/view endpoint."""
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.get(f"{self.base_url}/api/view", params=params)
resp.raise_for_status()
return resp.content
async def get_models(self) -> list[str]:
"""Return the list of available checkpoint model filenames.
Combines models known to ComfyUI with our internal registry
(including FLUX.2 Klein with Heretic encoder).
"""
models = set()
# Get models from ComfyUI
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
f"{self.base_url}/object_info/CheckpointLoaderSimple"
)
resp.raise_for_status()
data = resp.json()
node_info = data.get("CheckpointLoaderSimple", {})
ckpt_list = (
node_info.get("input", {})
.get("required", {})
.get("ckpt_name", [[]])[0]
)
if isinstance(ckpt_list, list):
models.update(ckpt_list)
except Exception:
# ComfyUI not reachable — fall back to registry only
pass
# Add our registered models
models.update(_WORKFLOW_REGISTRY.keys())
return sorted(list(models))
# ---------------------------------------------------------------------------
# Workflow builder
# ---------------------------------------------------------------------------
def build_flux_workflow(
prompt: str,
neg_prompt: str,
width: int,
height: int,
steps: int,
seed: int,
model: str = _DEFAULT_MODEL,
) -> dict:
"""Build a ComfyUI API-format workflow dict for the requested model.
Supports:
- "flux1-schnell.safetensors" (original)
- "flux-2-klein-4b-fp8.safetensors" (with Heretic-abliterated Qwen3-4B text encoder)
Falls back to FLUX.1-schnell if model is unknown.
This is a pure function — no I/O outside the registry, fully testable.
"""
workflow_path = _WORKFLOW_REGISTRY.get(model, _WORKFLOW_REGISTRY[_DEFAULT_MODEL])
# Load workflow as text first — replace string placeholders
raw = workflow_path.read_text()
actual_seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
raw = raw.replace('"PROMPT_PLACEHOLDER"', json.dumps(prompt))
raw = raw.replace('"NEGATIVE_PLACEHOLDER"', json.dumps(neg_prompt))
wf = json.loads(raw)
wf = copy.deepcopy(wf)
# Recursively inject numeric values into matching field names
_inject_workflow_params(wf, {
"width": width,
"height": height,
"steps": steps,
"seed": actual_seed,
"noise_seed": actual_seed,
"unet_name": model,
})
# Attach the actual seed as metadata so callers can retrieve it
wf["_meta"] = {"actual_seed": actual_seed}
return wf
def _inject_workflow_params(node: dict | list, params: dict) -> None:
"""Recursively walk a workflow dict/list and inject parameter values.
For each dict encountered, if it has an "inputs" sub-dict, update
any matching field names from params. This is model-agnostic and
works regardless of ComfyUI node IDs.
"""
if isinstance(node, dict):
if "inputs" in node and isinstance(node["inputs"], dict):
for key, value in params.items():
if key in node["inputs"] and not isinstance(node["inputs"][key], list):
node["inputs"][key] = value
for v in node.values():
_inject_workflow_params(v, params)
elif isinstance(node, list):
for item in node:
_inject_workflow_params(item, params)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _sanitize_name(name: str) -> str:
"""Sanitize a user-provided name for safe use in filenames.
Replaces whitespace with underscores, strips any characters that are not
alphanumeric, underscores, or hyphens, and collapses consecutive
underscores/hyphens. Returns empty string if nothing usable remains.
"""
name = name.strip()
name = re.sub(r"\s+", "_", name) # spaces → underscores
name = re.sub(r"[^\w\-]", "", name) # strip non-alphanum/underscore/hyphen
name = re.sub(r"[_\-]{2,}", "_", name) # collapse runs
name = name.strip("_-") # trim leading/trailing separators
return name[:64] # cap at 64 chars
def _build_filename(name: str, timestamp: str, actual_seed: int) -> str:
"""Build an output filename from optional name, timestamp and seed."""
sanitized = _sanitize_name(name)
if sanitized:
return f"{sanitized}_{timestamp}_{actual_seed}.png"
return f"{timestamp}_{actual_seed}.png"
async def _generate_single(
client: ComfyUIClient,
prompt: str,
negative_prompt: str,
width: int,
height: int,
steps: int,
seed: int,
model: str,
resolved_output_dir: Path,
name: str,
label: str,
) -> list:
"""Generate a single image and return [TextContent, ImageContent] or [TextContent] on error.
Supports two models:
- flux1-schnell.safetensors (default, fast 4-step)
- flux-2-klein-4b.safetensors (with Heretic-abliterated Qwen3-4B text encoder — no refusals)
"""
if model not in _WORKFLOW_REGISTRY:
model = _DEFAULT_MODEL
logger.warning("Unknown model %s, falling back to %s", model, _DEFAULT_MODEL)
# Build and submit workflow
try:
workflow = build_flux_workflow(
prompt=prompt,
neg_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
seed=seed,
model=model,
)
actual_seed = workflow["_meta"]["actual_seed"]
prompt_id = await client.queue_prompt(workflow)
except httpx.ConnectError:
return [
TextContent(
type="text",
text=(
f"{label} ComfyUI not reachable at {COMFYUI_URL}. "
"Start it with: python main.py --listen"
),
)
]
except httpx.HTTPStatusError as e:
return [
TextContent(
type="text",
text=f"{label} ComfyUI returned an error: {e.response.status_code}{e.response.text}",
)
]
# Poll until done
start = time.time()
while True:
elapsed = time.time() - start
if elapsed > COMFYUI_TIMEOUT:
return [
TextContent(
type="text",
text=(
f"{label} Generation timed out after {COMFYUI_TIMEOUT}s. "
f"prompt_id={prompt_id} — use get_generation_status to check"
),
)
]
try:
queue = await client.get_status(prompt_id)
except (httpx.ConnectError, httpx.HTTPStatusError):
await asyncio.sleep(2)
continue
running_ids = [item[1] for item in queue.get("queue_running", [])]
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
if prompt_id not in running_ids and prompt_id not in pending_ids:
break # Job is done
await asyncio.sleep(2)
elapsed = time.time() - start
# Retrieve history to find output filename
try:
history = await client.get_history(prompt_id)
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
return [
TextContent(
type="text",
text=f"{label} Failed to retrieve generation history: {e}",
)
]
job = history.get(prompt_id, {})
outputs = job.get("outputs", {})
# Find SaveImage node output (node "9" in our workflow)
image_info = None
for node_id, node_output in outputs.items():
images = node_output.get("images", [])
if images:
image_info = images[0]
break
if not image_info:
return [
TextContent(
type="text",
text=f"{label} No output image found in history for prompt_id={prompt_id}",
)
]
# Download image bytes
try:
image_bytes = await client.get_image(
filename=image_info["filename"],
subfolder=image_info.get("subfolder", ""),
folder_type=image_info.get("type", "output"),
)
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
return [
TextContent(
type="text",
text=f"{label} Failed to download generated image: {e}",
)
]
# Save to disk
try:
resolved_output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = _build_filename(name, timestamp, actual_seed)
out_path = resolved_output_dir / filename
out_path.write_bytes(image_bytes)
except OSError as e:
return [
TextContent(
type="text",
text=f"{label} Cannot write to output directory: {resolved_output_dir}{e}",
)
]
# Encode as base64 for inline display
b64_data = base64.b64encode(image_bytes).decode("utf-8")
return [
TextContent(
type="text",
text=(
f"{label} Generated: {out_path}\n"
f"Seed: {actual_seed}\n"
f"Elapsed: {elapsed:.1f}s\n"
f"Size: {width}x{height}, Steps: {steps}, Model: {model}"
),
),
ImageContent(
type="image",
data=b64_data,
mimeType="image/png",
),
]
# ---------------------------------------------------------------------------
# Tools
# ---------------------------------------------------------------------------
@mcp.tool()
async def generate_image(
prompt: Annotated[str, Field(description="Text description of the image to generate.")],
width: Annotated[int, Field(description="Image width in pixels (default: 1024).")] = 1024,
height: Annotated[int, Field(description="Image height in pixels (default: 1024).")] = 1024,
steps: Annotated[int, Field(description="Number of inference steps. FLUX.1-schnell works well at 4.")] = 4,
model: Annotated[str, Field(description="ComfyUI model filename (default: flux1-schnell.safetensors).")] = "flux1-schnell.safetensors",
seed: Annotated[int, Field(description="Random seed for reproducibility. -1 = random. When count > 1 and seed != -1, seeds are incremented per image (seed, seed+1, seed+2, ...) to produce deterministic variation.")] = -1,
negative_prompt: Annotated[str, Field(description="Things to exclude from the image (optional).")] = "",
output_dir: Annotated[str, Field(description="Override output directory. Defaults to IMAGE_OUTPUT_DIR env var or ~/Pictures/mcp-generated.")] = "",
name: Annotated[str, Field(description="Optional filename prefix. Saved as {name}_{timestamp}_{seed}.png. Useful to avoid confusion with auto-generated timestamp filenames.")] = "",
count: Annotated[int, Field(description="Number of images to generate (110). Each image is generated sequentially. Partial failures are returned inline — the batch continues even if one image fails.")] = 1,
) -> list:
"""Generate an image from a text prompt using ComfyUI.
Returns both a file path (for persistence) and an inline base64 image
(for display in Claude / Roo Code chat).
Returns:
Flat interleaved list: [TextContent1, ImageContent1, TextContent2, ImageContent2, ...]
On error for any single image, that slot contains only [TextContent(error)].
"""
# Validate count
if count < 1:
return [
TextContent(
type="text",
text=f"count must be at least 1 (got {count}).",
)
]
if count > MAX_COUNT:
return [
TextContent(
type="text",
text=f"count must be at most {MAX_COUNT} (got {count}). Use multiple calls for larger batches.",
)
]
# Resolve output directory once
resolved_output_dir = Path(
output_dir or IMAGE_OUTPUT_DIR
).expanduser().resolve()
client = ComfyUIClient(COMFYUI_URL)
results = []
for i in range(1, count + 1):
# Compute seed for this image:
# - seed=-1 → each image gets an independent random seed
# - fixed seed → increment by i-1 for deterministic variation across the batch
image_seed = seed if seed == -1 else seed + (i - 1)
label = f"[{_sanitize_name(name) or 'image'} {i}/{count}]" if count > 1 else (
f"[{_sanitize_name(name)}]" if _sanitize_name(name) else ""
)
single_result = await _generate_single(
client=client,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
seed=image_seed,
model=model,
resolved_output_dir=resolved_output_dir,
name=name,
label=label,
)
results.extend(single_result)
return results
@mcp.tool()
async def list_available_models() -> list[str]:
"""List all checkpoint models available in ComfyUI.
Returns a list of model filenames available for use with generate_image.
Requires ComfyUI to be running at COMFYUI_URL.
"""
client = ComfyUIClient(COMFYUI_URL)
try:
return await client.get_models()
except httpx.ConnectError:
return [
f"ComfyUI not reachable at {COMFYUI_URL}. "
"Start it with: python main.py --listen"
]
except httpx.HTTPStatusError as e:
return [f"ComfyUI error: {e.response.status_code}"]
@mcp.tool()
async def get_generation_status(
prompt_id: Annotated[str, Field(description="The prompt ID returned by a previous generate_image call.")],
) -> dict:
"""Check the status of a queued or running generation job.
Returns:
Dict with 'status' key: "pending", "running", "completed", or "not_found".
"""
client = ComfyUIClient(COMFYUI_URL)
try:
queue = await client.get_status(prompt_id)
running_ids = [item[1] for item in queue.get("queue_running", [])]
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
if prompt_id in running_ids:
return {"status": "running", "prompt_id": prompt_id}
if prompt_id in pending_ids:
return {"status": "pending", "prompt_id": prompt_id}
# Not in queue — check history
try:
history = await client.get_history(prompt_id)
if prompt_id in history:
return {"status": "completed", "prompt_id": prompt_id}
except (httpx.ConnectError, httpx.HTTPStatusError):
pass
return {"status": "not_found", "prompt_id": prompt_id}
except httpx.ConnectError:
return {
"status": "error",
"message": f"ComfyUI not reachable at {COMFYUI_URL}",
}
except httpx.HTTPStatusError as e:
return {"status": "error", "message": f"HTTP {e.response.status_code}"}
@mcp.tool()
def get_output_directory() -> str:
"""Return the directory where generated images are saved.
Returns:
Absolute path to the output directory (may not exist yet).
"""
return str(Path(IMAGE_OUTPUT_DIR).expanduser().resolve())
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
mcp.run(transport="stdio")