122 lines
3.8 KiB
Python
122 lines
3.8 KiB
Python
|
|
"""Tests for QA router: /api/v1/qa/gen-text and /api/v1/qa/gen-image."""
|
||
|
|
import json
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock
|
||
|
|
|
||
|
|
from app.core.exceptions import LLMCallError, LLMParseError, StorageError
|
||
|
|
|
||
|
|
|
||
|
|
SAMPLE_QA_JSON = json.dumps([
|
||
|
|
{"question": "电缆接头位于哪里?", "answer": "配电箱左侧"},
|
||
|
|
])
|
||
|
|
|
||
|
|
FAKE_IMAGE_BYTES = b"\xff\xd8\xff\xe0fake_jpeg_content"
|
||
|
|
|
||
|
|
TEXT_QA_PAYLOAD = {
|
||
|
|
"items": [
|
||
|
|
{
|
||
|
|
"subject": "电缆接头",
|
||
|
|
"predicate": "位于",
|
||
|
|
"object": "配电箱左侧",
|
||
|
|
"source_snippet": "电缆接头位于配电箱左侧",
|
||
|
|
}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
IMAGE_QA_PAYLOAD = {
|
||
|
|
"items": [
|
||
|
|
{
|
||
|
|
"subject": "电缆接头",
|
||
|
|
"predicate": "位于",
|
||
|
|
"object": "配电箱左侧",
|
||
|
|
"cropped_image_path": "crops/1/0.jpg",
|
||
|
|
}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# POST /api/v1/qa/gen-text
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def test_gen_text_qa_returns_200(client, mock_llm):
|
||
|
|
mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON)
|
||
|
|
|
||
|
|
resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
data = resp.json()
|
||
|
|
assert "pairs" in data
|
||
|
|
assert len(data["pairs"]) == 1
|
||
|
|
assert data["pairs"][0]["question"] == "电缆接头位于哪里?"
|
||
|
|
assert data["pairs"][0]["answer"] == "配电箱左侧"
|
||
|
|
|
||
|
|
|
||
|
|
def test_gen_text_qa_llm_parse_error_returns_502(client, mock_llm):
|
||
|
|
mock_llm.chat = AsyncMock(return_value="not valid json {{")
|
||
|
|
|
||
|
|
resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD)
|
||
|
|
|
||
|
|
assert resp.status_code == 502
|
||
|
|
assert resp.json()["code"] == "LLM_PARSE_ERROR"
|
||
|
|
|
||
|
|
|
||
|
|
def test_gen_text_qa_llm_call_error_returns_503(client, mock_llm):
|
||
|
|
mock_llm.chat = AsyncMock(side_effect=LLMCallError("GLM timeout"))
|
||
|
|
|
||
|
|
resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD)
|
||
|
|
|
||
|
|
assert resp.status_code == 503
|
||
|
|
assert resp.json()["code"] == "LLM_CALL_ERROR"
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# POST /api/v1/qa/gen-image
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def test_gen_image_qa_returns_200(client, mock_llm, mock_storage):
|
||
|
|
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
|
||
|
|
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON)
|
||
|
|
|
||
|
|
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
data = resp.json()
|
||
|
|
assert "pairs" in data
|
||
|
|
assert len(data["pairs"]) == 1
|
||
|
|
pair = data["pairs"][0]
|
||
|
|
assert pair["question"] == "电缆接头位于哪里?"
|
||
|
|
assert pair["answer"] == "配电箱左侧"
|
||
|
|
assert pair["image_path"] == "crops/1/0.jpg"
|
||
|
|
|
||
|
|
|
||
|
|
def test_gen_image_qa_llm_parse_error_returns_502(client, mock_llm, mock_storage):
|
||
|
|
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
|
||
|
|
mock_llm.chat_vision = AsyncMock(return_value="bad json {{")
|
||
|
|
|
||
|
|
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
|
||
|
|
|
||
|
|
assert resp.status_code == 502
|
||
|
|
assert resp.json()["code"] == "LLM_PARSE_ERROR"
|
||
|
|
|
||
|
|
|
||
|
|
def test_gen_image_qa_llm_call_error_returns_503(client, mock_llm, mock_storage):
|
||
|
|
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
|
||
|
|
mock_llm.chat_vision = AsyncMock(side_effect=LLMCallError("GLM vision timeout"))
|
||
|
|
|
||
|
|
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
|
||
|
|
|
||
|
|
assert resp.status_code == 503
|
||
|
|
assert resp.json()["code"] == "LLM_CALL_ERROR"
|
||
|
|
|
||
|
|
|
||
|
|
def test_gen_image_qa_storage_error_returns_502(client, mock_storage):
|
||
|
|
mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down"))
|
||
|
|
|
||
|
|
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
|
||
|
|
|
||
|
|
assert resp.status_code == 502
|
||
|
|
assert resp.json()["code"] == "STORAGE_ERROR"
|