79f1e6d65f
- Add name (str) param: filename prefix saved as {name}_{timestamp}_{seed}.png
- Add count (int, 1-10) param: generate N images in one call
- Extract _sanitize_name() helper: strips special chars, collapses underscores, caps at 64 chars
- Extract _build_filename() helper: pure function for testable filename construction
- Extract _generate_single() coroutine: clean loop body for batch generation
- Fixed seed batches increment seed per image (seed+i-1) for deterministic variation
- random seed (-1) batches give independent random seeds per image
- Partial batch failures continue (error TextContent in slot, remaining images proceed)
- Returns flat interleaved [Text1, Image1, Text2, Image2, ...] list
- 34/34 tests passing (was 19, added 15 new tests)
500 lines
17 KiB
Python
500 lines
17 KiB
Python
"""mcp-image-gen — FastMCP server for AI image generation via ComfyUI."""
|
||
|
||
import asyncio
|
||
import base64
|
||
import copy
|
||
import json
|
||
import os
|
||
import random
|
||
import re
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
from fastmcp import FastMCP
|
||
from mcp.types import ImageContent, TextContent
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Configuration
|
||
# ---------------------------------------------------------------------------
|
||
|
||
COMFYUI_URL = os.environ.get("COMFYUI_URL", "http://localhost:8188").rstrip("/")
|
||
IMAGE_OUTPUT_DIR = os.environ.get("IMAGE_OUTPUT_DIR", "~/Pictures/mcp-generated")
|
||
COMFYUI_TIMEOUT = int(os.environ.get("COMFYUI_TIMEOUT", "120"))
|
||
|
||
# Maximum number of images allowed in a single batch call
|
||
MAX_COUNT = 10
|
||
|
||
# Path to the bundled FLUX.1-schnell workflow template
|
||
_WORKFLOW_PATH = Path(__file__).parent / "workflows" / "flux_schnell.json"
|
||
|
||
mcp = FastMCP("mcp-image-gen")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# ComfyUI client
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class ComfyUIClient:
|
||
"""Async HTTP client wrapper for the ComfyUI REST API."""
|
||
|
||
def __init__(self, base_url: str = COMFYUI_URL):
|
||
self.base_url = base_url.rstrip("/")
|
||
|
||
async def queue_prompt(self, workflow: dict) -> str:
|
||
"""Submit a workflow to ComfyUI and return the prompt_id."""
|
||
# Strip internal metadata keys (e.g. "_meta") — they are not ComfyUI nodes
|
||
clean_workflow = {k: v for k, v in workflow.items() if not k.startswith("_")}
|
||
payload = {"prompt": clean_workflow}
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
resp = await client.post(f"{self.base_url}/api/prompt", json=payload)
|
||
resp.raise_for_status()
|
||
return resp.json()["prompt_id"]
|
||
|
||
async def get_status(self, prompt_id: str) -> dict:
|
||
"""Return the current queue state (queue_running + queue_pending lists)."""
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
resp = await client.get(f"{self.base_url}/api/queue")
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
|
||
async def get_history(self, prompt_id: str) -> dict:
|
||
"""Return the history entry for a completed prompt_id."""
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
resp = await client.get(f"{self.base_url}/api/history/{prompt_id}")
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
|
||
async def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
|
||
"""Download image bytes from ComfyUI's /api/view endpoint."""
|
||
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
resp = await client.get(f"{self.base_url}/api/view", params=params)
|
||
resp.raise_for_status()
|
||
return resp.content
|
||
|
||
async def get_models(self) -> list[str]:
|
||
"""Return the list of available checkpoint model filenames."""
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
resp = await client.get(
|
||
f"{self.base_url}/object_info/CheckpointLoaderSimple"
|
||
)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
# ComfyUI returns: {"CheckpointLoaderSimple": {"input": {"required": {"ckpt_name": [["model1.safetensors", ...], ...]}}}}
|
||
node_info = data.get("CheckpointLoaderSimple", {})
|
||
ckpt_list = (
|
||
node_info.get("input", {})
|
||
.get("required", {})
|
||
.get("ckpt_name", [[]])[0]
|
||
)
|
||
return ckpt_list if isinstance(ckpt_list, list) else []
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Workflow builder
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def build_flux_workflow(
|
||
prompt: str,
|
||
neg_prompt: str,
|
||
width: int,
|
||
height: int,
|
||
steps: int,
|
||
seed: int,
|
||
model: str,
|
||
) -> dict:
|
||
"""Build a ComfyUI API-format workflow dict for FLUX.1-schnell text-to-image.
|
||
|
||
This is a pure function — no I/O, fully testable.
|
||
"""
|
||
with open(_WORKFLOW_PATH) as f:
|
||
wf = json.load(f)
|
||
wf = copy.deepcopy(wf)
|
||
|
||
actual_seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
|
||
|
||
wf["6"]["inputs"]["text"] = prompt
|
||
wf["33"]["inputs"]["text"] = neg_prompt
|
||
wf["27"]["inputs"]["width"] = width
|
||
wf["27"]["inputs"]["height"] = height
|
||
wf["13"]["inputs"]["steps"] = steps
|
||
wf["13"]["inputs"]["seed"] = actual_seed
|
||
# Node 32 = UNETLoader (flux1-schnell.safetensors is UNet-only, not all-in-one checkpoint)
|
||
wf["32"]["inputs"]["unet_name"] = model
|
||
|
||
# Attach the actual seed as metadata so callers can retrieve it
|
||
wf["_meta"] = {"actual_seed": actual_seed}
|
||
return wf
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _sanitize_name(name: str) -> str:
|
||
"""Sanitize a user-provided name for safe use in filenames.
|
||
|
||
Replaces whitespace with underscores, strips any characters that are not
|
||
alphanumeric, underscores, or hyphens, and collapses consecutive
|
||
underscores/hyphens. Returns empty string if nothing usable remains.
|
||
"""
|
||
name = name.strip()
|
||
name = re.sub(r"\s+", "_", name) # spaces → underscores
|
||
name = re.sub(r"[^\w\-]", "", name) # strip non-alphanum/underscore/hyphen
|
||
name = re.sub(r"[_\-]{2,}", "_", name) # collapse runs
|
||
name = name.strip("_-") # trim leading/trailing separators
|
||
return name[:64] # cap at 64 chars
|
||
|
||
|
||
def _build_filename(name: str, timestamp: str, actual_seed: int) -> str:
|
||
"""Build an output filename from optional name, timestamp and seed."""
|
||
sanitized = _sanitize_name(name)
|
||
if sanitized:
|
||
return f"{sanitized}_{timestamp}_{actual_seed}.png"
|
||
return f"{timestamp}_{actual_seed}.png"
|
||
|
||
|
||
async def _generate_single(
|
||
client: ComfyUIClient,
|
||
prompt: str,
|
||
negative_prompt: str,
|
||
width: int,
|
||
height: int,
|
||
steps: int,
|
||
seed: int,
|
||
model: str,
|
||
resolved_output_dir: Path,
|
||
name: str,
|
||
label: str,
|
||
) -> list:
|
||
"""Generate a single image and return [TextContent, ImageContent] or [TextContent] on error.
|
||
|
||
Args:
|
||
client: ComfyUIClient instance.
|
||
prompt: Positive text prompt.
|
||
negative_prompt: Negative text prompt.
|
||
width / height: Image dimensions.
|
||
steps: Inference steps.
|
||
seed: Seed value (-1 = random).
|
||
model: ComfyUI model filename.
|
||
resolved_output_dir: Resolved output directory Path.
|
||
name: User-supplied name prefix (unsanitized).
|
||
label: Human-readable label for TextContent prefix (e.g. "[lumen 1/3]").
|
||
"""
|
||
# Build and submit workflow
|
||
try:
|
||
workflow = build_flux_workflow(
|
||
prompt=prompt,
|
||
neg_prompt=negative_prompt,
|
||
width=width,
|
||
height=height,
|
||
steps=steps,
|
||
seed=seed,
|
||
model=model,
|
||
)
|
||
actual_seed = workflow["_meta"]["actual_seed"]
|
||
prompt_id = await client.queue_prompt(workflow)
|
||
except httpx.ConnectError:
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=(
|
||
f"{label} ComfyUI not reachable at {COMFYUI_URL}. "
|
||
"Start it with: python main.py --listen"
|
||
),
|
||
)
|
||
]
|
||
except httpx.HTTPStatusError as e:
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=f"{label} ComfyUI returned an error: {e.response.status_code} — {e.response.text}",
|
||
)
|
||
]
|
||
|
||
# Poll until done
|
||
start = time.time()
|
||
while True:
|
||
elapsed = time.time() - start
|
||
if elapsed > COMFYUI_TIMEOUT:
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=(
|
||
f"{label} Generation timed out after {COMFYUI_TIMEOUT}s. "
|
||
f"prompt_id={prompt_id} — use get_generation_status to check"
|
||
),
|
||
)
|
||
]
|
||
|
||
try:
|
||
queue = await client.get_status(prompt_id)
|
||
except (httpx.ConnectError, httpx.HTTPStatusError):
|
||
await asyncio.sleep(2)
|
||
continue
|
||
|
||
running_ids = [item[1] for item in queue.get("queue_running", [])]
|
||
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
|
||
|
||
if prompt_id not in running_ids and prompt_id not in pending_ids:
|
||
break # Job is done
|
||
|
||
await asyncio.sleep(2)
|
||
|
||
elapsed = time.time() - start
|
||
|
||
# Retrieve history to find output filename
|
||
try:
|
||
history = await client.get_history(prompt_id)
|
||
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=f"{label} Failed to retrieve generation history: {e}",
|
||
)
|
||
]
|
||
|
||
job = history.get(prompt_id, {})
|
||
outputs = job.get("outputs", {})
|
||
|
||
# Find SaveImage node output (node "9" in our workflow)
|
||
image_info = None
|
||
for node_id, node_output in outputs.items():
|
||
images = node_output.get("images", [])
|
||
if images:
|
||
image_info = images[0]
|
||
break
|
||
|
||
if not image_info:
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=f"{label} No output image found in history for prompt_id={prompt_id}",
|
||
)
|
||
]
|
||
|
||
# Download image bytes
|
||
try:
|
||
image_bytes = await client.get_image(
|
||
filename=image_info["filename"],
|
||
subfolder=image_info.get("subfolder", ""),
|
||
folder_type=image_info.get("type", "output"),
|
||
)
|
||
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=f"{label} Failed to download generated image: {e}",
|
||
)
|
||
]
|
||
|
||
# Save to disk
|
||
try:
|
||
resolved_output_dir.mkdir(parents=True, exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
filename = _build_filename(name, timestamp, actual_seed)
|
||
out_path = resolved_output_dir / filename
|
||
out_path.write_bytes(image_bytes)
|
||
except OSError as e:
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=f"{label} Cannot write to output directory: {resolved_output_dir} — {e}",
|
||
)
|
||
]
|
||
|
||
# Encode as base64 for inline display
|
||
b64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=(
|
||
f"{label} Generated: {out_path}\n"
|
||
f"Seed: {actual_seed}\n"
|
||
f"Elapsed: {elapsed:.1f}s\n"
|
||
f"Size: {width}x{height}, Steps: {steps}, Model: {model}"
|
||
),
|
||
),
|
||
ImageContent(
|
||
type="image",
|
||
data=b64_data,
|
||
mimeType="image/png",
|
||
),
|
||
]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Tools
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@mcp.tool()
|
||
async def generate_image(
|
||
prompt: str,
|
||
width: int = 1024,
|
||
height: int = 1024,
|
||
steps: int = 4,
|
||
model: str = "flux1-schnell.safetensors",
|
||
seed: int = -1,
|
||
negative_prompt: str = "",
|
||
output_dir: str = "",
|
||
name: str = "",
|
||
count: int = 1,
|
||
) -> list:
|
||
"""Generate an image from a text prompt using ComfyUI.
|
||
|
||
Returns both a file path (for persistence) and an inline base64 image
|
||
(for display in Claude / Roo Code chat).
|
||
|
||
Args:
|
||
prompt: Text description of the image to generate.
|
||
width: Image width in pixels (default: 1024).
|
||
height: Image height in pixels (default: 1024).
|
||
steps: Number of inference steps. FLUX.1-schnell works well at 4.
|
||
model: ComfyUI model filename (default: flux1-schnell.safetensors).
|
||
seed: Random seed for reproducibility. -1 = random.
|
||
When count > 1 and seed != -1, seeds are incremented per image
|
||
(seed, seed+1, seed+2, ...) to produce deterministic variation.
|
||
negative_prompt: Things to exclude from the image (optional).
|
||
output_dir: Override output directory. Defaults to IMAGE_OUTPUT_DIR env var
|
||
or ~/Pictures/mcp-generated.
|
||
name: Optional filename prefix. Saved as {name}_{timestamp}_{seed}.png.
|
||
Useful to avoid confusion with auto-generated timestamp filenames.
|
||
count: Number of images to generate (1–10). Each image is generated
|
||
sequentially. Partial failures are returned inline — the batch
|
||
continues even if one image fails.
|
||
|
||
Returns:
|
||
Flat interleaved list: [TextContent1, ImageContent1, TextContent2, ImageContent2, ...]
|
||
On error for any single image, that slot contains only [TextContent(error)].
|
||
"""
|
||
# Validate count
|
||
if count < 1:
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=f"count must be at least 1 (got {count}).",
|
||
)
|
||
]
|
||
if count > MAX_COUNT:
|
||
return [
|
||
TextContent(
|
||
type="text",
|
||
text=f"count must be at most {MAX_COUNT} (got {count}). Use multiple calls for larger batches.",
|
||
)
|
||
]
|
||
|
||
# Resolve output directory once
|
||
resolved_output_dir = Path(
|
||
output_dir or IMAGE_OUTPUT_DIR
|
||
).expanduser().resolve()
|
||
|
||
client = ComfyUIClient(COMFYUI_URL)
|
||
|
||
results = []
|
||
for i in range(1, count + 1):
|
||
# Compute seed for this image:
|
||
# - seed=-1 → each image gets an independent random seed
|
||
# - fixed seed → increment by i-1 for deterministic variation across the batch
|
||
image_seed = seed if seed == -1 else seed + (i - 1)
|
||
|
||
label = f"[{_sanitize_name(name) or 'image'} {i}/{count}]" if count > 1 else (
|
||
f"[{_sanitize_name(name)}]" if _sanitize_name(name) else ""
|
||
)
|
||
|
||
single_result = await _generate_single(
|
||
client=client,
|
||
prompt=prompt,
|
||
negative_prompt=negative_prompt,
|
||
width=width,
|
||
height=height,
|
||
steps=steps,
|
||
seed=image_seed,
|
||
model=model,
|
||
resolved_output_dir=resolved_output_dir,
|
||
name=name,
|
||
label=label,
|
||
)
|
||
results.extend(single_result)
|
||
|
||
return results
|
||
|
||
|
||
@mcp.tool()
|
||
async def list_available_models() -> list[str]:
|
||
"""List all checkpoint models available in ComfyUI.
|
||
|
||
Returns a list of model filenames available for use with generate_image.
|
||
Requires ComfyUI to be running at COMFYUI_URL.
|
||
"""
|
||
client = ComfyUIClient(COMFYUI_URL)
|
||
try:
|
||
return await client.get_models()
|
||
except httpx.ConnectError:
|
||
return [
|
||
f"ComfyUI not reachable at {COMFYUI_URL}. "
|
||
"Start it with: python main.py --listen"
|
||
]
|
||
except httpx.HTTPStatusError as e:
|
||
return [f"ComfyUI error: {e.response.status_code}"]
|
||
|
||
|
||
@mcp.tool()
|
||
async def get_generation_status(prompt_id: str) -> dict:
|
||
"""Check the status of a queued or running generation job.
|
||
|
||
Args:
|
||
prompt_id: The prompt ID returned by a previous generate_image call.
|
||
|
||
Returns:
|
||
Dict with 'status' key: "pending", "running", "completed", or "not_found".
|
||
"""
|
||
client = ComfyUIClient(COMFYUI_URL)
|
||
try:
|
||
queue = await client.get_status(prompt_id)
|
||
running_ids = [item[1] for item in queue.get("queue_running", [])]
|
||
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
|
||
|
||
if prompt_id in running_ids:
|
||
return {"status": "running", "prompt_id": prompt_id}
|
||
if prompt_id in pending_ids:
|
||
return {"status": "pending", "prompt_id": prompt_id}
|
||
|
||
# Not in queue — check history
|
||
try:
|
||
history = await client.get_history(prompt_id)
|
||
if prompt_id in history:
|
||
return {"status": "completed", "prompt_id": prompt_id}
|
||
except (httpx.ConnectError, httpx.HTTPStatusError):
|
||
pass
|
||
|
||
return {"status": "not_found", "prompt_id": prompt_id}
|
||
|
||
except httpx.ConnectError:
|
||
return {
|
||
"status": "error",
|
||
"message": f"ComfyUI not reachable at {COMFYUI_URL}",
|
||
}
|
||
except httpx.HTTPStatusError as e:
|
||
return {"status": "error", "message": f"HTTP {e.response.status_code}"}
|
||
|
||
|
||
@mcp.tool()
|
||
def get_output_directory() -> str:
|
||
"""Return the directory where generated images are saved.
|
||
|
||
Returns:
|
||
Absolute path to the output directory (may not exist yet).
|
||
"""
|
||
return str(Path(IMAGE_OUTPUT_DIR).expanduser().resolve())
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Entry point
|
||
# ---------------------------------------------------------------------------
|
||
|
||
if __name__ == "__main__":
|
||
mcp.run(transport="stdio")
|