8112ff2f12
- FastMCP server with 4 tools: generate_image, list_available_models, get_generation_status, get_output_directory - ComfyUI REST API client (httpx) polling lifecycle - FLUX.1-schnell workflow JSON template - Dual output: TextContent (path + seed) + ImageContent (base64 PNG) - 14 passing pytest tests with respx HTTP mocking - ROCm/AMD RX 7900 XTX optimized setup in README - Ollama Linux migration path documented (future)
385 lines
13 KiB
Python
385 lines
13 KiB
Python
"""mcp-image-gen — FastMCP server for AI image generation via ComfyUI."""
|
|
|
|
import asyncio
|
|
import base64
|
|
import copy
|
|
import json
|
|
import os
|
|
import random
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
from fastmcp import FastMCP
|
|
from mcp.types import ImageContent, TextContent
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
COMFYUI_URL = os.environ.get("COMFYUI_URL", "http://localhost:8188").rstrip("/")
|
|
IMAGE_OUTPUT_DIR = os.environ.get("IMAGE_OUTPUT_DIR", "~/Pictures/mcp-generated")
|
|
COMFYUI_TIMEOUT = int(os.environ.get("COMFYUI_TIMEOUT", "120"))
|
|
|
|
# Path to the bundled FLUX.1-schnell workflow template
|
|
_WORKFLOW_PATH = Path(__file__).parent / "workflows" / "flux_schnell.json"
|
|
|
|
mcp = FastMCP("mcp-image-gen")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ComfyUI client
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class ComfyUIClient:
|
|
"""Async HTTP client wrapper for the ComfyUI REST API."""
|
|
|
|
def __init__(self, base_url: str = COMFYUI_URL):
|
|
self.base_url = base_url.rstrip("/")
|
|
|
|
async def queue_prompt(self, workflow: dict) -> str:
|
|
"""Submit a workflow to ComfyUI and return the prompt_id."""
|
|
payload = {"prompt": workflow}
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(f"{self.base_url}/api/prompt", json=payload)
|
|
resp.raise_for_status()
|
|
return resp.json()["prompt_id"]
|
|
|
|
async def get_status(self, prompt_id: str) -> dict:
|
|
"""Return the current queue state (queue_running + queue_pending lists)."""
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.get(f"{self.base_url}/api/queue")
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
async def get_history(self, prompt_id: str) -> dict:
|
|
"""Return the history entry for a completed prompt_id."""
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.get(f"{self.base_url}/api/history/{prompt_id}")
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
async def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
|
|
"""Download image bytes from ComfyUI's /api/view endpoint."""
|
|
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
resp = await client.get(f"{self.base_url}/api/view", params=params)
|
|
resp.raise_for_status()
|
|
return resp.content
|
|
|
|
async def get_models(self) -> list[str]:
|
|
"""Return the list of available checkpoint model filenames."""
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.get(
|
|
f"{self.base_url}/object_info/CheckpointLoaderSimple"
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
# ComfyUI returns: {"CheckpointLoaderSimple": {"input": {"required": {"ckpt_name": [["model1.safetensors", ...], ...]}}}}
|
|
node_info = data.get("CheckpointLoaderSimple", {})
|
|
ckpt_list = (
|
|
node_info.get("input", {})
|
|
.get("required", {})
|
|
.get("ckpt_name", [[]])[0]
|
|
)
|
|
return ckpt_list if isinstance(ckpt_list, list) else []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Workflow builder
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def build_flux_workflow(
|
|
prompt: str,
|
|
neg_prompt: str,
|
|
width: int,
|
|
height: int,
|
|
steps: int,
|
|
seed: int,
|
|
model: str,
|
|
) -> dict:
|
|
"""Build a ComfyUI API-format workflow dict for FLUX.1-schnell text-to-image.
|
|
|
|
This is a pure function — no I/O, fully testable.
|
|
"""
|
|
with open(_WORKFLOW_PATH) as f:
|
|
wf = json.load(f)
|
|
wf = copy.deepcopy(wf)
|
|
|
|
actual_seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
|
|
|
|
wf["6"]["inputs"]["text"] = prompt
|
|
wf["33"]["inputs"]["text"] = neg_prompt
|
|
wf["27"]["inputs"]["width"] = width
|
|
wf["27"]["inputs"]["height"] = height
|
|
wf["13"]["inputs"]["steps"] = steps
|
|
wf["13"]["inputs"]["seed"] = actual_seed
|
|
wf["30"]["inputs"]["ckpt_name"] = model
|
|
|
|
# Attach the actual seed as metadata so callers can retrieve it
|
|
wf["_meta"] = {"actual_seed": actual_seed}
|
|
return wf
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tools
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@mcp.tool()
|
|
async def generate_image(
|
|
prompt: str,
|
|
width: int = 1024,
|
|
height: int = 1024,
|
|
steps: int = 4,
|
|
model: str = "flux1-schnell.safetensors",
|
|
seed: int = -1,
|
|
negative_prompt: str = "",
|
|
output_dir: str = "",
|
|
) -> list:
|
|
"""Generate an image from a text prompt using ComfyUI.
|
|
|
|
Returns both a file path (for persistence) and an inline base64 image
|
|
(for display in Claude / Roo Code chat).
|
|
|
|
Args:
|
|
prompt: Text description of the image to generate.
|
|
width: Image width in pixels (default: 1024).
|
|
height: Image height in pixels (default: 1024).
|
|
steps: Number of inference steps. FLUX.1-schnell works well at 4.
|
|
model: ComfyUI model filename (default: flux1-schnell.safetensors).
|
|
seed: Random seed for reproducibility. -1 = random.
|
|
negative_prompt: Things to exclude from the image (optional).
|
|
output_dir: Override output directory. Defaults to IMAGE_OUTPUT_DIR env var
|
|
or ~/Pictures/mcp-generated.
|
|
|
|
Returns:
|
|
[TextContent(path + metadata), ImageContent(base64 PNG)]
|
|
"""
|
|
# Resolve output directory
|
|
resolved_output_dir = Path(
|
|
output_dir or IMAGE_OUTPUT_DIR
|
|
).expanduser().resolve()
|
|
|
|
client = ComfyUIClient(COMFYUI_URL)
|
|
|
|
# Build and submit workflow
|
|
try:
|
|
workflow = build_flux_workflow(
|
|
prompt=prompt,
|
|
neg_prompt=negative_prompt,
|
|
width=width,
|
|
height=height,
|
|
steps=steps,
|
|
seed=seed,
|
|
model=model,
|
|
)
|
|
actual_seed = workflow["_meta"]["actual_seed"]
|
|
|
|
prompt_id = await client.queue_prompt(workflow)
|
|
except httpx.ConnectError:
|
|
return [
|
|
TextContent(
|
|
type="text",
|
|
text=(
|
|
f"ComfyUI not reachable at {COMFYUI_URL}. "
|
|
"Start it with: python main.py --listen"
|
|
),
|
|
)
|
|
]
|
|
except httpx.HTTPStatusError as e:
|
|
return [
|
|
TextContent(
|
|
type="text",
|
|
text=f"ComfyUI returned an error: {e.response.status_code} — {e.response.text}",
|
|
)
|
|
]
|
|
|
|
# Poll until done
|
|
start = time.time()
|
|
while True:
|
|
elapsed = time.time() - start
|
|
if elapsed > COMFYUI_TIMEOUT:
|
|
return [
|
|
TextContent(
|
|
type="text",
|
|
text=(
|
|
f"Generation timed out after {COMFYUI_TIMEOUT}s. "
|
|
f"prompt_id={prompt_id} — use get_generation_status to check"
|
|
),
|
|
)
|
|
]
|
|
|
|
try:
|
|
queue = await client.get_status(prompt_id)
|
|
except (httpx.ConnectError, httpx.HTTPStatusError):
|
|
await asyncio.sleep(2)
|
|
continue
|
|
|
|
running_ids = [item[1] for item in queue.get("queue_running", [])]
|
|
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
|
|
|
|
if prompt_id not in running_ids and prompt_id not in pending_ids:
|
|
break # Job is done
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
elapsed = time.time() - start
|
|
|
|
# Retrieve history to find output filename
|
|
try:
|
|
history = await client.get_history(prompt_id)
|
|
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
|
return [
|
|
TextContent(
|
|
type="text",
|
|
text=f"Failed to retrieve generation history: {e}",
|
|
)
|
|
]
|
|
|
|
job = history.get(prompt_id, {})
|
|
outputs = job.get("outputs", {})
|
|
|
|
# Find SaveImage node output (node "9" in our workflow)
|
|
image_info = None
|
|
for node_id, node_output in outputs.items():
|
|
images = node_output.get("images", [])
|
|
if images:
|
|
image_info = images[0]
|
|
break
|
|
|
|
if not image_info:
|
|
return [
|
|
TextContent(
|
|
type="text",
|
|
text=f"No output image found in history for prompt_id={prompt_id}",
|
|
)
|
|
]
|
|
|
|
# Download image bytes
|
|
try:
|
|
image_bytes = await client.get_image(
|
|
filename=image_info["filename"],
|
|
subfolder=image_info.get("subfolder", ""),
|
|
folder_type=image_info.get("type", "output"),
|
|
)
|
|
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
|
return [
|
|
TextContent(
|
|
type="text",
|
|
text=f"Failed to download generated image: {e}",
|
|
)
|
|
]
|
|
|
|
# Save to disk
|
|
try:
|
|
resolved_output_dir.mkdir(parents=True, exist_ok=True)
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{timestamp}_{actual_seed}.png"
|
|
out_path = resolved_output_dir / filename
|
|
out_path.write_bytes(image_bytes)
|
|
except OSError as e:
|
|
return [
|
|
TextContent(
|
|
type="text",
|
|
text=f"Cannot write to output directory: {resolved_output_dir} — {e}",
|
|
)
|
|
]
|
|
|
|
# Encode as base64 for inline display
|
|
b64_data = base64.b64encode(image_bytes).decode("utf-8")
|
|
|
|
return [
|
|
TextContent(
|
|
type="text",
|
|
text=(
|
|
f"Generated: {out_path}\n"
|
|
f"Seed: {actual_seed}\n"
|
|
f"Elapsed: {elapsed:.1f}s\n"
|
|
f"Size: {width}x{height}, Steps: {steps}, Model: {model}"
|
|
),
|
|
),
|
|
ImageContent(
|
|
type="image",
|
|
data=b64_data,
|
|
mimeType="image/png",
|
|
),
|
|
]
|
|
|
|
|
|
@mcp.tool()
|
|
async def list_available_models() -> list[str]:
|
|
"""List all checkpoint models available in ComfyUI.
|
|
|
|
Returns a list of model filenames available for use with generate_image.
|
|
Requires ComfyUI to be running at COMFYUI_URL.
|
|
"""
|
|
client = ComfyUIClient(COMFYUI_URL)
|
|
try:
|
|
return await client.get_models()
|
|
except httpx.ConnectError:
|
|
return [
|
|
f"ComfyUI not reachable at {COMFYUI_URL}. "
|
|
"Start it with: python main.py --listen"
|
|
]
|
|
except httpx.HTTPStatusError as e:
|
|
return [f"ComfyUI error: {e.response.status_code}"]
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_generation_status(prompt_id: str) -> dict:
|
|
"""Check the status of a queued or running generation job.
|
|
|
|
Args:
|
|
prompt_id: The prompt ID returned by a previous generate_image call.
|
|
|
|
Returns:
|
|
Dict with 'status' key: "pending", "running", "completed", or "not_found".
|
|
"""
|
|
client = ComfyUIClient(COMFYUI_URL)
|
|
try:
|
|
queue = await client.get_status(prompt_id)
|
|
running_ids = [item[1] for item in queue.get("queue_running", [])]
|
|
pending_ids = [item[1] for item in queue.get("queue_pending", [])]
|
|
|
|
if prompt_id in running_ids:
|
|
return {"status": "running", "prompt_id": prompt_id}
|
|
if prompt_id in pending_ids:
|
|
return {"status": "pending", "prompt_id": prompt_id}
|
|
|
|
# Not in queue — check history
|
|
try:
|
|
history = await client.get_history(prompt_id)
|
|
if prompt_id in history:
|
|
return {"status": "completed", "prompt_id": prompt_id}
|
|
except (httpx.ConnectError, httpx.HTTPStatusError):
|
|
pass
|
|
|
|
return {"status": "not_found", "prompt_id": prompt_id}
|
|
|
|
except httpx.ConnectError:
|
|
return {
|
|
"status": "error",
|
|
"message": f"ComfyUI not reachable at {COMFYUI_URL}",
|
|
}
|
|
except httpx.HTTPStatusError as e:
|
|
return {"status": "error", "message": f"HTTP {e.response.status_code}"}
|
|
|
|
|
|
@mcp.tool()
|
|
def get_output_directory() -> str:
|
|
"""Return the directory where generated images are saved.
|
|
|
|
Returns:
|
|
Absolute path to the output directory (may not exist yet).
|
|
"""
|
|
return str(Path(IMAGE_OUTPUT_DIR).expanduser().resolve())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
if __name__ == "__main__":
|
|
mcp.run(transport="stdio")
|