chore: initialize medical consultation agent demo
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""数据访问层:封装 ORM 查询和持久化。"""
|
||||
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
|
||||
class AuditRepository:
|
||||
"""审计仓储:负责写入关键业务动作的审计日志。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def create(self, log: AuditLog) -> AuditLog:
|
||||
"""审计写入:保存一条审计日志并刷新主键。"""
|
||||
self.db.add(log)
|
||||
self.db.flush()
|
||||
return log
|
||||
@@ -0,0 +1,155 @@
|
||||
from sqlalchemy import delete, exists, func, or_, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.models.source_case import CaseBase, CaseExamItem, ScoringRule, TeachingCase, TraditionalCase
|
||||
from app.models.training import SessionOrder, SessionSubmission, TrainingSession
|
||||
from app.models.training_record import TrainingRecord
|
||||
|
||||
|
||||
class CaseRepository:
|
||||
"""病例仓储:基于 case_base 新表体系读取病例、扩展表和检查项目。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_active_cases(
|
||||
self,
|
||||
department_id: int | None = None,
|
||||
training_type: str | None = None,
|
||||
mode: str | None = None,
|
||||
) -> list[CaseBase]:
|
||||
"""病例列表:从 case_base 读取已发布病例,并按模式匹配传统/教学扩展表。"""
|
||||
stmt = (
|
||||
select(CaseBase)
|
||||
.options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case))
|
||||
.where(CaseBase.status == 1, CaseBase.publish_status == 1)
|
||||
)
|
||||
if department_id:
|
||||
stmt = stmt.where(CaseBase.department_id == department_id)
|
||||
if training_type:
|
||||
stmt = stmt.where(CaseBase.case_type == training_type)
|
||||
if mode == "practice":
|
||||
stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id))
|
||||
if mode == "teaching":
|
||||
stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id))
|
||||
return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all())
|
||||
|
||||
def get_active_case(self, case_id: int) -> CaseBase | None:
|
||||
"""病例详情:读取 case_base,并加载传统病例、教学病例、评分规则和检查项目。"""
|
||||
stmt = (
|
||||
select(CaseBase)
|
||||
.options(
|
||||
selectinload(CaseBase.traditional_case),
|
||||
selectinload(CaseBase.teaching_case),
|
||||
selectinload(CaseBase.scoring_rules),
|
||||
selectinload(CaseBase.exam_items),
|
||||
)
|
||||
.where(CaseBase.id == case_id, CaseBase.status == 1, CaseBase.publish_status == 1)
|
||||
)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def get_case_by_id(self, case_id: int) -> CaseBase | None:
|
||||
"""病例删除:按主键读取病例,不限制发布状态,用于删除前校验。"""
|
||||
return self.db.get(CaseBase, case_id)
|
||||
|
||||
def get_delete_preview_counts(self, case_id: int) -> dict[str, int]:
|
||||
"""病例删除预览:统计删除病例会影响的业务数据数量。"""
|
||||
session_ids = self._session_ids(case_id)
|
||||
return {
|
||||
"case_base": self._count(CaseBase, CaseBase.id == case_id),
|
||||
"traditional_case": self._count(TraditionalCase, TraditionalCase.case_id == case_id),
|
||||
"teaching_case": self._count(TeachingCase, TeachingCase.case_id == case_id),
|
||||
"scoring_rule": self._count(ScoringRule, ScoringRule.case_id == case_id),
|
||||
"case_exam_item": self._count(CaseExamItem, CaseExamItem.case_id == case_id),
|
||||
"training_session": len(session_ids),
|
||||
"training_order": self._count_training_orders(case_id, session_ids),
|
||||
"training_submission": self._count_by_sessions(SessionSubmission, SessionSubmission.session_id, session_ids),
|
||||
"training_record": self._count_training_records(case_id, session_ids),
|
||||
}
|
||||
|
||||
def delete_case_cascade(self, case_id: int) -> dict[str, int]:
|
||||
"""病例删除执行:按外键依赖顺序清理训练数据、检查项、评分规则和病例主表。"""
|
||||
session_ids = self._session_ids(case_id)
|
||||
deleted: dict[str, int] = {}
|
||||
deleted["training_order"] = self._delete_training_orders(case_id, session_ids)
|
||||
deleted["training_submission"] = self._delete_by_sessions(
|
||||
SessionSubmission, SessionSubmission.session_id, session_ids
|
||||
)
|
||||
deleted["training_record"] = self._delete_training_records(case_id, session_ids)
|
||||
deleted["training_session"] = self._delete_where(TrainingSession, TrainingSession.case_id == case_id)
|
||||
deleted["case_exam_item"] = self._delete_where(CaseExamItem, CaseExamItem.case_id == case_id)
|
||||
deleted["scoring_rule"] = self._delete_where(ScoringRule, ScoringRule.case_id == case_id)
|
||||
deleted["traditional_case"] = self._delete_where(TraditionalCase, TraditionalCase.case_id == case_id)
|
||||
deleted["teaching_case"] = self._delete_where(TeachingCase, TeachingCase.case_id == case_id)
|
||||
deleted["case_base"] = self._delete_where(CaseBase, CaseBase.id == case_id)
|
||||
return deleted
|
||||
|
||||
def _session_ids(self, case_id: int) -> list[int]:
|
||||
"""病例删除:读取该病例关联的训练会话 ID 集合。"""
|
||||
stmt = select(TrainingSession.id).where(TrainingSession.case_id == case_id)
|
||||
return [int(item) for item in self.db.scalars(stmt).all()]
|
||||
|
||||
def _count(self, model: type, *criteria) -> int:
|
||||
"""病例删除预览:按条件统计单表记录数。"""
|
||||
stmt = select(func.count()).select_from(model).where(*criteria)
|
||||
return int(self.db.scalar(stmt) or 0)
|
||||
|
||||
def _count_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int:
|
||||
"""病例删除预览:按训练会话集合统计从表记录数。"""
|
||||
if not session_ids:
|
||||
return 0
|
||||
return self._count(model, session_column.in_(session_ids))
|
||||
|
||||
def _count_training_orders(self, case_id: int, session_ids: list[int]) -> int:
|
||||
"""病例删除预览:统计检查申请记录,兼容按病例和按会话两种关联。"""
|
||||
if session_ids:
|
||||
return self._count(SessionOrder, or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids)))
|
||||
return self._count(SessionOrder, SessionOrder.case_id == case_id)
|
||||
|
||||
def _count_training_records(self, case_id: int, session_ids: list[int]) -> int:
|
||||
"""病例删除预览:统计完整训练记录,兼容按病例和按会话两种关联。"""
|
||||
if session_ids:
|
||||
return self._count(
|
||||
TrainingRecord,
|
||||
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)),
|
||||
)
|
||||
return self._count(TrainingRecord, TrainingRecord.case_id == case_id)
|
||||
|
||||
def _delete_where(self, model: type, *criteria) -> int:
|
||||
"""病例删除执行:按条件删除单表记录并返回影响行数。"""
|
||||
result = self.db.execute(delete(model).where(*criteria))
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
def _delete_by_sessions(self, model: type, session_column, session_ids: list[int]) -> int:
|
||||
"""病例删除执行:按训练会话集合删除从表记录。"""
|
||||
if not session_ids:
|
||||
return 0
|
||||
return self._delete_where(model, session_column.in_(session_ids))
|
||||
|
||||
def _delete_training_orders(self, case_id: int, session_ids: list[int]) -> int:
|
||||
"""病例删除执行:删除该病例下所有检查申请记录,避免阻塞检查项删除。"""
|
||||
if session_ids:
|
||||
return self._delete_where(
|
||||
SessionOrder,
|
||||
or_(SessionOrder.case_id == case_id, SessionOrder.session_id.in_(session_ids)),
|
||||
)
|
||||
return self._delete_where(SessionOrder, SessionOrder.case_id == case_id)
|
||||
|
||||
def _delete_training_records(self, case_id: int, session_ids: list[int]) -> int:
|
||||
"""病例删除执行:删除该病例完整训练后沉淀的评价记录。"""
|
||||
if session_ids:
|
||||
return self._delete_where(
|
||||
TrainingRecord,
|
||||
or_(TrainingRecord.case_id == case_id, TrainingRecord.session_id.in_(session_ids)),
|
||||
)
|
||||
return self._delete_where(TrainingRecord, TrainingRecord.case_id == case_id)
|
||||
|
||||
def get_exam_items(self, case_id: int) -> list[CaseExamItem]:
|
||||
"""检查项目:读取当前病例下全部可申请检查检验项目。"""
|
||||
stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id).order_by(CaseExamItem.display_order)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def get_exam_item(self, case_id: int, item_code: str) -> CaseExamItem | None:
|
||||
"""检查结果:按病例和项目编码读取固定检查检验结果。"""
|
||||
stmt = select(CaseExamItem).where(CaseExamItem.case_id == case_id, CaseExamItem.item_code == item_code)
|
||||
return self.db.scalar(stmt)
|
||||
@@ -0,0 +1,46 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.training_record import TrainingRecord
|
||||
|
||||
|
||||
class EvaluationRepository:
|
||||
"""评价仓储:负责完整训练结束后的 training_record 读写。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def create_record(self, record: TrainingRecord) -> TrainingRecord:
|
||||
"""评价保存:把 AI 评分报告保存为训练记录。"""
|
||||
self.db.add(record)
|
||||
self.db.flush()
|
||||
return record
|
||||
|
||||
def get_by_session(self, session_id: int, user_id: str) -> TrainingRecord | None:
|
||||
"""评价读取:按会话 ID 和外部 user_id 查询训练记录。"""
|
||||
stmt = select(TrainingRecord).where(
|
||||
TrainingRecord.session_id == session_id,
|
||||
TrainingRecord.external_user_id == user_id,
|
||||
)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def get_owned_record(self, evaluation_id: int, user_id: str) -> TrainingRecord | None:
|
||||
"""评价归属校验:按训练记录 ID 和外部 user_id 查询记录。"""
|
||||
stmt = select(TrainingRecord).where(
|
||||
TrainingRecord.id == evaluation_id,
|
||||
TrainingRecord.external_user_id == user_id,
|
||||
)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def list_by_user(self, user_id: str) -> list[TrainingRecord]:
|
||||
"""历史评价:按外部 user_id 查询完整训练后的评价记录。"""
|
||||
stmt = (
|
||||
select(TrainingRecord)
|
||||
.where(TrainingRecord.external_user_id == user_id)
|
||||
.order_by(TrainingRecord.created_at.desc())
|
||||
)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def flush(self) -> None:
|
||||
"""记录更新:刷新 PDF 路径等派生字段。"""
|
||||
self.db.flush()
|
||||
@@ -0,0 +1,34 @@
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.models.knowledge import KnowledgeChunk
|
||||
|
||||
|
||||
class KnowledgeRepository:
|
||||
"""知识库仓储:负责评分参考指南的轻量检索。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def search_chunks(
|
||||
self,
|
||||
department_id: int,
|
||||
task_type: str,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
) -> list[KnowledgeChunk]:
|
||||
"""知识检索:按科室、任务类型和关键词检索知识片段。"""
|
||||
stmt = (
|
||||
select(KnowledgeChunk)
|
||||
.options(selectinload(KnowledgeChunk.document))
|
||||
.where(KnowledgeChunk.is_active.is_(True))
|
||||
.where(or_(KnowledgeChunk.department_id == department_id, KnowledgeChunk.department_id.is_(None)))
|
||||
.where(or_(KnowledgeChunk.task_type == task_type, KnowledgeChunk.task_type.is_(None)))
|
||||
)
|
||||
|
||||
keyword_clauses = [KnowledgeChunk.chunk_text.contains(keyword) for keyword in keywords if keyword]
|
||||
if keyword_clauses:
|
||||
stmt = stmt.where(or_(*keyword_clauses))
|
||||
|
||||
stmt = stmt.order_by(KnowledgeChunk.weight.desc(), KnowledgeChunk.id.asc()).limit(limit)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
@@ -0,0 +1,25 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.user import UserLearningProfile
|
||||
|
||||
|
||||
class UserLearningProfileRepository:
|
||||
"""学习档案仓储:维护用户训练评价聚合数据。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_profile(self, user_id: str, tenant_id: str | None) -> UserLearningProfile | None:
|
||||
"""档案读取:按 user_id 和 tenant_id 获取学习档案。"""
|
||||
stmt = select(UserLearningProfile).where(
|
||||
UserLearningProfile.user_id == user_id,
|
||||
UserLearningProfile.tenant_id == tenant_id,
|
||||
)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def save(self, profile: UserLearningProfile) -> UserLearningProfile:
|
||||
"""档案保存:创建或更新用户学习档案。"""
|
||||
self.db.add(profile)
|
||||
self.db.flush()
|
||||
return profile
|
||||
@@ -0,0 +1,67 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.models.training import SessionOrder, SessionSubmission, TrainingSession
|
||||
|
||||
|
||||
class SessionRepository:
|
||||
"""会话仓储:负责训练会话、检查申请和诊断治疗提交数据。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def create_session(self, session: TrainingSession) -> TrainingSession:
|
||||
"""会话创建:保存训练会话主记录。"""
|
||||
self.db.add(session)
|
||||
self.db.flush()
|
||||
return session
|
||||
|
||||
def get_owned_session(self, session_id: int, user_id: str) -> TrainingSession | None:
|
||||
"""会话归属校验:根据 session_id 和 user_id 查询会话。"""
|
||||
stmt = (
|
||||
select(TrainingSession)
|
||||
.options(
|
||||
selectinload(TrainingSession.case),
|
||||
selectinload(TrainingSession.orders),
|
||||
selectinload(TrainingSession.submission),
|
||||
)
|
||||
.where(TrainingSession.id == session_id, TrainingSession.user_id == user_id)
|
||||
)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def update_status(self, session: TrainingSession, status: str) -> TrainingSession:
|
||||
"""状态流转:更新训练会话阶段状态。"""
|
||||
session.status = status
|
||||
if status == "diagnosis":
|
||||
session.inquiry_completed_at = datetime.utcnow()
|
||||
if status == "completed":
|
||||
session.completed_at = datetime.utcnow()
|
||||
self.db.flush()
|
||||
return session
|
||||
|
||||
def create_order(self, order: SessionOrder) -> SessionOrder:
|
||||
"""检查申请保存:保存用户申请过的检查检验结果。"""
|
||||
self.db.add(order)
|
||||
self.db.flush()
|
||||
return order
|
||||
|
||||
def get_order_by_item(self, session_id: int, item_code: str) -> SessionOrder | None:
|
||||
"""检查申请读取:按会话和检查编码获取已申请结果,用于幂等返回。"""
|
||||
stmt = select(SessionOrder).where(
|
||||
SessionOrder.session_id == session_id,
|
||||
SessionOrder.item_code == item_code,
|
||||
)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def get_submission(self, session_id: int) -> SessionSubmission | None:
|
||||
"""提交读取:获取当前会话的诊断治疗提交记录。"""
|
||||
stmt = select(SessionSubmission).where(SessionSubmission.session_id == session_id)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def upsert_submission(self, submission: SessionSubmission) -> SessionSubmission:
|
||||
"""诊断治疗保存:创建或更新当前会话的提交记录。"""
|
||||
self.db.add(submission)
|
||||
self.db.flush()
|
||||
return submission
|
||||
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.models.department import Department
|
||||
from app.models.source_case import CaseBase, ScoringRule, TeachingCase, TraditionalCase
|
||||
|
||||
|
||||
class SourceCaseRepository:
|
||||
"""源库病例仓储:读取 case_base、traditional_case、teaching_case 和 scoring_rule。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_active_cases(
|
||||
self,
|
||||
department_id: int | None = None,
|
||||
case_type: str | None = None,
|
||||
mode: str | None = None,
|
||||
) -> list[CaseBase]:
|
||||
"""源库病例列表:按科室、病例分类和训练模式读取已发布病例。"""
|
||||
stmt = (
|
||||
select(CaseBase)
|
||||
.options(selectinload(CaseBase.traditional_case), selectinload(CaseBase.teaching_case))
|
||||
.where(CaseBase.status == 1, CaseBase.publish_status == 1)
|
||||
)
|
||||
if department_id:
|
||||
stmt = stmt.where(CaseBase.department_id == department_id)
|
||||
if case_type:
|
||||
stmt = stmt.where(CaseBase.case_type == case_type)
|
||||
normalized_mode = self.normalize_mode(mode)
|
||||
if normalized_mode == "practice":
|
||||
stmt = stmt.where(exists().where(TraditionalCase.case_id == CaseBase.id))
|
||||
if normalized_mode == "teaching":
|
||||
stmt = stmt.where(exists().where(TeachingCase.case_id == CaseBase.id))
|
||||
return list(self.db.scalars(stmt.order_by(CaseBase.id.desc())).all())
|
||||
|
||||
def get_active_case_base(self, case_id: int) -> CaseBase | None:
|
||||
"""源库病例详情:读取病例主表及传统/教学扩展表。"""
|
||||
stmt = (
|
||||
select(CaseBase)
|
||||
.options(
|
||||
selectinload(CaseBase.traditional_case),
|
||||
selectinload(CaseBase.teaching_case),
|
||||
selectinload(CaseBase.scoring_rules),
|
||||
)
|
||||
.where(CaseBase.id == case_id, CaseBase.status == 1, CaseBase.publish_status == 1)
|
||||
)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def get_department_name(self, department_id: int | None) -> str:
|
||||
"""科室名称:兼容当前 demo 的 departments 表,源库无科室表时返回空字符串。"""
|
||||
if not department_id:
|
||||
return ""
|
||||
department = self.db.scalar(select(Department).where(Department.id == department_id))
|
||||
return department.name if department else ""
|
||||
|
||||
def get_scoring_rules(self, case_id: int) -> list[ScoringRule]:
|
||||
"""评分规则:读取当前病例对应的基础评分细则。"""
|
||||
stmt = select(ScoringRule).where(ScoringRule.case_id == case_id).order_by(ScoringRule.id)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
@staticmethod
|
||||
def normalize_mode(mode: str | None) -> str | None:
|
||||
"""模式归一:旧 novice 请求按练习模式处理,第一版只暴露 practice/teaching。"""
|
||||
if mode == "novice":
|
||||
return "practice"
|
||||
return mode
|
||||
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.training_record import TrainingRecord
|
||||
|
||||
|
||||
class TrainingRecordRepository:
|
||||
"""训练记录仓储:完整训练结束后写入 training_record,未完成会话不沉淀长期记录。"""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_session(self, session_id: int) -> TrainingRecord | None:
|
||||
"""训练记录读取:按 session_id 保证评价接口重复调用时幂等。"""
|
||||
stmt = select(TrainingRecord).where(TrainingRecord.session_id == session_id)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def create_record(self, record: TrainingRecord) -> TrainingRecord:
|
||||
"""训练记录保存:写入源库兼容训练记录表。"""
|
||||
self.db.add(record)
|
||||
self.db.flush()
|
||||
return record
|
||||
Reference in New Issue
Block a user