feat(mcp-image-gen): scaffold ComfyUI-backed image generation MCP server
- FastMCP server with 4 tools: generate_image, list_available_models, get_generation_status, get_output_directory - ComfyUI REST API client (httpx) polling lifecycle - FLUX.1-schnell workflow JSON template - Dual output: TextContent (path + seed) + ImageContent (base64 PNG) - 14 passing pytest tests with respx HTTP mocking - ROCm/AMD RX 7900 XTX optimized setup in README - Ollama Linux migration path documented (future)
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
"""Pytest fixtures for mcp-image-gen tests."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Make src/ importable
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def comfyui_url(monkeypatch):
|
||||
"""Set COMFYUI_URL to a test URL for all tests."""
|
||||
monkeypatch.setenv("COMFYUI_URL", "http://test-comfyui:8188")
|
||||
# Also patch the module-level constant in server
|
||||
import server
|
||||
monkeypatch.setattr(server, "COMFYUI_URL", "http://test-comfyui:8188")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_bytes():
|
||||
"""Generate a 1x1 red pixel PNG as bytes using Pillow."""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (1, 1), color=(255, 0, 0))
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_history_response():
|
||||
"""Sample ComfyUI history response for prompt_id='test-uuid-1234'."""
|
||||
return {
|
||||
"test-uuid-1234": {
|
||||
"outputs": {
|
||||
"9": {
|
||||
"images": [
|
||||
{
|
||||
"filename": "mcp-image-gen_00001_.png",
|
||||
"subfolder": "",
|
||||
"type": "output",
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"status": {"completed": True},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_empty():
|
||||
"""ComfyUI queue response with nothing running or pending."""
|
||||
return {"queue_running": [], "queue_pending": []}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_with_pending():
|
||||
"""ComfyUI queue response with our test prompt pending."""
|
||||
return {
|
||||
"queue_running": [],
|
||||
"queue_pending": [[1, "test-uuid-1234", {}, {}]],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_with_running():
|
||||
"""ComfyUI queue response with our test prompt running."""
|
||||
return {
|
||||
"queue_running": [[1, "test-uuid-1234", {}, {}]],
|
||||
"queue_pending": [],
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
"""Tests for mcp-image-gen server — all ComfyUI HTTP calls mocked via respx."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
# Import the server module (sys.path set by conftest.py)
|
||||
import server
|
||||
from server import (
|
||||
ComfyUIClient,
|
||||
build_flux_workflow,
|
||||
generate_image,
|
||||
get_generation_status,
|
||||
get_output_directory,
|
||||
list_available_models,
|
||||
)
|
||||
|
||||
COMFYUI_BASE = "http://test-comfyui:8188"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_flux_workflow — pure function, no mocking needed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_flux_workflow_structure():
|
||||
"""Verify build_flux_workflow returns a dict with correct node types."""
|
||||
wf = build_flux_workflow(
|
||||
prompt="a red cat",
|
||||
neg_prompt="ugly",
|
||||
width=512,
|
||||
height=768,
|
||||
steps=8,
|
||||
seed=42,
|
||||
model="flux1-schnell.safetensors",
|
||||
)
|
||||
assert wf["6"]["class_type"] == "CLIPTextEncode"
|
||||
assert wf["8"]["class_type"] == "VAEDecode"
|
||||
assert wf["9"]["class_type"] == "SaveImage"
|
||||
assert wf["13"]["class_type"] == "KSampler"
|
||||
assert wf["27"]["class_type"] == "EmptySD3LatentImage"
|
||||
assert wf["30"]["class_type"] == "CheckpointLoaderSimple"
|
||||
assert wf["33"]["class_type"] == "CLIPTextEncode"
|
||||
|
||||
|
||||
def test_build_flux_workflow_params_injected():
|
||||
"""Verify all parameters are injected into correct nodes."""
|
||||
wf = build_flux_workflow(
|
||||
prompt="a blue whale",
|
||||
neg_prompt="cartoonish",
|
||||
width=512,
|
||||
height=768,
|
||||
steps=8,
|
||||
seed=12345,
|
||||
model="sdxl.safetensors",
|
||||
)
|
||||
assert wf["6"]["inputs"]["text"] == "a blue whale"
|
||||
assert wf["33"]["inputs"]["text"] == "cartoonish"
|
||||
assert wf["27"]["inputs"]["width"] == 512
|
||||
assert wf["27"]["inputs"]["height"] == 768
|
||||
assert wf["13"]["inputs"]["steps"] == 8
|
||||
assert wf["13"]["inputs"]["seed"] == 12345
|
||||
assert wf["30"]["inputs"]["ckpt_name"] == "sdxl.safetensors"
|
||||
|
||||
|
||||
def test_negative_prompt_included():
|
||||
"""Verify negative prompt appears in workflow node 33 when provided."""
|
||||
wf = build_flux_workflow(
|
||||
prompt="forest",
|
||||
neg_prompt="blurry, dark",
|
||||
width=1024,
|
||||
height=1024,
|
||||
steps=4,
|
||||
seed=1,
|
||||
model="flux1-schnell.safetensors",
|
||||
)
|
||||
assert wf["33"]["inputs"]["text"] == "blurry, dark"
|
||||
|
||||
|
||||
def test_random_seed_generated():
|
||||
"""seed=-1 generates a random seed each call."""
|
||||
wf1 = build_flux_workflow("cat", "", 512, 512, 4, -1, "flux1-schnell.safetensors")
|
||||
wf2 = build_flux_workflow("cat", "", 512, 512, 4, -1, "flux1-schnell.safetensors")
|
||||
seed1 = wf1["_meta"]["actual_seed"]
|
||||
seed2 = wf2["_meta"]["actual_seed"]
|
||||
# Both are valid integers
|
||||
assert isinstance(seed1, int)
|
||||
assert 0 <= seed1 < 2**32
|
||||
# With overwhelming probability they differ
|
||||
# (1/2^32 chance of collision — negligible for a test)
|
||||
# We just verify _meta is populated
|
||||
assert "_meta" in wf1
|
||||
assert "_meta" in wf2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_available_models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_available_models():
|
||||
"""Mock /object_info, verify model list is returned."""
|
||||
mock_response = {
|
||||
"CheckpointLoaderSimple": {
|
||||
"input": {
|
||||
"required": {
|
||||
"ckpt_name": [
|
||||
["flux1-schnell.safetensors", "sdxl.safetensors"],
|
||||
{},
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
respx.get(f"{COMFYUI_BASE}/object_info/CheckpointLoaderSimple").mock(
|
||||
return_value=httpx.Response(200, json=mock_response)
|
||||
)
|
||||
|
||||
result = await list_available_models()
|
||||
assert "flux1-schnell.safetensors" in result
|
||||
assert "sdxl.safetensors" in result
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_available_models_comfyui_offline():
|
||||
"""When ComfyUI is unreachable, list_available_models returns error message."""
|
||||
respx.get(f"{COMFYUI_BASE}/object_info/CheckpointLoaderSimple").mock(
|
||||
side_effect=httpx.ConnectError("connection refused")
|
||||
)
|
||||
|
||||
result = await list_available_models()
|
||||
assert len(result) == 1
|
||||
assert "not reachable" in result[0].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_generation_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_generation_status_pending(queue_with_pending):
|
||||
"""prompt_id in queue_pending → status is 'pending'."""
|
||||
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
|
||||
return_value=httpx.Response(200, json=queue_with_pending)
|
||||
)
|
||||
|
||||
result = await get_generation_status("test-uuid-1234")
|
||||
assert result["status"] == "pending"
|
||||
assert result["prompt_id"] == "test-uuid-1234"
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_generation_status_running(queue_with_running):
|
||||
"""prompt_id in queue_running → status is 'running'."""
|
||||
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
|
||||
return_value=httpx.Response(200, json=queue_with_running)
|
||||
)
|
||||
|
||||
result = await get_generation_status("test-uuid-1234")
|
||||
assert result["status"] == "running"
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_generation_status_complete(queue_empty, mock_history_response):
|
||||
"""prompt_id not in queue + found in history → status is 'completed'."""
|
||||
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
|
||||
return_value=httpx.Response(200, json=queue_empty)
|
||||
)
|
||||
respx.get(f"{COMFYUI_BASE}/api/history/test-uuid-1234").mock(
|
||||
return_value=httpx.Response(200, json=mock_history_response)
|
||||
)
|
||||
|
||||
result = await get_generation_status("test-uuid-1234")
|
||||
assert result["status"] == "completed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_output_directory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_output_directory_default(monkeypatch):
|
||||
"""No IMAGE_OUTPUT_DIR env var → returns expanded ~/Pictures/mcp-generated."""
|
||||
monkeypatch.delenv("IMAGE_OUTPUT_DIR", raising=False)
|
||||
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", "~/Pictures/mcp-generated")
|
||||
|
||||
result = get_output_directory()
|
||||
assert result == str(Path("~/Pictures/mcp-generated").expanduser().resolve())
|
||||
assert "~" not in result # expanded
|
||||
|
||||
|
||||
def test_get_output_directory_custom(monkeypatch, tmp_path):
|
||||
"""IMAGE_OUTPUT_DIR set → returns that path."""
|
||||
custom = str(tmp_path / "custom-output")
|
||||
monkeypatch.setenv("IMAGE_OUTPUT_DIR", custom)
|
||||
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", custom)
|
||||
|
||||
result = get_output_directory()
|
||||
assert result == str(Path(custom).expanduser().resolve())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_image
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_success(
|
||||
tmp_path, sample_image_bytes, mock_history_response, queue_empty, monkeypatch
|
||||
):
|
||||
"""Mock full lifecycle: queue → poll done → history → view. Verify outputs."""
|
||||
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", str(tmp_path))
|
||||
|
||||
# 1. POST /api/prompt → prompt_id
|
||||
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(
|
||||
return_value=httpx.Response(200, json={"prompt_id": "test-uuid-1234"})
|
||||
)
|
||||
# 2. GET /api/queue → empty (job done immediately)
|
||||
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
|
||||
return_value=httpx.Response(200, json=queue_empty)
|
||||
)
|
||||
# 3. GET /api/history/test-uuid-1234
|
||||
respx.get(f"{COMFYUI_BASE}/api/history/test-uuid-1234").mock(
|
||||
return_value=httpx.Response(200, json=mock_history_response)
|
||||
)
|
||||
# 4. GET /api/view → image bytes
|
||||
respx.get(f"{COMFYUI_BASE}/api/view").mock(
|
||||
return_value=httpx.Response(200, content=sample_image_bytes)
|
||||
)
|
||||
|
||||
result = await generate_image(
|
||||
prompt="a red cat",
|
||||
output_dir=str(tmp_path),
|
||||
)
|
||||
|
||||
# Should return [TextContent, ImageContent]
|
||||
assert len(result) == 2
|
||||
text_content = result[0]
|
||||
image_content = result[1]
|
||||
|
||||
# TextContent has path info
|
||||
assert "Generated:" in text_content.text
|
||||
assert str(tmp_path) in text_content.text
|
||||
|
||||
# ImageContent has valid base64 PNG
|
||||
assert image_content.type == "image"
|
||||
assert image_content.mimeType == "image/png"
|
||||
decoded = base64.b64decode(image_content.data)
|
||||
assert decoded[:8] == b"\x89PNG\r\n\x1a\n" # PNG magic bytes
|
||||
|
||||
# File was actually saved
|
||||
saved_files = list(tmp_path.glob("*.png"))
|
||||
assert len(saved_files) == 1
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_comfyui_unavailable():
|
||||
"""ComfyUI unreachable → returns graceful error message as single TextContent."""
|
||||
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(
|
||||
side_effect=httpx.ConnectError("connection refused")
|
||||
)
|
||||
|
||||
result = await generate_image(prompt="a cat")
|
||||
|
||||
assert len(result) == 1
|
||||
assert "not reachable" in result[0].text.lower()
|
||||
|
||||
|
||||
@respx.mock
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_timeout(monkeypatch, queue_with_pending):
|
||||
"""Poll loop never completes within timeout → returns timeout error."""
|
||||
monkeypatch.setattr(server, "COMFYUI_TIMEOUT", 0) # instant timeout
|
||||
|
||||
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(
|
||||
return_value=httpx.Response(200, json={"prompt_id": "test-uuid-1234"})
|
||||
)
|
||||
# Queue always shows job pending → never finishes
|
||||
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
|
||||
return_value=httpx.Response(200, json=queue_with_pending)
|
||||
)
|
||||
|
||||
result = await generate_image(prompt="slow image")
|
||||
|
||||
assert len(result) == 1
|
||||
assert "timed out" in result[0].text.lower()
|
||||
assert "test-uuid-1234" in result[0].text
|
||||
Reference in New Issue
Block a user