feat(mcp-image-gen): scaffold ComfyUI-backed image generation MCP server
- FastMCP server with 4 tools: generate_image, list_available_models, get_generation_status, get_output_directory - ComfyUI REST API client (httpx) polling lifecycle - FLUX.1-schnell workflow JSON template - Dual output: TextContent (path + seed) + ImageContent (base64 PNG) - 14 passing pytest tests with respx HTTP mocking - ROCm/AMD RX 7900 XTX optimized setup in README - Ollama Linux migration path documented (future)
This commit is contained in:
@@ -0,0 +1,384 @@
|
||||
"""mcp-image-gen — FastMCP server for AI image generation via ComfyUI."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
from mcp.types import ImageContent, TextContent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
COMFYUI_URL = os.environ.get("COMFYUI_URL", "http://localhost:8188").rstrip("/")
|
||||
IMAGE_OUTPUT_DIR = os.environ.get("IMAGE_OUTPUT_DIR", "~/Pictures/mcp-generated")
|
||||
COMFYUI_TIMEOUT = int(os.environ.get("COMFYUI_TIMEOUT", "120"))
|
||||
|
||||
# Path to the bundled FLUX.1-schnell workflow template
|
||||
_WORKFLOW_PATH = Path(__file__).parent / "workflows" / "flux_schnell.json"
|
||||
|
||||
mcp = FastMCP("mcp-image-gen")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ComfyUI client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ComfyUIClient:
|
||||
"""Async HTTP client wrapper for the ComfyUI REST API."""
|
||||
|
||||
def __init__(self, base_url: str = COMFYUI_URL):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
|
||||
async def queue_prompt(self, workflow: dict) -> str:
|
||||
"""Submit a workflow to ComfyUI and return the prompt_id."""
|
||||
payload = {"prompt": workflow}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(f"{self.base_url}/api/prompt", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["prompt_id"]
|
||||
|
||||
async def get_status(self, prompt_id: str) -> dict:
|
||||
"""Return the current queue state (queue_running + queue_pending lists)."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(f"{self.base_url}/api/queue")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def get_history(self, prompt_id: str) -> dict:
|
||||
"""Return the history entry for a completed prompt_id."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(f"{self.base_url}/api/history/{prompt_id}")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
|
||||
"""Download image bytes from ComfyUI's /api/view endpoint."""
|
||||
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.get(f"{self.base_url}/api/view", params=params)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def get_models(self) -> list[str]:
|
||||
"""Return the list of available checkpoint model filenames."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/object_info/CheckpointLoaderSimple"
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# ComfyUI returns: {"CheckpointLoaderSimple": {"input": {"required": {"ckpt_name": [["model1.safetensors", ...], ...]}}}}
|
||||
node_info = data.get("CheckpointLoaderSimple", {})
|
||||
ckpt_list = (
|
||||
node_info.get("input", {})
|
||||
.get("required", {})
|
||||
.get("ckpt_name", [[]])[0]
|
||||
)
|
||||
return ckpt_list if isinstance(ckpt_list, list) else []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_flux_workflow(
|
||||
prompt: str,
|
||||
neg_prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
steps: int,
|
||||
seed: int,
|
||||
model: str,
|
||||
) -> dict:
|
||||
"""Build a ComfyUI API-format workflow dict for FLUX.1-schnell text-to-image.
|
||||
|
||||
This is a pure function — no I/O, fully testable.
|
||||
"""
|
||||
with open(_WORKFLOW_PATH) as f:
|
||||
wf = json.load(f)
|
||||
wf = copy.deepcopy(wf)
|
||||
|
||||
actual_seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
|
||||
|
||||
wf["6"]["inputs"]["text"] = prompt
|
||||
wf["33"]["inputs"]["text"] = neg_prompt
|
||||
wf["27"]["inputs"]["width"] = width
|
||||
wf["27"]["inputs"]["height"] = height
|
||||
wf["13"]["inputs"]["steps"] = steps
|
||||
wf["13"]["inputs"]["seed"] = actual_seed
|
||||
wf["30"]["inputs"]["ckpt_name"] = model
|
||||
|
||||
# Attach the actual seed as metadata so callers can retrieve it
|
||||
wf["_meta"] = {"actual_seed": actual_seed}
|
||||
return wf
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@mcp.tool()
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
steps: int = 4,
|
||||
model: str = "flux1-schnell.safetensors",
|
||||
seed: int = -1,
|
||||
negative_prompt: str = "",
|
||||
output_dir: str = "",
|
||||
) -> list:
|
||||
"""Generate an image from a text prompt using ComfyUI.
|
||||
|
||||
Returns both a file path (for persistence) and an inline base64 image
|
||||
(for display in Claude / Roo Code chat).
|
||||
|
||||
Args:
|
||||
prompt: Text description of the image to generate.
|
||||
width: Image width in pixels (default: 1024).
|
||||
height: Image height in pixels (default: 1024).
|
||||
steps: Number of inference steps. FLUX.1-schnell works well at 4.
|
||||
model: ComfyUI model filename (default: flux1-schnell.safetensors).
|
||||
seed: Random seed for reproducibility. -1 = random.
|
||||
negative_prompt: Things to exclude from the image (optional).
|
||||
output_dir: Override output directory. Defaults to IMAGE_OUTPUT_DIR env var
|
||||
or ~/Pictures/mcp-generated.
|
||||
|
||||
Returns:
|
||||
[TextContent(path + metadata), ImageContent(base64 PNG)]
|
||||
"""
|
||||
# Resolve output directory
|
||||
resolved_output_dir = Path(
|
||||
output_dir or IMAGE_OUTPUT_DIR
|
||||
).expanduser().resolve()
|
||||
|
||||
client = ComfyUIClient(COMFYUI_URL)
|
||||
|
||||
# Build and submit workflow
|
||||
try:
|
||||
workflow = build_flux_workflow(
|
||||
prompt=prompt,
|
||||
neg_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
model=model,
|
||||
)
|
||||
actual_seed = workflow["_meta"]["actual_seed"]
|
||||
|
||||
prompt_id = await client.queue_prompt(workflow)
|
||||
except httpx.ConnectError:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=(
|
||||
f"ComfyUI not reachable at {COMFYUI_URL}. "
|
||||
"Start it with: python main.py --listen"
|
||||
),
|
||||
)
|
||||
]
|
||||
except httpx.HTTPStatusError as e:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"ComfyUI returned an error: {e.response.status_code} — {e.response.text}",
|
||||
)
|
||||
]
|
||||
|
||||
# Poll until done
|
||||
start = time.time()
|
||||
while True:
|
||||
elapsed = time.time() - start
|
||||
if elapsed > COMFYUI_TIMEOUT:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=(
|
||||
f"Generation timed out after {COMFYUI_TIMEOUT}s. "
|
||||
f"prompt_id={prompt_id} — use get_generation_status to check"
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
queue = await client.get_status(prompt_id)
|
||||
except (httpx.ConnectError, httpx.HTTPStatusError):
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
running_ids = [item[1] for item in queue.get("queue_running", [])]
|
||||
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
|
||||
|
||||
if prompt_id not in running_ids and prompt_id not in pending_ids:
|
||||
break # Job is done
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Retrieve history to find output filename
|
||||
try:
|
||||
history = await client.get_history(prompt_id)
|
||||
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"Failed to retrieve generation history: {e}",
|
||||
)
|
||||
]
|
||||
|
||||
job = history.get(prompt_id, {})
|
||||
outputs = job.get("outputs", {})
|
||||
|
||||
# Find SaveImage node output (node "9" in our workflow)
|
||||
image_info = None
|
||||
for node_id, node_output in outputs.items():
|
||||
images = node_output.get("images", [])
|
||||
if images:
|
||||
image_info = images[0]
|
||||
break
|
||||
|
||||
if not image_info:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"No output image found in history for prompt_id={prompt_id}",
|
||||
)
|
||||
]
|
||||
|
||||
# Download image bytes
|
||||
try:
|
||||
image_bytes = await client.get_image(
|
||||
filename=image_info["filename"],
|
||||
subfolder=image_info.get("subfolder", ""),
|
||||
folder_type=image_info.get("type", "output"),
|
||||
)
|
||||
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"Failed to download generated image: {e}",
|
||||
)
|
||||
]
|
||||
|
||||
# Save to disk
|
||||
try:
|
||||
resolved_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{timestamp}_{actual_seed}.png"
|
||||
out_path = resolved_output_dir / filename
|
||||
out_path.write_bytes(image_bytes)
|
||||
except OSError as e:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"Cannot write to output directory: {resolved_output_dir} — {e}",
|
||||
)
|
||||
]
|
||||
|
||||
# Encode as base64 for inline display
|
||||
b64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=(
|
||||
f"Generated: {out_path}\n"
|
||||
f"Seed: {actual_seed}\n"
|
||||
f"Elapsed: {elapsed:.1f}s\n"
|
||||
f"Size: {width}x{height}, Steps: {steps}, Model: {model}"
|
||||
),
|
||||
),
|
||||
ImageContent(
|
||||
type="image",
|
||||
data=b64_data,
|
||||
mimeType="image/png",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_available_models() -> list[str]:
|
||||
"""List all checkpoint models available in ComfyUI.
|
||||
|
||||
Returns a list of model filenames available for use with generate_image.
|
||||
Requires ComfyUI to be running at COMFYUI_URL.
|
||||
"""
|
||||
client = ComfyUIClient(COMFYUI_URL)
|
||||
try:
|
||||
return await client.get_models()
|
||||
except httpx.ConnectError:
|
||||
return [
|
||||
f"ComfyUI not reachable at {COMFYUI_URL}. "
|
||||
"Start it with: python main.py --listen"
|
||||
]
|
||||
except httpx.HTTPStatusError as e:
|
||||
return [f"ComfyUI error: {e.response.status_code}"]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_generation_status(prompt_id: str) -> dict:
|
||||
"""Check the status of a queued or running generation job.
|
||||
|
||||
Args:
|
||||
prompt_id: The prompt ID returned by a previous generate_image call.
|
||||
|
||||
Returns:
|
||||
Dict with 'status' key: "pending", "running", "completed", or "not_found".
|
||||
"""
|
||||
client = ComfyUIClient(COMFYUI_URL)
|
||||
try:
|
||||
queue = await client.get_status(prompt_id)
|
||||
running_ids = [item[1] for item in queue.get("queue_running", [])]
|
||||
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
|
||||
|
||||
if prompt_id in running_ids:
|
||||
return {"status": "running", "prompt_id": prompt_id}
|
||||
if prompt_id in pending_ids:
|
||||
return {"status": "pending", "prompt_id": prompt_id}
|
||||
|
||||
# Not in queue — check history
|
||||
try:
|
||||
history = await client.get_history(prompt_id)
|
||||
if prompt_id in history:
|
||||
return {"status": "completed", "prompt_id": prompt_id}
|
||||
except (httpx.ConnectError, httpx.HTTPStatusError):
|
||||
pass
|
||||
|
||||
return {"status": "not_found", "prompt_id": prompt_id}
|
||||
|
||||
except httpx.ConnectError:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"ComfyUI not reachable at {COMFYUI_URL}",
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
return {"status": "error", "message": f"HTTP {e.response.status_code}"}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_output_directory() -> str:
|
||||
"""Return the directory where generated images are saved.
|
||||
|
||||
Returns:
|
||||
Absolute path to the output directory (may not exist yet).
|
||||
"""
|
||||
return str(Path(IMAGE_OUTPUT_DIR).expanduser().resolve())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run(transport="stdio")
|
||||
@@ -0,0 +1,59 @@
|
||||
{
|
||||
"6": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"clip": ["30", 1],
|
||||
"text": "PROMPT_PLACEHOLDER"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {
|
||||
"samples": ["13", 0],
|
||||
"vae": ["30", 2]
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {
|
||||
"filename_prefix": "mcp-image-gen",
|
||||
"images": ["8", 0]
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"cfg": 1.0,
|
||||
"denoise": 1.0,
|
||||
"latent_image": ["27", 0],
|
||||
"model": ["30", 0],
|
||||
"negative": ["33", 0],
|
||||
"positive": ["6", 0],
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"seed": 42,
|
||||
"steps": 4
|
||||
}
|
||||
},
|
||||
"27": {
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"inputs": {
|
||||
"batch_size": 1,
|
||||
"height": 1024,
|
||||
"width": 1024
|
||||
}
|
||||
},
|
||||
"30": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {
|
||||
"ckpt_name": "flux1-schnell.safetensors"
|
||||
}
|
||||
},
|
||||
"33": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"clip": ["30", 1],
|
||||
"text": "NEGATIVE_PLACEHOLDER"
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user