Files
fastapi/tests/test_import_source_case_sql.py
T

76 lines
3.1 KiB
Python
Raw Normal View History

import tempfile
from pathlib import Path
import os
import sys
os.environ.setdefault("DATABASE_URL", "sqlite:///./storage/test_import.db")
os.environ.setdefault("RUNTIME_MEMORY_BACKEND", "memory")
os.environ.setdefault("LLM_MOCK_ENABLED", "true")
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from scripts.import_source_case_sql import ImportValidationError, extract_create_columns, extract_insert_rows, parse_values_clause
def test_extract_insert_rows_maps_by_create_columns() -> None:
"""导入解析:验证 INSERT VALUES 按 CREATE TABLE 字段顺序映射。"""
sql = """
CREATE TABLE `traditional_case` (
`created_at` datetime(6) NOT NULL,
`updated_at` datetime(6) NOT NULL,
`id` bigint NOT NULL AUTO_INCREMENT,
`standard_diagnosis` longtext NOT NULL,
`standard_treatment` longtext NOT NULL,
`guideline_reference` longtext NOT NULL,
`case_id` bigint NOT NULL
) ENGINE=InnoDB;
INSERT INTO `traditional_case` VALUES ('2026-05-28 09:34:08','2026-05-28 09:34:09',1,'支气管肺炎','抗感染','指南',1);
"""
columns = extract_create_columns(sql, "traditional_case")
rows = extract_insert_rows(sql, "traditional_case", columns)
assert rows[0]["standard_diagnosis"] == "支气管肺炎"
assert rows[0]["case_id"] == 1
def test_extract_insert_rows_rejects_broken_sql() -> None:
"""导入解析:验证损坏 SQL 被拒绝,避免半导入。"""
sql = """
CREATE TABLE `traditional_case` (
`created_at` datetime(6) NOT NULL,
`updated_at` datetime(6) NOT NULL,
`id` bigint NOT NULL AUTO_INCREMENT,
`standard_diagnosis` longtext NOT NULL,
`standard_treatment` longtext NOT NULL,
`guideline_reference` longtext NOT NULL,
`case_id` bigint NOT NULL
) ENGINE=InnoDB;
INSERT INTO `traditional_case` VALUES ('2026-05-28','2026-05-28',1,'支气管肺炎'bad,'抗感染','指南',1);
"""
columns = extract_create_columns(sql, "traditional_case")
try:
extract_insert_rows(sql, "traditional_case", columns)
except ImportValidationError:
return
raise AssertionError("broken SQL should be rejected")
def test_parse_values_clause_recovers_misplaced_quote_separator() -> None:
"""导入解析:兼容接口 SQL 中文本字段写成 `文本,'next'` 的引号/逗号错位。"""
rows = parse_values_clause("(1,'标题?,'traditional','medium',NULL,'主诉?,'描述?,6,'male')")
assert rows[0] == [1, "标题?", "traditional", "medium", None, "主诉?", "描述?", 6, "male"]
def test_parse_values_clause_recovers_unclosed_string_before_tuple_end() -> None:
"""导入解析:兼容接口 SQL 最后文本字段缺少结束引号且后接外键的情况。"""
rows = parse_values_clause("(1,'诊断?,'治疗?,'指南?,1)")
assert rows[0] == [1, "诊断?", "治疗?", "指南?", 1]
if __name__ == "__main__":
test_extract_insert_rows_maps_by_create_columns()
test_extract_insert_rows_rejects_broken_sql()
test_parse_values_clause_recovers_misplaced_quote_separator()
test_parse_values_clause_recovers_unclosed_string_before_tuple_end()
print("import source case sql tests passed")