Files
pi_mcps/mcp/mcp-image-gen/src/server.py
T
Patrick Plate 79f1e6d65f feat(mcp-image-gen): add name and count params to generate_image
- Add name (str) param: filename prefix saved as {name}_{timestamp}_{seed}.png
- Add count (int, 1-10) param: generate N images in one call
- Extract _sanitize_name() helper: strips special chars, collapses underscores, caps at 64 chars
- Extract _build_filename() helper: pure function for testable filename construction
- Extract _generate_single() coroutine: clean loop body for batch generation
- Fixed seed batches increment seed per image (seed+i-1) for deterministic variation
- random seed (-1) batches give independent random seeds per image
- Partial batch failures continue (error TextContent in slot, remaining images proceed)
- Returns flat interleaved [Text1, Image1, Text2, Image2, ...] list
- 34/34 tests passing (was 19, added 15 new tests)
2026-04-06 07:45:37 +02:00

500 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""mcp-image-gen — FastMCP server for AI image generation via ComfyUI."""
import asyncio
import base64
import copy
import json
import os
import random
import re
import time
from datetime import datetime
from pathlib import Path
import httpx
from fastmcp import FastMCP
from mcp.types import ImageContent, TextContent
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
COMFYUI_URL = os.environ.get("COMFYUI_URL", "http://localhost:8188").rstrip("/")
IMAGE_OUTPUT_DIR = os.environ.get("IMAGE_OUTPUT_DIR", "~/Pictures/mcp-generated")
COMFYUI_TIMEOUT = int(os.environ.get("COMFYUI_TIMEOUT", "120"))
# Maximum number of images allowed in a single batch call
MAX_COUNT = 10
# Path to the bundled FLUX.1-schnell workflow template
_WORKFLOW_PATH = Path(__file__).parent / "workflows" / "flux_schnell.json"
mcp = FastMCP("mcp-image-gen")
# ---------------------------------------------------------------------------
# ComfyUI client
# ---------------------------------------------------------------------------
class ComfyUIClient:
"""Async HTTP client wrapper for the ComfyUI REST API."""
def __init__(self, base_url: str = COMFYUI_URL):
self.base_url = base_url.rstrip("/")
async def queue_prompt(self, workflow: dict) -> str:
"""Submit a workflow to ComfyUI and return the prompt_id."""
# Strip internal metadata keys (e.g. "_meta") — they are not ComfyUI nodes
clean_workflow = {k: v for k, v in workflow.items() if not k.startswith("_")}
payload = {"prompt": clean_workflow}
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(f"{self.base_url}/api/prompt", json=payload)
resp.raise_for_status()
return resp.json()["prompt_id"]
async def get_status(self, prompt_id: str) -> dict:
"""Return the current queue state (queue_running + queue_pending lists)."""
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/api/queue")
resp.raise_for_status()
return resp.json()
async def get_history(self, prompt_id: str) -> dict:
"""Return the history entry for a completed prompt_id."""
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/api/history/{prompt_id}")
resp.raise_for_status()
return resp.json()
async def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
"""Download image bytes from ComfyUI's /api/view endpoint."""
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.get(f"{self.base_url}/api/view", params=params)
resp.raise_for_status()
return resp.content
async def get_models(self) -> list[str]:
"""Return the list of available checkpoint model filenames."""
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
f"{self.base_url}/object_info/CheckpointLoaderSimple"
)
resp.raise_for_status()
data = resp.json()
# ComfyUI returns: {"CheckpointLoaderSimple": {"input": {"required": {"ckpt_name": [["model1.safetensors", ...], ...]}}}}
node_info = data.get("CheckpointLoaderSimple", {})
ckpt_list = (
node_info.get("input", {})
.get("required", {})
.get("ckpt_name", [[]])[0]
)
return ckpt_list if isinstance(ckpt_list, list) else []
# ---------------------------------------------------------------------------
# Workflow builder
# ---------------------------------------------------------------------------
def build_flux_workflow(
prompt: str,
neg_prompt: str,
width: int,
height: int,
steps: int,
seed: int,
model: str,
) -> dict:
"""Build a ComfyUI API-format workflow dict for FLUX.1-schnell text-to-image.
This is a pure function — no I/O, fully testable.
"""
with open(_WORKFLOW_PATH) as f:
wf = json.load(f)
wf = copy.deepcopy(wf)
actual_seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
wf["6"]["inputs"]["text"] = prompt
wf["33"]["inputs"]["text"] = neg_prompt
wf["27"]["inputs"]["width"] = width
wf["27"]["inputs"]["height"] = height
wf["13"]["inputs"]["steps"] = steps
wf["13"]["inputs"]["seed"] = actual_seed
# Node 32 = UNETLoader (flux1-schnell.safetensors is UNet-only, not all-in-one checkpoint)
wf["32"]["inputs"]["unet_name"] = model
# Attach the actual seed as metadata so callers can retrieve it
wf["_meta"] = {"actual_seed": actual_seed}
return wf
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _sanitize_name(name: str) -> str:
"""Sanitize a user-provided name for safe use in filenames.
Replaces whitespace with underscores, strips any characters that are not
alphanumeric, underscores, or hyphens, and collapses consecutive
underscores/hyphens. Returns empty string if nothing usable remains.
"""
name = name.strip()
name = re.sub(r"\s+", "_", name) # spaces → underscores
name = re.sub(r"[^\w\-]", "", name) # strip non-alphanum/underscore/hyphen
name = re.sub(r"[_\-]{2,}", "_", name) # collapse runs
name = name.strip("_-") # trim leading/trailing separators
return name[:64] # cap at 64 chars
def _build_filename(name: str, timestamp: str, actual_seed: int) -> str:
"""Build an output filename from optional name, timestamp and seed."""
sanitized = _sanitize_name(name)
if sanitized:
return f"{sanitized}_{timestamp}_{actual_seed}.png"
return f"{timestamp}_{actual_seed}.png"
async def _generate_single(
client: ComfyUIClient,
prompt: str,
negative_prompt: str,
width: int,
height: int,
steps: int,
seed: int,
model: str,
resolved_output_dir: Path,
name: str,
label: str,
) -> list:
"""Generate a single image and return [TextContent, ImageContent] or [TextContent] on error.
Args:
client: ComfyUIClient instance.
prompt: Positive text prompt.
negative_prompt: Negative text prompt.
width / height: Image dimensions.
steps: Inference steps.
seed: Seed value (-1 = random).
model: ComfyUI model filename.
resolved_output_dir: Resolved output directory Path.
name: User-supplied name prefix (unsanitized).
label: Human-readable label for TextContent prefix (e.g. "[lumen 1/3]").
"""
# Build and submit workflow
try:
workflow = build_flux_workflow(
prompt=prompt,
neg_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
seed=seed,
model=model,
)
actual_seed = workflow["_meta"]["actual_seed"]
prompt_id = await client.queue_prompt(workflow)
except httpx.ConnectError:
return [
TextContent(
type="text",
text=(
f"{label} ComfyUI not reachable at {COMFYUI_URL}. "
"Start it with: python main.py --listen"
),
)
]
except httpx.HTTPStatusError as e:
return [
TextContent(
type="text",
text=f"{label} ComfyUI returned an error: {e.response.status_code}{e.response.text}",
)
]
# Poll until done
start = time.time()
while True:
elapsed = time.time() - start
if elapsed > COMFYUI_TIMEOUT:
return [
TextContent(
type="text",
text=(
f"{label} Generation timed out after {COMFYUI_TIMEOUT}s. "
f"prompt_id={prompt_id} — use get_generation_status to check"
),
)
]
try:
queue = await client.get_status(prompt_id)
except (httpx.ConnectError, httpx.HTTPStatusError):
await asyncio.sleep(2)
continue
running_ids = [item[1] for item in queue.get("queue_running", [])]
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
if prompt_id not in running_ids and prompt_id not in pending_ids:
break # Job is done
await asyncio.sleep(2)
elapsed = time.time() - start
# Retrieve history to find output filename
try:
history = await client.get_history(prompt_id)
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
return [
TextContent(
type="text",
text=f"{label} Failed to retrieve generation history: {e}",
)
]
job = history.get(prompt_id, {})
outputs = job.get("outputs", {})
# Find SaveImage node output (node "9" in our workflow)
image_info = None
for node_id, node_output in outputs.items():
images = node_output.get("images", [])
if images:
image_info = images[0]
break
if not image_info:
return [
TextContent(
type="text",
text=f"{label} No output image found in history for prompt_id={prompt_id}",
)
]
# Download image bytes
try:
image_bytes = await client.get_image(
filename=image_info["filename"],
subfolder=image_info.get("subfolder", ""),
folder_type=image_info.get("type", "output"),
)
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
return [
TextContent(
type="text",
text=f"{label} Failed to download generated image: {e}",
)
]
# Save to disk
try:
resolved_output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = _build_filename(name, timestamp, actual_seed)
out_path = resolved_output_dir / filename
out_path.write_bytes(image_bytes)
except OSError as e:
return [
TextContent(
type="text",
text=f"{label} Cannot write to output directory: {resolved_output_dir}{e}",
)
]
# Encode as base64 for inline display
b64_data = base64.b64encode(image_bytes).decode("utf-8")
return [
TextContent(
type="text",
text=(
f"{label} Generated: {out_path}\n"
f"Seed: {actual_seed}\n"
f"Elapsed: {elapsed:.1f}s\n"
f"Size: {width}x{height}, Steps: {steps}, Model: {model}"
),
),
ImageContent(
type="image",
data=b64_data,
mimeType="image/png",
),
]
# ---------------------------------------------------------------------------
# Tools
# ---------------------------------------------------------------------------
@mcp.tool()
async def generate_image(
prompt: str,
width: int = 1024,
height: int = 1024,
steps: int = 4,
model: str = "flux1-schnell.safetensors",
seed: int = -1,
negative_prompt: str = "",
output_dir: str = "",
name: str = "",
count: int = 1,
) -> list:
"""Generate an image from a text prompt using ComfyUI.
Returns both a file path (for persistence) and an inline base64 image
(for display in Claude / Roo Code chat).
Args:
prompt: Text description of the image to generate.
width: Image width in pixels (default: 1024).
height: Image height in pixels (default: 1024).
steps: Number of inference steps. FLUX.1-schnell works well at 4.
model: ComfyUI model filename (default: flux1-schnell.safetensors).
seed: Random seed for reproducibility. -1 = random.
When count > 1 and seed != -1, seeds are incremented per image
(seed, seed+1, seed+2, ...) to produce deterministic variation.
negative_prompt: Things to exclude from the image (optional).
output_dir: Override output directory. Defaults to IMAGE_OUTPUT_DIR env var
or ~/Pictures/mcp-generated.
name: Optional filename prefix. Saved as {name}_{timestamp}_{seed}.png.
Useful to avoid confusion with auto-generated timestamp filenames.
count: Number of images to generate (110). Each image is generated
sequentially. Partial failures are returned inline — the batch
continues even if one image fails.
Returns:
Flat interleaved list: [TextContent1, ImageContent1, TextContent2, ImageContent2, ...]
On error for any single image, that slot contains only [TextContent(error)].
"""
# Validate count
if count < 1:
return [
TextContent(
type="text",
text=f"count must be at least 1 (got {count}).",
)
]
if count > MAX_COUNT:
return [
TextContent(
type="text",
text=f"count must be at most {MAX_COUNT} (got {count}). Use multiple calls for larger batches.",
)
]
# Resolve output directory once
resolved_output_dir = Path(
output_dir or IMAGE_OUTPUT_DIR
).expanduser().resolve()
client = ComfyUIClient(COMFYUI_URL)
results = []
for i in range(1, count + 1):
# Compute seed for this image:
# - seed=-1 → each image gets an independent random seed
# - fixed seed → increment by i-1 for deterministic variation across the batch
image_seed = seed if seed == -1 else seed + (i - 1)
label = f"[{_sanitize_name(name) or 'image'} {i}/{count}]" if count > 1 else (
f"[{_sanitize_name(name)}]" if _sanitize_name(name) else ""
)
single_result = await _generate_single(
client=client,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
seed=image_seed,
model=model,
resolved_output_dir=resolved_output_dir,
name=name,
label=label,
)
results.extend(single_result)
return results
@mcp.tool()
async def list_available_models() -> list[str]:
"""List all checkpoint models available in ComfyUI.
Returns a list of model filenames available for use with generate_image.
Requires ComfyUI to be running at COMFYUI_URL.
"""
client = ComfyUIClient(COMFYUI_URL)
try:
return await client.get_models()
except httpx.ConnectError:
return [
f"ComfyUI not reachable at {COMFYUI_URL}. "
"Start it with: python main.py --listen"
]
except httpx.HTTPStatusError as e:
return [f"ComfyUI error: {e.response.status_code}"]
@mcp.tool()
async def get_generation_status(prompt_id: str) -> dict:
"""Check the status of a queued or running generation job.
Args:
prompt_id: The prompt ID returned by a previous generate_image call.
Returns:
Dict with 'status' key: "pending", "running", "completed", or "not_found".
"""
client = ComfyUIClient(COMFYUI_URL)
try:
queue = await client.get_status(prompt_id)
running_ids = [item[1] for item in queue.get("queue_running", [])]
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
if prompt_id in running_ids:
return {"status": "running", "prompt_id": prompt_id}
if prompt_id in pending_ids:
return {"status": "pending", "prompt_id": prompt_id}
# Not in queue — check history
try:
history = await client.get_history(prompt_id)
if prompt_id in history:
return {"status": "completed", "prompt_id": prompt_id}
except (httpx.ConnectError, httpx.HTTPStatusError):
pass
return {"status": "not_found", "prompt_id": prompt_id}
except httpx.ConnectError:
return {
"status": "error",
"message": f"ComfyUI not reachable at {COMFYUI_URL}",
}
except httpx.HTTPStatusError as e:
return {"status": "error", "message": f"HTTP {e.response.status_code}"}
@mcp.tool()
def get_output_directory() -> str:
"""Return the directory where generated images are saved.
Returns:
Absolute path to the output directory (may not exist yet).
"""
return str(Path(IMAGE_OUTPUT_DIR).expanduser().resolve())
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
mcp.run(transport="stdio")