"""mcp-image-gen — FastMCP server for AI image generation via ComfyUI.""" import asyncio import base64 import copy import json import os import random import time from datetime import datetime from pathlib import Path import httpx from fastmcp import FastMCP from mcp.types import ImageContent, TextContent # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- COMFYUI_URL = os.environ.get("COMFYUI_URL", "http://localhost:8188").rstrip("/") IMAGE_OUTPUT_DIR = os.environ.get("IMAGE_OUTPUT_DIR", "~/Pictures/mcp-generated") COMFYUI_TIMEOUT = int(os.environ.get("COMFYUI_TIMEOUT", "120")) # Path to the bundled FLUX.1-schnell workflow template _WORKFLOW_PATH = Path(__file__).parent / "workflows" / "flux_schnell.json" mcp = FastMCP("mcp-image-gen") # --------------------------------------------------------------------------- # ComfyUI client # --------------------------------------------------------------------------- class ComfyUIClient: """Async HTTP client wrapper for the ComfyUI REST API.""" def __init__(self, base_url: str = COMFYUI_URL): self.base_url = base_url.rstrip("/") async def queue_prompt(self, workflow: dict) -> str: """Submit a workflow to ComfyUI and return the prompt_id.""" payload = {"prompt": workflow} async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post(f"{self.base_url}/api/prompt", json=payload) resp.raise_for_status() return resp.json()["prompt_id"] async def get_status(self, prompt_id: str) -> dict: """Return the current queue state (queue_running + queue_pending lists).""" async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.get(f"{self.base_url}/api/queue") resp.raise_for_status() return resp.json() async def get_history(self, prompt_id: str) -> dict: """Return the history entry for a completed prompt_id.""" async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.get(f"{self.base_url}/api/history/{prompt_id}") resp.raise_for_status() return resp.json() async def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes: """Download image bytes from ComfyUI's /api/view endpoint.""" params = {"filename": filename, "subfolder": subfolder, "type": folder_type} async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.get(f"{self.base_url}/api/view", params=params) resp.raise_for_status() return resp.content async def get_models(self) -> list[str]: """Return the list of available checkpoint model filenames.""" async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.get( f"{self.base_url}/object_info/CheckpointLoaderSimple" ) resp.raise_for_status() data = resp.json() # ComfyUI returns: {"CheckpointLoaderSimple": {"input": {"required": {"ckpt_name": [["model1.safetensors", ...], ...]}}}} node_info = data.get("CheckpointLoaderSimple", {}) ckpt_list = ( node_info.get("input", {}) .get("required", {}) .get("ckpt_name", [[]])[0] ) return ckpt_list if isinstance(ckpt_list, list) else [] # --------------------------------------------------------------------------- # Workflow builder # --------------------------------------------------------------------------- def build_flux_workflow( prompt: str, neg_prompt: str, width: int, height: int, steps: int, seed: int, model: str, ) -> dict: """Build a ComfyUI API-format workflow dict for FLUX.1-schnell text-to-image. This is a pure function — no I/O, fully testable. """ with open(_WORKFLOW_PATH) as f: wf = json.load(f) wf = copy.deepcopy(wf) actual_seed = seed if seed != -1 else random.randint(0, 2**32 - 1) wf["6"]["inputs"]["text"] = prompt wf["33"]["inputs"]["text"] = neg_prompt wf["27"]["inputs"]["width"] = width wf["27"]["inputs"]["height"] = height wf["13"]["inputs"]["steps"] = steps wf["13"]["inputs"]["seed"] = actual_seed wf["30"]["inputs"]["ckpt_name"] = model # Attach the actual seed as metadata so callers can retrieve it wf["_meta"] = {"actual_seed": actual_seed} return wf # --------------------------------------------------------------------------- # Tools # --------------------------------------------------------------------------- @mcp.tool() async def generate_image( prompt: str, width: int = 1024, height: int = 1024, steps: int = 4, model: str = "flux1-schnell.safetensors", seed: int = -1, negative_prompt: str = "", output_dir: str = "", ) -> list: """Generate an image from a text prompt using ComfyUI. Returns both a file path (for persistence) and an inline base64 image (for display in Claude / Roo Code chat). Args: prompt: Text description of the image to generate. width: Image width in pixels (default: 1024). height: Image height in pixels (default: 1024). steps: Number of inference steps. FLUX.1-schnell works well at 4. model: ComfyUI model filename (default: flux1-schnell.safetensors). seed: Random seed for reproducibility. -1 = random. negative_prompt: Things to exclude from the image (optional). output_dir: Override output directory. Defaults to IMAGE_OUTPUT_DIR env var or ~/Pictures/mcp-generated. Returns: [TextContent(path + metadata), ImageContent(base64 PNG)] """ # Resolve output directory resolved_output_dir = Path( output_dir or IMAGE_OUTPUT_DIR ).expanduser().resolve() client = ComfyUIClient(COMFYUI_URL) # Build and submit workflow try: workflow = build_flux_workflow( prompt=prompt, neg_prompt=negative_prompt, width=width, height=height, steps=steps, seed=seed, model=model, ) actual_seed = workflow["_meta"]["actual_seed"] prompt_id = await client.queue_prompt(workflow) except httpx.ConnectError: return [ TextContent( type="text", text=( f"ComfyUI not reachable at {COMFYUI_URL}. " "Start it with: python main.py --listen" ), ) ] except httpx.HTTPStatusError as e: return [ TextContent( type="text", text=f"ComfyUI returned an error: {e.response.status_code} — {e.response.text}", ) ] # Poll until done start = time.time() while True: elapsed = time.time() - start if elapsed > COMFYUI_TIMEOUT: return [ TextContent( type="text", text=( f"Generation timed out after {COMFYUI_TIMEOUT}s. " f"prompt_id={prompt_id} — use get_generation_status to check" ), ) ] try: queue = await client.get_status(prompt_id) except (httpx.ConnectError, httpx.HTTPStatusError): await asyncio.sleep(2) continue running_ids = [item[1] for item in queue.get("queue_running", [])] pending_ids = [item[1] for item in queue.get("queue_pending", [])] if prompt_id not in running_ids and prompt_id not in pending_ids: break # Job is done await asyncio.sleep(2) elapsed = time.time() - start # Retrieve history to find output filename try: history = await client.get_history(prompt_id) except (httpx.ConnectError, httpx.HTTPStatusError) as e: return [ TextContent( type="text", text=f"Failed to retrieve generation history: {e}", ) ] job = history.get(prompt_id, {}) outputs = job.get("outputs", {}) # Find SaveImage node output (node "9" in our workflow) image_info = None for node_id, node_output in outputs.items(): images = node_output.get("images", []) if images: image_info = images[0] break if not image_info: return [ TextContent( type="text", text=f"No output image found in history for prompt_id={prompt_id}", ) ] # Download image bytes try: image_bytes = await client.get_image( filename=image_info["filename"], subfolder=image_info.get("subfolder", ""), folder_type=image_info.get("type", "output"), ) except (httpx.ConnectError, httpx.HTTPStatusError) as e: return [ TextContent( type="text", text=f"Failed to download generated image: {e}", ) ] # Save to disk try: resolved_output_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{timestamp}_{actual_seed}.png" out_path = resolved_output_dir / filename out_path.write_bytes(image_bytes) except OSError as e: return [ TextContent( type="text", text=f"Cannot write to output directory: {resolved_output_dir} — {e}", ) ] # Encode as base64 for inline display b64_data = base64.b64encode(image_bytes).decode("utf-8") return [ TextContent( type="text", text=( f"Generated: {out_path}\n" f"Seed: {actual_seed}\n" f"Elapsed: {elapsed:.1f}s\n" f"Size: {width}x{height}, Steps: {steps}, Model: {model}" ), ), ImageContent( type="image", data=b64_data, mimeType="image/png", ), ] @mcp.tool() async def list_available_models() -> list[str]: """List all checkpoint models available in ComfyUI. Returns a list of model filenames available for use with generate_image. Requires ComfyUI to be running at COMFYUI_URL. """ client = ComfyUIClient(COMFYUI_URL) try: return await client.get_models() except httpx.ConnectError: return [ f"ComfyUI not reachable at {COMFYUI_URL}. " "Start it with: python main.py --listen" ] except httpx.HTTPStatusError as e: return [f"ComfyUI error: {e.response.status_code}"] @mcp.tool() async def get_generation_status(prompt_id: str) -> dict: """Check the status of a queued or running generation job. Args: prompt_id: The prompt ID returned by a previous generate_image call. Returns: Dict with 'status' key: "pending", "running", "completed", or "not_found". """ client = ComfyUIClient(COMFYUI_URL) try: queue = await client.get_status(prompt_id) running_ids = [item[1] for item in queue.get("queue_running", [])] pending_ids = [item[1] for item in queue.get("queue_pending", [])] if prompt_id in running_ids: return {"status": "running", "prompt_id": prompt_id} if prompt_id in pending_ids: return {"status": "pending", "prompt_id": prompt_id} # Not in queue — check history try: history = await client.get_history(prompt_id) if prompt_id in history: return {"status": "completed", "prompt_id": prompt_id} except (httpx.ConnectError, httpx.HTTPStatusError): pass return {"status": "not_found", "prompt_id": prompt_id} except httpx.ConnectError: return { "status": "error", "message": f"ComfyUI not reachable at {COMFYUI_URL}", } except httpx.HTTPStatusError as e: return {"status": "error", "message": f"HTTP {e.response.status_code}"} @mcp.tool() def get_output_directory() -> str: """Return the directory where generated images are saved. Returns: Absolute path to the output directory (may not exist yet). """ return str(Path(IMAGE_OUTPUT_DIR).expanduser().resolve()) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- if __name__ == "__main__": mcp.run(transport="stdio")