Files
pi_mcps/mcp/mcp-image-gen/tests/test_server.py
T
Patrick Plate 8112ff2f12 feat(mcp-image-gen): scaffold ComfyUI-backed image generation MCP server
- FastMCP server with 4 tools: generate_image, list_available_models,
  get_generation_status, get_output_directory
- ComfyUI REST API client (httpx) polling lifecycle
- FLUX.1-schnell workflow JSON template
- Dual output: TextContent (path + seed) + ImageContent (base64 PNG)
- 14 passing pytest tests with respx HTTP mocking
- ROCm/AMD RX 7900 XTX optimized setup in README
- Ollama Linux migration path documented (future)
2026-04-04 11:49:31 +02:00

303 lines
9.9 KiB
Python

"""Tests for mcp-image-gen server — all ComfyUI HTTP calls mocked via respx."""
import base64
import json
import os
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
import respx
# Import the server module (sys.path set by conftest.py)
import server
from server import (
ComfyUIClient,
build_flux_workflow,
generate_image,
get_generation_status,
get_output_directory,
list_available_models,
)
COMFYUI_BASE = "http://test-comfyui:8188"
# ---------------------------------------------------------------------------
# build_flux_workflow — pure function, no mocking needed
# ---------------------------------------------------------------------------
def test_build_flux_workflow_structure():
"""Verify build_flux_workflow returns a dict with correct node types."""
wf = build_flux_workflow(
prompt="a red cat",
neg_prompt="ugly",
width=512,
height=768,
steps=8,
seed=42,
model="flux1-schnell.safetensors",
)
assert wf["6"]["class_type"] == "CLIPTextEncode"
assert wf["8"]["class_type"] == "VAEDecode"
assert wf["9"]["class_type"] == "SaveImage"
assert wf["13"]["class_type"] == "KSampler"
assert wf["27"]["class_type"] == "EmptySD3LatentImage"
assert wf["30"]["class_type"] == "CheckpointLoaderSimple"
assert wf["33"]["class_type"] == "CLIPTextEncode"
def test_build_flux_workflow_params_injected():
"""Verify all parameters are injected into correct nodes."""
wf = build_flux_workflow(
prompt="a blue whale",
neg_prompt="cartoonish",
width=512,
height=768,
steps=8,
seed=12345,
model="sdxl.safetensors",
)
assert wf["6"]["inputs"]["text"] == "a blue whale"
assert wf["33"]["inputs"]["text"] == "cartoonish"
assert wf["27"]["inputs"]["width"] == 512
assert wf["27"]["inputs"]["height"] == 768
assert wf["13"]["inputs"]["steps"] == 8
assert wf["13"]["inputs"]["seed"] == 12345
assert wf["30"]["inputs"]["ckpt_name"] == "sdxl.safetensors"
def test_negative_prompt_included():
"""Verify negative prompt appears in workflow node 33 when provided."""
wf = build_flux_workflow(
prompt="forest",
neg_prompt="blurry, dark",
width=1024,
height=1024,
steps=4,
seed=1,
model="flux1-schnell.safetensors",
)
assert wf["33"]["inputs"]["text"] == "blurry, dark"
def test_random_seed_generated():
"""seed=-1 generates a random seed each call."""
wf1 = build_flux_workflow("cat", "", 512, 512, 4, -1, "flux1-schnell.safetensors")
wf2 = build_flux_workflow("cat", "", 512, 512, 4, -1, "flux1-schnell.safetensors")
seed1 = wf1["_meta"]["actual_seed"]
seed2 = wf2["_meta"]["actual_seed"]
# Both are valid integers
assert isinstance(seed1, int)
assert 0 <= seed1 < 2**32
# With overwhelming probability they differ
# (1/2^32 chance of collision — negligible for a test)
# We just verify _meta is populated
assert "_meta" in wf1
assert "_meta" in wf2
# ---------------------------------------------------------------------------
# list_available_models
# ---------------------------------------------------------------------------
@respx.mock
@pytest.mark.asyncio
async def test_list_available_models():
"""Mock /object_info, verify model list is returned."""
mock_response = {
"CheckpointLoaderSimple": {
"input": {
"required": {
"ckpt_name": [
["flux1-schnell.safetensors", "sdxl.safetensors"],
{},
]
}
}
}
}
respx.get(f"{COMFYUI_BASE}/object_info/CheckpointLoaderSimple").mock(
return_value=httpx.Response(200, json=mock_response)
)
result = await list_available_models()
assert "flux1-schnell.safetensors" in result
assert "sdxl.safetensors" in result
@respx.mock
@pytest.mark.asyncio
async def test_list_available_models_comfyui_offline():
"""When ComfyUI is unreachable, list_available_models returns error message."""
respx.get(f"{COMFYUI_BASE}/object_info/CheckpointLoaderSimple").mock(
side_effect=httpx.ConnectError("connection refused")
)
result = await list_available_models()
assert len(result) == 1
assert "not reachable" in result[0].lower()
# ---------------------------------------------------------------------------
# get_generation_status
# ---------------------------------------------------------------------------
@respx.mock
@pytest.mark.asyncio
async def test_get_generation_status_pending(queue_with_pending):
"""prompt_id in queue_pending → status is 'pending'."""
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
return_value=httpx.Response(200, json=queue_with_pending)
)
result = await get_generation_status("test-uuid-1234")
assert result["status"] == "pending"
assert result["prompt_id"] == "test-uuid-1234"
@respx.mock
@pytest.mark.asyncio
async def test_get_generation_status_running(queue_with_running):
"""prompt_id in queue_running → status is 'running'."""
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
return_value=httpx.Response(200, json=queue_with_running)
)
result = await get_generation_status("test-uuid-1234")
assert result["status"] == "running"
@respx.mock
@pytest.mark.asyncio
async def test_get_generation_status_complete(queue_empty, mock_history_response):
"""prompt_id not in queue + found in history → status is 'completed'."""
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
return_value=httpx.Response(200, json=queue_empty)
)
respx.get(f"{COMFYUI_BASE}/api/history/test-uuid-1234").mock(
return_value=httpx.Response(200, json=mock_history_response)
)
result = await get_generation_status("test-uuid-1234")
assert result["status"] == "completed"
# ---------------------------------------------------------------------------
# get_output_directory
# ---------------------------------------------------------------------------
def test_get_output_directory_default(monkeypatch):
"""No IMAGE_OUTPUT_DIR env var → returns expanded ~/Pictures/mcp-generated."""
monkeypatch.delenv("IMAGE_OUTPUT_DIR", raising=False)
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", "~/Pictures/mcp-generated")
result = get_output_directory()
assert result == str(Path("~/Pictures/mcp-generated").expanduser().resolve())
assert "~" not in result # expanded
def test_get_output_directory_custom(monkeypatch, tmp_path):
"""IMAGE_OUTPUT_DIR set → returns that path."""
custom = str(tmp_path / "custom-output")
monkeypatch.setenv("IMAGE_OUTPUT_DIR", custom)
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", custom)
result = get_output_directory()
assert result == str(Path(custom).expanduser().resolve())
# ---------------------------------------------------------------------------
# generate_image
# ---------------------------------------------------------------------------
@respx.mock
@pytest.mark.asyncio
async def test_generate_image_success(
tmp_path, sample_image_bytes, mock_history_response, queue_empty, monkeypatch
):
"""Mock full lifecycle: queue → poll done → history → view. Verify outputs."""
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", str(tmp_path))
# 1. POST /api/prompt → prompt_id
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(
return_value=httpx.Response(200, json={"prompt_id": "test-uuid-1234"})
)
# 2. GET /api/queue → empty (job done immediately)
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
return_value=httpx.Response(200, json=queue_empty)
)
# 3. GET /api/history/test-uuid-1234
respx.get(f"{COMFYUI_BASE}/api/history/test-uuid-1234").mock(
return_value=httpx.Response(200, json=mock_history_response)
)
# 4. GET /api/view → image bytes
respx.get(f"{COMFYUI_BASE}/api/view").mock(
return_value=httpx.Response(200, content=sample_image_bytes)
)
result = await generate_image(
prompt="a red cat",
output_dir=str(tmp_path),
)
# Should return [TextContent, ImageContent]
assert len(result) == 2
text_content = result[0]
image_content = result[1]
# TextContent has path info
assert "Generated:" in text_content.text
assert str(tmp_path) in text_content.text
# ImageContent has valid base64 PNG
assert image_content.type == "image"
assert image_content.mimeType == "image/png"
decoded = base64.b64decode(image_content.data)
assert decoded[:8] == b"\x89PNG\r\n\x1a\n" # PNG magic bytes
# File was actually saved
saved_files = list(tmp_path.glob("*.png"))
assert len(saved_files) == 1
@respx.mock
@pytest.mark.asyncio
async def test_generate_image_comfyui_unavailable():
"""ComfyUI unreachable → returns graceful error message as single TextContent."""
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(
side_effect=httpx.ConnectError("connection refused")
)
result = await generate_image(prompt="a cat")
assert len(result) == 1
assert "not reachable" in result[0].text.lower()
@respx.mock
@pytest.mark.asyncio
async def test_generate_image_timeout(monkeypatch, queue_with_pending):
"""Poll loop never completes within timeout → returns timeout error."""
monkeypatch.setattr(server, "COMFYUI_TIMEOUT", 0) # instant timeout
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(
return_value=httpx.Response(200, json={"prompt_id": "test-uuid-1234"})
)
# Queue always shows job pending → never finishes
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
return_value=httpx.Response(200, json=queue_with_pending)
)
result = await generate_image(prompt="slow image")
assert len(result) == 1
assert "timed out" in result[0].text.lower()
assert "test-uuid-1234" in result[0].text