make case catalog read-only

This commit is contained in:
刘金宝
2026-06-04 17:50:22 +08:00
parent b46e43aadc
commit 7f1803f9fa
15 changed files with 35 additions and 1268 deletions
+8 -119
View File
@@ -1,15 +1,16 @@
import os
import sys
from decimal import Decimal
import tempfile
from pathlib import Path
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_api_contract.db")
TEST_DB_PATH = Path(tempfile.gettempdir()) / "medical_agent_test_api_contract.db"
TEST_DB_PATH.unlink(missing_ok=True)
os.environ["DATABASE_URL"] = f"sqlite:///{TEST_DB_PATH.as_posix()}"
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
os.environ.setdefault("AUTH_USER_ME_URL", "http://django-user-center.test/api/user/users/me/")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
Path("storage").mkdir(exist_ok=True)
try:
from fastapi.testclient import TestClient
@@ -25,9 +26,6 @@ def run_api_contract_tests() -> None:
from app.main import app
from app.services.external_auth_service import AuthenticatedUser, ExternalAuthService
from app.db.session import SessionLocal
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TraditionalCase
from app.repositories.case_repository import CaseRepository
from scripts.init_demo_db import init_database
async def fake_authenticate(self, request): # noqa: ARG001
@@ -87,6 +85,10 @@ def run_api_contract_tests() -> None:
auth_me_operation = openapi_payload["paths"]["/api/v1/auth/me"]["get"]
assert any("HTTPBearer" in item for item in auth_me_operation.get("security", []))
assert "HTTPBearer" in openapi_payload["components"]["securitySchemes"]
assert "/api/v1/imports/case-sql/preview" not in openapi_payload["paths"]
assert "/api/v1/imports/case-sql/apply" not in openapi_payload["paths"]
assert "/api/v1/cases/{case_id}/delete-preview" not in openapi_payload["paths"]
assert "delete" not in openapi_payload["paths"]["/api/v1/cases/{case_id}"]
cases = client.get("/api/v1/cases", headers=headers)
assert cases.status_code == 200
@@ -158,119 +160,6 @@ def run_api_contract_tests() -> None:
assert llm_reason.json()["code"] == "OK"
assert "total_latency_ms" in llm_reason.json()["data"]
import_preview = client.post(
"/api/v1/imports/case-sql/preview",
headers=headers,
files={"file": ("bad.sql", b"not sql", "application/sql")},
)
assert import_preview.status_code == 200
assert import_preview.json()["code"] == "OK"
assert import_preview.json()["data"]["can_import"] is False
assert import_preview.json()["data"]["errors"]
import_apply = client.post(
"/api/v1/imports/case-sql/apply",
headers=headers,
files={"file": ("bad.sql", b"not sql", "application/sql")},
)
assert import_apply.status_code == 400
assert import_apply.json()["code"] == "CASE_SQL_IMPORT_INVALID"
temp_case_id = 880001
with SessionLocal() as db:
CaseRepository(db).delete_case_cascade(temp_case_id)
db.add(
CaseBase(
id=temp_case_id,
title="删除测试病例",
case_type="diagnosis_treatment",
difficulty="medium",
chief_complaint="发热、咳嗽",
description="用于删除接口测试的临时病例",
patient_age=6,
patient_gender="male",
tags="test",
symptom_tags=[],
disease_tags=[],
competency_tags=[],
guideline_tags=[],
knowledge_points=[],
icd_codes="",
osce_enabled=False,
rag_enabled=False,
ai_prompt_template="",
multimodal_assets=[],
vector_status=0,
publish_status=1,
status=1,
department_id=1,
)
)
db.add(
TraditionalCase(
id=temp_case_id,
case_id=temp_case_id,
standard_diagnosis="测试诊断",
standard_treatment="测试治疗",
guideline_reference="测试指南",
)
)
db.add(
ScoringRule(
id=temp_case_id,
case_id=temp_case_id,
dimension="信息采集",
competency_dimension="问诊完整性",
score_weight=Decimal("10.00"),
ai_auto_score=True,
osce_dimension=False,
scoring_standard="测试评分标准",
rubric_json={},
)
)
db.add(
CaseExamItem(
id=temp_case_id,
case_id=temp_case_id,
item_code="temp_exam",
item_name="测试检查",
item_type="lab",
result_text="测试结果",
result_structured={},
is_key=True,
is_abnormal=False,
score_weight=Decimal("1.00"),
display_order=1,
)
)
db.commit()
delete_preview = client.get(f"/api/v1/cases/{temp_case_id}/delete-preview", headers=headers)
assert delete_preview.status_code == 200
assert delete_preview.json()["data"]["affected"]["case_base"] == 1
assert delete_preview.json()["data"]["affected"]["case_exam_item"] == 1
delete_without_confirm = client.request(
"DELETE",
f"/api/v1/cases/{temp_case_id}",
headers=headers,
json={"confirm": False, "delete_training_data": True},
)
assert delete_without_confirm.status_code == 400
assert delete_without_confirm.json()["code"] == "CASE_DELETE_CONFIRM_REQUIRED"
delete_case = client.request(
"DELETE",
f"/api/v1/cases/{temp_case_id}",
headers=headers,
json={"confirm": True, "delete_training_data": True},
)
assert delete_case.status_code == 200
assert delete_case.json()["data"]["deleted_counts"]["case_base"] == 1
deleted_detail = client.get(f"/api/v1/cases/{temp_case_id}", headers=headers)
assert deleted_detail.status_code == 404
assert deleted_detail.json()["code"] == "CASE_NOT_FOUND"
if __name__ == "__main__":
+5 -3
View File
@@ -1,15 +1,17 @@
import asyncio
import os
import sys
import tempfile
from pathlib import Path
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_demo_flow.db")
TEST_DB_PATH = Path(tempfile.gettempdir()) / "medical_agent_test_demo_flow.db"
TEST_DB_PATH.unlink(missing_ok=True)
os.environ["DATABASE_URL"] = f"sqlite:///{TEST_DB_PATH.as_posix()}"
os.environ["REPORT_STORAGE_DIR"] = str(Path(tempfile.gettempdir()) / "medical_agent_test_reports")
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
if os.getenv("DATABASE_URL") == "sqlite:///./storage/test_demo_flow.db":
Path("storage/test_demo_flow.db").unlink(missing_ok=True)
from sqlalchemy import select
-75
View File
@@ -1,75 +0,0 @@
import tempfile
from pathlib import Path
import os
import sys
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_import.db")
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from scripts.import_source_case_sql import ImportValidationError, extract_create_columns, extract_insert_rows, parse_values_clause
def test_extract_insert_rows_maps_by_create_columns() -> None:
"""导入解析:验证 INSERT VALUES 按 CREATE TABLE 字段顺序映射。"""
sql = """
CREATE TABLE `traditional_case` (
`created_at` datetime(6) NOT NULL,
`updated_at` datetime(6) NOT NULL,
`id` bigint NOT NULL AUTO_INCREMENT,
`standard_diagnosis` longtext NOT NULL,
`standard_treatment` longtext NOT NULL,
`guideline_reference` longtext NOT NULL,
`case_id` bigint NOT NULL
) ENGINE=InnoDB;
INSERT INTO `traditional_case` VALUES ('2026-05-28 09:34:08','2026-05-28 09:34:09',1,'支气管肺炎','抗感染','指南',1);
"""
columns = extract_create_columns(sql, "traditional_case")
rows = extract_insert_rows(sql, "traditional_case", columns)
assert rows[0]["standard_diagnosis"] == "支气管肺炎"
assert rows[0]["case_id"] == 1
def test_extract_insert_rows_rejects_broken_sql() -> None:
"""导入解析:验证损坏 SQL 被拒绝,避免半导入。"""
sql = """
CREATE TABLE `traditional_case` (
`created_at` datetime(6) NOT NULL,
`updated_at` datetime(6) NOT NULL,
`id` bigint NOT NULL AUTO_INCREMENT,
`standard_diagnosis` longtext NOT NULL,
`standard_treatment` longtext NOT NULL,
`guideline_reference` longtext NOT NULL,
`case_id` bigint NOT NULL
) ENGINE=InnoDB;
INSERT INTO `traditional_case` VALUES ('2026-05-28','2026-05-28',1,'支气管肺炎'bad,'抗感染','指南',1);
"""
columns = extract_create_columns(sql, "traditional_case")
try:
extract_insert_rows(sql, "traditional_case", columns)
except ImportValidationError:
return
raise AssertionError("broken SQL should be rejected")
def test_parse_values_clause_recovers_misplaced_quote_separator() -> None:
"""导入解析:兼容接口 SQL 中文本字段写成 `文本,'next'` 的引号/逗号错位。"""
rows = parse_values_clause("(1,'标题?,'traditional','medium',NULL,'主诉?,'描述?,6,'male')")
assert rows[0] == [1, "标题?", "traditional", "medium", None, "主诉?", "描述?", 6, "male"]
def test_parse_values_clause_recovers_unclosed_string_before_tuple_end() -> None:
"""导入解析:兼容接口 SQL 最后文本字段缺少结束引号且后接外键的情况。"""
rows = parse_values_clause("(1,'诊断?,'治疗?,'指南?,1)")
assert rows[0] == [1, "诊断?", "治疗?", "指南?", 1]
if __name__ == "__main__":
test_extract_insert_rows_maps_by_create_columns()
test_extract_insert_rows_rejects_broken_sql()
test_parse_values_clause_recovers_misplaced_quote_separator()
test_parse_values_clause_recovers_unclosed_string_before_tuple_end()
print("import source case sql tests passed")