Files
pi_mcps/mcp/mcp-image-gen/src/server.py
T
Patrick Plate 8112ff2f12 feat(mcp-image-gen): scaffold ComfyUI-backed image generation MCP server
- FastMCP server with 4 tools: generate_image, list_available_models,
  get_generation_status, get_output_directory
- ComfyUI REST API client (httpx) polling lifecycle
- FLUX.1-schnell workflow JSON template
- Dual output: TextContent (path + seed) + ImageContent (base64 PNG)
- 14 passing pytest tests with respx HTTP mocking
- ROCm/AMD RX 7900 XTX optimized setup in README
- Ollama Linux migration path documented (future)
2026-04-04 11:49:31 +02:00

385 lines
13 KiB
Python

"""mcp-image-gen — FastMCP server for AI image generation via ComfyUI."""
import asyncio
import base64
import copy
import json
import os
import random
import time
from datetime import datetime
from pathlib import Path
import httpx
from fastmcp import FastMCP
from mcp.types import ImageContent, TextContent
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
COMFYUI_URL = os.environ.get("COMFYUI_URL", "http://localhost:8188").rstrip("/")
IMAGE_OUTPUT_DIR = os.environ.get("IMAGE_OUTPUT_DIR", "~/Pictures/mcp-generated")
COMFYUI_TIMEOUT = int(os.environ.get("COMFYUI_TIMEOUT", "120"))
# Path to the bundled FLUX.1-schnell workflow template
_WORKFLOW_PATH = Path(__file__).parent / "workflows" / "flux_schnell.json"
mcp = FastMCP("mcp-image-gen")
# ---------------------------------------------------------------------------
# ComfyUI client
# ---------------------------------------------------------------------------
class ComfyUIClient:
"""Async HTTP client wrapper for the ComfyUI REST API."""
def __init__(self, base_url: str = COMFYUI_URL):
self.base_url = base_url.rstrip("/")
async def queue_prompt(self, workflow: dict) -> str:
"""Submit a workflow to ComfyUI and return the prompt_id."""
payload = {"prompt": workflow}
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(f"{self.base_url}/api/prompt", json=payload)
resp.raise_for_status()
return resp.json()["prompt_id"]
async def get_status(self, prompt_id: str) -> dict:
"""Return the current queue state (queue_running + queue_pending lists)."""
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/api/queue")
resp.raise_for_status()
return resp.json()
async def get_history(self, prompt_id: str) -> dict:
"""Return the history entry for a completed prompt_id."""
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(f"{self.base_url}/api/history/{prompt_id}")
resp.raise_for_status()
return resp.json()
async def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
"""Download image bytes from ComfyUI's /api/view endpoint."""
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.get(f"{self.base_url}/api/view", params=params)
resp.raise_for_status()
return resp.content
async def get_models(self) -> list[str]:
"""Return the list of available checkpoint model filenames."""
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
f"{self.base_url}/object_info/CheckpointLoaderSimple"
)
resp.raise_for_status()
data = resp.json()
# ComfyUI returns: {"CheckpointLoaderSimple": {"input": {"required": {"ckpt_name": [["model1.safetensors", ...], ...]}}}}
node_info = data.get("CheckpointLoaderSimple", {})
ckpt_list = (
node_info.get("input", {})
.get("required", {})
.get("ckpt_name", [[]])[0]
)
return ckpt_list if isinstance(ckpt_list, list) else []
# ---------------------------------------------------------------------------
# Workflow builder
# ---------------------------------------------------------------------------
def build_flux_workflow(
prompt: str,
neg_prompt: str,
width: int,
height: int,
steps: int,
seed: int,
model: str,
) -> dict:
"""Build a ComfyUI API-format workflow dict for FLUX.1-schnell text-to-image.
This is a pure function — no I/O, fully testable.
"""
with open(_WORKFLOW_PATH) as f:
wf = json.load(f)
wf = copy.deepcopy(wf)
actual_seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
wf["6"]["inputs"]["text"] = prompt
wf["33"]["inputs"]["text"] = neg_prompt
wf["27"]["inputs"]["width"] = width
wf["27"]["inputs"]["height"] = height
wf["13"]["inputs"]["steps"] = steps
wf["13"]["inputs"]["seed"] = actual_seed
wf["30"]["inputs"]["ckpt_name"] = model
# Attach the actual seed as metadata so callers can retrieve it
wf["_meta"] = {"actual_seed": actual_seed}
return wf
# ---------------------------------------------------------------------------
# Tools
# ---------------------------------------------------------------------------
@mcp.tool()
async def generate_image(
prompt: str,
width: int = 1024,
height: int = 1024,
steps: int = 4,
model: str = "flux1-schnell.safetensors",
seed: int = -1,
negative_prompt: str = "",
output_dir: str = "",
) -> list:
"""Generate an image from a text prompt using ComfyUI.
Returns both a file path (for persistence) and an inline base64 image
(for display in Claude / Roo Code chat).
Args:
prompt: Text description of the image to generate.
width: Image width in pixels (default: 1024).
height: Image height in pixels (default: 1024).
steps: Number of inference steps. FLUX.1-schnell works well at 4.
model: ComfyUI model filename (default: flux1-schnell.safetensors).
seed: Random seed for reproducibility. -1 = random.
negative_prompt: Things to exclude from the image (optional).
output_dir: Override output directory. Defaults to IMAGE_OUTPUT_DIR env var
or ~/Pictures/mcp-generated.
Returns:
[TextContent(path + metadata), ImageContent(base64 PNG)]
"""
# Resolve output directory
resolved_output_dir = Path(
output_dir or IMAGE_OUTPUT_DIR
).expanduser().resolve()
client = ComfyUIClient(COMFYUI_URL)
# Build and submit workflow
try:
workflow = build_flux_workflow(
prompt=prompt,
neg_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
seed=seed,
model=model,
)
actual_seed = workflow["_meta"]["actual_seed"]
prompt_id = await client.queue_prompt(workflow)
except httpx.ConnectError:
return [
TextContent(
type="text",
text=(
f"ComfyUI not reachable at {COMFYUI_URL}. "
"Start it with: python main.py --listen"
),
)
]
except httpx.HTTPStatusError as e:
return [
TextContent(
type="text",
text=f"ComfyUI returned an error: {e.response.status_code}{e.response.text}",
)
]
# Poll until done
start = time.time()
while True:
elapsed = time.time() - start
if elapsed > COMFYUI_TIMEOUT:
return [
TextContent(
type="text",
text=(
f"Generation timed out after {COMFYUI_TIMEOUT}s. "
f"prompt_id={prompt_id} — use get_generation_status to check"
),
)
]
try:
queue = await client.get_status(prompt_id)
except (httpx.ConnectError, httpx.HTTPStatusError):
await asyncio.sleep(2)
continue
running_ids = [item[1] for item in queue.get("queue_running", [])]
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
if prompt_id not in running_ids and prompt_id not in pending_ids:
break # Job is done
await asyncio.sleep(2)
elapsed = time.time() - start
# Retrieve history to find output filename
try:
history = await client.get_history(prompt_id)
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
return [
TextContent(
type="text",
text=f"Failed to retrieve generation history: {e}",
)
]
job = history.get(prompt_id, {})
outputs = job.get("outputs", {})
# Find SaveImage node output (node "9" in our workflow)
image_info = None
for node_id, node_output in outputs.items():
images = node_output.get("images", [])
if images:
image_info = images[0]
break
if not image_info:
return [
TextContent(
type="text",
text=f"No output image found in history for prompt_id={prompt_id}",
)
]
# Download image bytes
try:
image_bytes = await client.get_image(
filename=image_info["filename"],
subfolder=image_info.get("subfolder", ""),
folder_type=image_info.get("type", "output"),
)
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
return [
TextContent(
type="text",
text=f"Failed to download generated image: {e}",
)
]
# Save to disk
try:
resolved_output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{timestamp}_{actual_seed}.png"
out_path = resolved_output_dir / filename
out_path.write_bytes(image_bytes)
except OSError as e:
return [
TextContent(
type="text",
text=f"Cannot write to output directory: {resolved_output_dir}{e}",
)
]
# Encode as base64 for inline display
b64_data = base64.b64encode(image_bytes).decode("utf-8")
return [
TextContent(
type="text",
text=(
f"Generated: {out_path}\n"
f"Seed: {actual_seed}\n"
f"Elapsed: {elapsed:.1f}s\n"
f"Size: {width}x{height}, Steps: {steps}, Model: {model}"
),
),
ImageContent(
type="image",
data=b64_data,
mimeType="image/png",
),
]
@mcp.tool()
async def list_available_models() -> list[str]:
"""List all checkpoint models available in ComfyUI.
Returns a list of model filenames available for use with generate_image.
Requires ComfyUI to be running at COMFYUI_URL.
"""
client = ComfyUIClient(COMFYUI_URL)
try:
return await client.get_models()
except httpx.ConnectError:
return [
f"ComfyUI not reachable at {COMFYUI_URL}. "
"Start it with: python main.py --listen"
]
except httpx.HTTPStatusError as e:
return [f"ComfyUI error: {e.response.status_code}"]
@mcp.tool()
async def get_generation_status(prompt_id: str) -> dict:
"""Check the status of a queued or running generation job.
Args:
prompt_id: The prompt ID returned by a previous generate_image call.
Returns:
Dict with 'status' key: "pending", "running", "completed", or "not_found".
"""
client = ComfyUIClient(COMFYUI_URL)
try:
queue = await client.get_status(prompt_id)
running_ids = [item[1] for item in queue.get("queue_running", [])]
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
if prompt_id in running_ids:
return {"status": "running", "prompt_id": prompt_id}
if prompt_id in pending_ids:
return {"status": "pending", "prompt_id": prompt_id}
# Not in queue — check history
try:
history = await client.get_history(prompt_id)
if prompt_id in history:
return {"status": "completed", "prompt_id": prompt_id}
except (httpx.ConnectError, httpx.HTTPStatusError):
pass
return {"status": "not_found", "prompt_id": prompt_id}
except httpx.ConnectError:
return {
"status": "error",
"message": f"ComfyUI not reachable at {COMFYUI_URL}",
}
except httpx.HTTPStatusError as e:
return {"status": "error", "message": f"HTTP {e.response.status_code}"}
@mcp.tool()
def get_output_directory() -> str:
"""Return the directory where generated images are saved.
Returns:
Absolute path to the output directory (may not exist yet).
"""
return str(Path(IMAGE_OUTPUT_DIR).expanduser().resolve())
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
mcp.run(transport="stdio")