76 lines
2.8 KiB
Python
76 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.core.config import settings
|
|
from app.db.session import SessionLocal
|
|
|
|
|
|
TRAINING_TABLES = (
|
|
"training_score_detail",
|
|
"training_record",
|
|
"training_submission",
|
|
"training_order",
|
|
"training_session",
|
|
"audit_logs",
|
|
)
|
|
|
|
|
|
def clear_training_runtime_data(clear_reports: bool = False) -> dict:
|
|
"""训练数据清理:只清空训练运行表和本地报告文件,不删除病例、用户、评分规则和知识库。"""
|
|
with SessionLocal() as db:
|
|
before = {table: _count(db, table) for table in TRAINING_TABLES}
|
|
for table in TRAINING_TABLES:
|
|
db.execute(text(f"DELETE FROM {table}"))
|
|
for table in TRAINING_TABLES:
|
|
db.execute(text(f"ALTER TABLE {table} AUTO_INCREMENT = 1"))
|
|
db.commit()
|
|
after = {table: _count(db, table) for table in TRAINING_TABLES}
|
|
|
|
deleted_reports = _clear_reports() if clear_reports else 0
|
|
return {
|
|
"database": settings.database_url.split("@")[-1] if "@" in settings.database_url else settings.database_url,
|
|
"tables_before": before,
|
|
"tables_after": after,
|
|
"deleted_report_files": deleted_reports,
|
|
}
|
|
|
|
|
|
def _count(db, table: str) -> int:
|
|
"""数据计数:读取目标表当前行数,用于清理前后核对。"""
|
|
return int(db.execute(text(f"SELECT COUNT(*) FROM {table}")).scalar() or 0)
|
|
|
|
|
|
def _clear_reports() -> int:
|
|
"""报告清理:只删除 backend/storage/reports 下的文件,保留目录本身。"""
|
|
report_dir = Path(settings.report_storage_dir)
|
|
if not report_dir.is_absolute():
|
|
report_dir = Path(__file__).resolve().parents[1] / report_dir
|
|
expected_root = Path(__file__).resolve().parents[1] / "storage" / "reports"
|
|
report_dir = report_dir.resolve()
|
|
if report_dir != expected_root.resolve():
|
|
raise RuntimeError(f"refuse to clear unexpected report directory: {report_dir}")
|
|
report_dir.mkdir(parents=True, exist_ok=True)
|
|
files = [path for path in report_dir.iterdir() if path.is_file()]
|
|
for path in files:
|
|
path.unlink()
|
|
return len(files)
|
|
|
|
|
|
def main() -> None:
|
|
"""命令入口:要求显式确认后才执行训练数据清理。"""
|
|
parser = argparse.ArgumentParser(description="Clear training runtime data only.")
|
|
parser.add_argument("--confirm", required=True, help="Must be CLEAR_TRAINING_DATA")
|
|
parser.add_argument("--reports", action="store_true", help="Also clear local generated PDF reports")
|
|
args = parser.parse_args()
|
|
if args.confirm != "CLEAR_TRAINING_DATA":
|
|
raise SystemExit("confirmation mismatch; use --confirm CLEAR_TRAINING_DATA")
|
|
print(clear_training_runtime_data(clear_reports=args.reports))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|