feat(mcp-image-gen): add name and count params to generate_image

- Add name (str) param: filename prefix saved as {name}_{timestamp}_{seed}.png
- Add count (int, 1-10) param: generate N images in one call
- Extract _sanitize_name() helper: strips special chars, collapses underscores, caps at 64 chars
- Extract _build_filename() helper: pure function for testable filename construction
- Extract _generate_single() coroutine: clean loop body for batch generation
- Fixed seed batches increment seed per image (seed+i-1) for deterministic variation
- random seed (-1) batches give independent random seeds per image
- Partial batch failures continue (error TextContent in slot, remaining images proceed)
- Returns flat interleaved [Text1, Image1, Text2, Image2, ...] list
- 34/34 tests passing (was 19, added 15 new tests)
This commit is contained in:
Patrick Plate
2026-04-06 07:45:37 +02:00
parent 79a2e1d10a
commit 79f1e6d65f
3 changed files with 794 additions and 45 deletions
+319 -1
View File
@@ -14,6 +14,8 @@ import respx
import server
from server import (
ComfyUIClient,
_build_filename,
_sanitize_name,
build_flux_workflow,
generate_image,
get_generation_status,
@@ -100,6 +102,74 @@ def test_random_seed_generated():
assert "_meta" in wf2
# ---------------------------------------------------------------------------
# _sanitize_name — pure function
# ---------------------------------------------------------------------------
def test_sanitize_name_basic():
"""Simple alphanumeric name passes through unchanged."""
assert _sanitize_name("lumen_profile") == "lumen_profile"
def test_sanitize_name_spaces_to_underscores():
"""Spaces are converted to underscores."""
assert _sanitize_name("my cool image") == "my_cool_image"
def test_sanitize_name_special_chars_stripped():
"""Special characters (!, @, #, etc.) are stripped."""
result = _sanitize_name("hello! world@2024#")
assert "!" not in result
assert "@" not in result
assert "#" not in result
assert "hello" in result
assert "world" in result
def test_sanitize_name_empty_returns_empty():
"""Empty string or whitespace-only returns empty string."""
assert _sanitize_name("") == ""
assert _sanitize_name(" ") == ""
def test_sanitize_name_collapse_underscores():
"""Multiple consecutive underscores/hyphens are collapsed to one."""
result = _sanitize_name("lumen__profile")
assert "__" not in result
def test_sanitize_name_truncates_at_64():
"""Names longer than 64 chars are truncated."""
long_name = "a" * 100
result = _sanitize_name(long_name)
assert len(result) <= 64
# ---------------------------------------------------------------------------
# _build_filename — pure function
# ---------------------------------------------------------------------------
def test_build_filename_with_name():
"""When name is provided, filename includes it as prefix."""
filename = _build_filename("lumen", "20260406_120000", 12345)
assert filename == "lumen_20260406_120000_12345.png"
def test_build_filename_without_name():
"""When name is empty, filename is timestamp_seed.png."""
filename = _build_filename("", "20260406_120000", 12345)
assert filename == "20260406_120000_12345.png"
assert not filename.startswith("_")
def test_build_filename_sanitizes_name():
"""Name with spaces and special chars is sanitized before use in filename."""
filename = _build_filename("my image!", "20260406_120000", 99)
assert "!" not in filename
assert "my_image" in filename
assert filename.endswith("_20260406_120000_99.png")
# ---------------------------------------------------------------------------
# list_available_models
# ---------------------------------------------------------------------------
@@ -211,7 +281,255 @@ def test_get_output_directory_custom(monkeypatch, tmp_path):
# ---------------------------------------------------------------------------
# generate_image
# generate_image — count/name validation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_generate_image_count_zero_returns_error():
"""count=0 → returns error TextContent without calling ComfyUI."""
result = await generate_image(prompt="a cat", count=0)
assert len(result) == 1
assert "count must be at least 1" in result[0].text
@pytest.mark.asyncio
async def test_generate_image_count_exceeds_max_returns_error():
"""count=11 (> MAX_COUNT=10) → returns error TextContent without calling ComfyUI."""
result = await generate_image(prompt="a cat", count=11)
assert len(result) == 1
assert "at most 10" in result[0].text
# ---------------------------------------------------------------------------
# generate_image — name parameter
# ---------------------------------------------------------------------------
@respx.mock
@pytest.mark.asyncio
async def test_generate_image_with_name(
tmp_path, sample_image_bytes, mock_history_response, queue_empty, monkeypatch
):
"""name param → saved file has name as prefix."""
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", str(tmp_path))
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(
return_value=httpx.Response(200, json={"prompt_id": "name-test-uuid"})
)
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
return_value=httpx.Response(200, json=queue_empty)
)
mock_history_named = {
"name-test-uuid": {
"outputs": {
"9": {
"images": [
{
"filename": "mcp-image-gen_00001_.png",
"subfolder": "",
"type": "output",
}
]
}
},
"status": {"completed": True},
}
}
respx.get(f"{COMFYUI_BASE}/api/history/name-test-uuid").mock(
return_value=httpx.Response(200, json=mock_history_named)
)
respx.get(f"{COMFYUI_BASE}/api/view").mock(
return_value=httpx.Response(200, content=sample_image_bytes)
)
result = await generate_image(
prompt="lumen portrait",
name="lumen_profile",
output_dir=str(tmp_path),
)
assert len(result) == 2
saved_files = list(tmp_path.glob("lumen_profile_*.png"))
assert len(saved_files) == 1, f"Expected 1 file with 'lumen_profile_' prefix, got: {list(tmp_path.glob('*.png'))}"
# Path in TextContent also has the name prefix
assert "lumen_profile_" in result[0].text
# ---------------------------------------------------------------------------
# generate_image — count=2 batch
# ---------------------------------------------------------------------------
@respx.mock
@pytest.mark.asyncio
async def test_generate_image_count_2(
tmp_path, sample_image_bytes, queue_empty, monkeypatch
):
"""count=2 → returns 4 content items (Text+Image per image), 2 files saved."""
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", str(tmp_path))
# First image
mock_history_1 = {
"uuid-batch-1": {
"outputs": {"9": {"images": [{"filename": "img1.png", "subfolder": "", "type": "output"}]}},
"status": {"completed": True},
}
}
# Second image
mock_history_2 = {
"uuid-batch-2": {
"outputs": {"9": {"images": [{"filename": "img2.png", "subfolder": "", "type": "output"}]}},
"status": {"completed": True},
}
}
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(
side_effect=[
httpx.Response(200, json={"prompt_id": "uuid-batch-1"}),
httpx.Response(200, json={"prompt_id": "uuid-batch-2"}),
]
)
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
return_value=httpx.Response(200, json=queue_empty)
)
respx.get(f"{COMFYUI_BASE}/api/history/uuid-batch-1").mock(
return_value=httpx.Response(200, json=mock_history_1)
)
respx.get(f"{COMFYUI_BASE}/api/history/uuid-batch-2").mock(
return_value=httpx.Response(200, json=mock_history_2)
)
respx.get(f"{COMFYUI_BASE}/api/view").mock(
return_value=httpx.Response(200, content=sample_image_bytes)
)
result = await generate_image(
prompt="a landscape",
count=2,
seed=100,
output_dir=str(tmp_path),
)
# 4 content items: [Text1, Image1, Text2, Image2]
assert len(result) == 4
assert result[0].type == "text"
assert result[1].type == "image"
assert result[2].type == "text"
assert result[3].type == "image"
# 2 files saved
saved_files = list(tmp_path.glob("*.png"))
assert len(saved_files) == 2
# Label contains batch index
assert "1/2" in result[0].text
assert "2/2" in result[2].text
@respx.mock
@pytest.mark.asyncio
async def test_generate_image_count_2_fixed_seed_increments(
tmp_path, sample_image_bytes, queue_empty, monkeypatch
):
"""count=2 with fixed seed → seeds are incremented (seed, seed+1)."""
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", str(tmp_path))
submitted_seeds = []
def capture_prompt(request):
body = json.loads(request.content)
seed_val = body["prompt"]["13"]["inputs"]["seed"]
submitted_seeds.append(seed_val)
idx = len(submitted_seeds)
return httpx.Response(200, json={"prompt_id": f"seed-test-{idx}"})
mock_history_1 = {
"seed-test-1": {
"outputs": {"9": {"images": [{"filename": "img1.png", "subfolder": "", "type": "output"}]}},
"status": {"completed": True},
}
}
mock_history_2 = {
"seed-test-2": {
"outputs": {"9": {"images": [{"filename": "img2.png", "subfolder": "", "type": "output"}]}},
"status": {"completed": True},
}
}
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(side_effect=capture_prompt)
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
return_value=httpx.Response(200, json=queue_empty)
)
respx.get(f"{COMFYUI_BASE}/api/history/seed-test-1").mock(
return_value=httpx.Response(200, json=mock_history_1)
)
respx.get(f"{COMFYUI_BASE}/api/history/seed-test-2").mock(
return_value=httpx.Response(200, json=mock_history_2)
)
respx.get(f"{COMFYUI_BASE}/api/view").mock(
return_value=httpx.Response(200, content=sample_image_bytes)
)
await generate_image(
prompt="a test",
count=2,
seed=42,
output_dir=str(tmp_path),
)
assert submitted_seeds == [42, 43], f"Expected [42, 43], got {submitted_seeds}"
@respx.mock
@pytest.mark.asyncio
async def test_generate_image_count_partial_failure_continues(
tmp_path, sample_image_bytes, queue_empty, monkeypatch
):
"""count=2 where first image fails → error in slot 1, second image succeeds in slot 2."""
monkeypatch.setattr(server, "IMAGE_OUTPUT_DIR", str(tmp_path))
mock_history_2 = {
"uuid-ok": {
"outputs": {"9": {"images": [{"filename": "img2.png", "subfolder": "", "type": "output"}]}},
"status": {"completed": True},
}
}
respx.post(f"{COMFYUI_BASE}/api/prompt").mock(
side_effect=[
httpx.Response(500, json={"error": "GPU OOM"}), # first fails
httpx.Response(200, json={"prompt_id": "uuid-ok"}), # second succeeds
]
)
respx.get(f"{COMFYUI_BASE}/api/queue").mock(
return_value=httpx.Response(200, json=queue_empty)
)
respx.get(f"{COMFYUI_BASE}/api/history/uuid-ok").mock(
return_value=httpx.Response(200, json=mock_history_2)
)
respx.get(f"{COMFYUI_BASE}/api/view").mock(
return_value=httpx.Response(200, content=sample_image_bytes)
)
result = await generate_image(
prompt="a test",
count=2,
seed=10,
output_dir=str(tmp_path),
)
# First: error TextContent only (no ImageContent)
# Second: [TextContent, ImageContent]
assert len(result) == 3
assert result[0].type == "text"
assert "500" in result[0].text or "error" in result[0].text.lower()
assert result[1].type == "text"
assert result[2].type == "image"
# Only 1 file saved (the successful one)
saved_files = list(tmp_path.glob("*.png"))
assert len(saved_files) == 1
# ---------------------------------------------------------------------------
# generate_image — existing tests (kept intact)
# ---------------------------------------------------------------------------
@respx.mock