326 lines
15 KiB
Python
326 lines
15 KiB
Python
import json
|
||
from datetime import datetime
|
||
from decimal import Decimal
|
||
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.agents.orchestrator import MedicalConsultationOrchestrator
|
||
from app.core.context import UserContext
|
||
from app.core.exceptions import AppError
|
||
from app.models.training_record import TrainingRecord, TrainingScoreDetail
|
||
from app.repositories.case_repository import CaseRepository
|
||
from app.repositories.evaluation_repository import EvaluationRepository
|
||
from app.repositories.session_repository import SessionRepository
|
||
from app.repositories.source_case_repository import SourceCaseRepository
|
||
from app.schemas.evaluation import (
|
||
CreateEvaluationRequest,
|
||
DimensionScore,
|
||
EvaluationDetailResponse,
|
||
EvaluationListItem,
|
||
EvaluationListResponse,
|
||
EvaluationResponse,
|
||
ScoreDetailItem,
|
||
)
|
||
from app.services.audit_service import AuditService
|
||
from app.services.knowledge_service import KnowledgeService
|
||
from app.services.runtime_memory import runtime_memory
|
||
|
||
|
||
class EvaluationService:
|
||
"""评价服务:基于病例、评分规则和作答过程生成 training_record 与评分明细。"""
|
||
|
||
def __init__(self, db: Session) -> None:
|
||
self.db = db
|
||
self.session_repo = SessionRepository(db)
|
||
self.case_repo = CaseRepository(db)
|
||
self.eval_repo = EvaluationRepository(db)
|
||
self.source_repo = SourceCaseRepository(db)
|
||
self.knowledge = KnowledgeService(db)
|
||
self.audit = AuditService(db)
|
||
self.orchestrator = MedicalConsultationOrchestrator()
|
||
|
||
async def create_evaluation(self, ctx: UserContext, session_id: int, payload: CreateEvaluationRequest) -> EvaluationResponse:
|
||
"""评价生成:读取会话短期 memory、提交内容、评分规则和指南后写入 training_record。"""
|
||
session = self.session_repo.get_owned_session(session_id, ctx.user_id)
|
||
if not session:
|
||
raise AppError("SESSION_NOT_FOUND", "session not found or not owned by current user", 404)
|
||
if session.status not in {"evaluating", "completed"}:
|
||
raise AppError("SESSION_STATUS_INVALID", "evaluation requires treatment submission", 400)
|
||
|
||
existed = self.eval_repo.get_by_session(session.id, ctx.user_id)
|
||
if existed:
|
||
return self._to_response(existed)
|
||
|
||
case = self.case_repo.get_active_case(session.case_id)
|
||
if not case:
|
||
raise AppError("CASE_NOT_FOUND", "case not found or inactive", 404)
|
||
submission = self.session_repo.get_submission(session.id)
|
||
if not submission or not submission.treatment_submitted_at:
|
||
raise AppError("TREATMENT_REQUIRED", "treatment submission is required", 400)
|
||
|
||
session.score_type = payload.score_type
|
||
memory_messages = runtime_memory.get_messages(session.memory_key)
|
||
keyword_seed = (case.key_symptoms or []) + (case.key_exams or []) + [case.diagnosis_primary or ""]
|
||
guideline_result = self.knowledge.search_guidelines(case.department_id, session.training_type, keyword_seed)
|
||
guideline_refs = guideline_result["source_refs"]
|
||
scoring_rules = self.source_repo.get_scoring_rules(case.id)
|
||
|
||
report = await self.orchestrator.evaluate(
|
||
session=session,
|
||
case=case,
|
||
memory_messages=memory_messages,
|
||
orders=session.orders,
|
||
submission=submission,
|
||
rubric=None,
|
||
guideline_refs=guideline_refs,
|
||
scoring_rules=scoring_rules,
|
||
)
|
||
|
||
record = self._build_training_record(ctx, session, case, submission, report, scoring_rules, guideline_result)
|
||
self.eval_repo.create_record(record)
|
||
self.eval_repo.replace_score_details(record.id, self._build_score_details(record.id, report, scoring_rules))
|
||
self.session_repo.update_status(session, "completed")
|
||
runtime_memory.release(session.memory_key)
|
||
self.audit.log(ctx, "evaluation.generate", "training_record", str(record.id), session.id)
|
||
return self._to_response(record)
|
||
|
||
def _build_training_record(
|
||
self,
|
||
ctx: UserContext,
|
||
session,
|
||
case,
|
||
submission,
|
||
report: dict,
|
||
scoring_rules: list,
|
||
guideline_result: dict,
|
||
) -> TrainingRecord:
|
||
"""训练记录写入:完整流程结束后把评分结果沉淀到 training_record。"""
|
||
end_time = datetime.utcnow()
|
||
start_time = session.started_at or session.created_at or end_time
|
||
duration_seconds = int((end_time - start_time).total_seconds()) if start_time else None
|
||
total_score = float(report.get("total_score") or 0)
|
||
structured = {
|
||
"score_type": report.get("score_type", session.score_type),
|
||
"total_score": total_score,
|
||
"dimension_scores": report.get("dimension_scores") or [],
|
||
"score_details": report.get("score_details") or [],
|
||
"errors": report.get("errors") or [],
|
||
"improvement_plan": report.get("improvement_plan") or [],
|
||
"evidence_summary": report.get("evidence_summary") or [],
|
||
"guideline_refs": report.get("guideline_refs") or [],
|
||
"overall_comment": report.get("overall_comment") or "",
|
||
"llm_model": report.get("_llm_model"),
|
||
"latency_metrics": report.get("_latency_metrics") or {},
|
||
}
|
||
return TrainingRecord(
|
||
training_mode=session.mode,
|
||
case_type=session.training_type,
|
||
start_time=start_time,
|
||
end_time=end_time,
|
||
duration_seconds=duration_seconds,
|
||
total_score=total_score,
|
||
ai_score=total_score,
|
||
teacher_score=None,
|
||
evaluation_level=self._evaluation_level(total_score, report.get("score_type", session.score_type)),
|
||
status="completed",
|
||
feedback=structured["overall_comment"],
|
||
thinking_chain=json.dumps(
|
||
{
|
||
"evidence_summary": structured["evidence_summary"],
|
||
"guideline_refs": structured["guideline_refs"],
|
||
"scoring_rule_count": len(scoring_rules),
|
||
},
|
||
ensure_ascii=False,
|
||
),
|
||
diagnosis_path=json.dumps(
|
||
{
|
||
"primary_diagnosis": submission.primary_diagnosis,
|
||
"differential_diagnoses": submission.differential_diagnoses or [],
|
||
"diagnosis_basis": submission.diagnosis_basis,
|
||
"standard_diagnosis": case.diagnosis_primary,
|
||
},
|
||
ensure_ascii=False,
|
||
),
|
||
wrong_points=structured["errors"],
|
||
missed_questions=[],
|
||
recommendation_result={"improvement_plan": structured["improvement_plan"]},
|
||
ai_feedback_structured=structured,
|
||
osce_station_score={},
|
||
interruption_count=0,
|
||
emotion_analysis={},
|
||
prompt_version="v1",
|
||
rag_context_version=self._rag_context_version(guideline_result),
|
||
case_id=case.id,
|
||
teacher_id=None,
|
||
user_id=self._numeric_user_id(ctx.user_id),
|
||
external_user_id=ctx.user_id,
|
||
session_id=session.id,
|
||
evaluation_record_id=None,
|
||
score_type=structured["score_type"],
|
||
pdf_file_path=None,
|
||
)
|
||
|
||
def _build_score_details(self, record_id: int, report: dict, scoring_rules: list) -> list[TrainingScoreDetail]:
|
||
"""评分明细写入:把 LLM 结构化评分结果映射到 training_score_detail。"""
|
||
raw_items = report.get("score_details") or report.get("dimension_scores") or []
|
||
rule_map = self._rule_map(scoring_rules)
|
||
details: list[TrainingScoreDetail] = []
|
||
for item in raw_items:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
dimension = str(item.get("dimension") or "综合表现")
|
||
matched_rule = self._match_rule(item, dimension, rule_map)
|
||
deducted_reason = item.get("deducted_reason")
|
||
if not deducted_reason:
|
||
deducted_reason = ";".join(str(value) for value in (item.get("deductions") or []) if value)
|
||
evidence = item.get("evidence_message_ids")
|
||
if evidence is None:
|
||
evidence = item.get("evidence") or []
|
||
details.append(
|
||
TrainingScoreDetail(
|
||
record_id=record_id,
|
||
rule_id=int(item.get("rule_id") or matched_rule.id) if matched_rule else None,
|
||
dimension=dimension[:50],
|
||
score=self._decimal_or_none(item.get("score")),
|
||
deducted_reason=deducted_reason or "",
|
||
evidence_message_ids=evidence if isinstance(evidence, list) else [evidence],
|
||
ai_confidence=self._decimal_or_none(item.get("ai_confidence") or 0.85),
|
||
comment=item.get("comment") or item.get("improvement") or "",
|
||
)
|
||
)
|
||
return details
|
||
|
||
def _rule_map(self, scoring_rules: list) -> dict[str, object]:
|
||
"""评分规则映射:按维度和能力维度建立匹配索引。"""
|
||
result = {}
|
||
for rule in scoring_rules:
|
||
for key in (getattr(rule, "dimension", ""), getattr(rule, "competency_dimension", "")):
|
||
if key:
|
||
result[str(key).strip()] = rule
|
||
return result
|
||
|
||
def _match_rule(self, item: dict, dimension: str, rule_map: dict[str, object]):
|
||
"""评分规则匹配:优先按 rule_id,其次按维度文本匹配 scoring_rule。"""
|
||
rule_id = item.get("rule_id")
|
||
if rule_id:
|
||
for rule in rule_map.values():
|
||
if getattr(rule, "id", None) == rule_id:
|
||
return rule
|
||
return rule_map.get(dimension) or rule_map.get(str(item.get("competency_dimension") or "").strip())
|
||
|
||
def _decimal_or_none(self, value: object) -> Decimal | None:
|
||
"""分数转换:将 LLM 返回值转换为 Decimal,异常时置空。"""
|
||
try:
|
||
return Decimal(str(value))
|
||
except Exception:
|
||
return None
|
||
|
||
def _evaluation_level(self, score: float, score_type: str) -> str:
|
||
"""评价等级:根据百分制或五分制总分生成训练记录等级。"""
|
||
normalized = score * 20 if score_type == "five_point" else score
|
||
if normalized >= 90:
|
||
return "excellent"
|
||
if normalized >= 80:
|
||
return "good"
|
||
if normalized >= 60:
|
||
return "pass"
|
||
return "needs_improvement"
|
||
|
||
def _rag_context_version(self, guideline_result: dict) -> str:
|
||
"""RAG 版本:记录评分时是否命中指南片段。"""
|
||
matched = guideline_result.get("matched_chunks") or []
|
||
return f"knowledge_chunks:{len(matched)}" if matched else "none"
|
||
|
||
def _numeric_user_id(self, user_id: str) -> int | None:
|
||
"""用户 ID 兼容:Django 返回的 id 写入 external_user_id,纯数字时同步写入源库 user_id。"""
|
||
return int(user_id) if str(user_id).isdigit() else None
|
||
|
||
def list_history(self, user_id: str) -> EvaluationListResponse:
|
||
"""历史评价:按 Django 用户中心 ID 查询完整训练后的 training_record。"""
|
||
records = self.eval_repo.list_by_user(user_id)
|
||
return EvaluationListResponse(
|
||
items=[
|
||
EvaluationListItem(
|
||
evaluation_id=record.id,
|
||
case_title=self._case_title(record.case_id),
|
||
score_type=record.score_type,
|
||
total_score=float(record.total_score or 0),
|
||
created_at=record.created_at,
|
||
pdf_exported=bool(record.pdf_file_path),
|
||
)
|
||
for record in records
|
||
]
|
||
)
|
||
|
||
def get_detail(self, evaluation_id: int, user_id: str) -> EvaluationDetailResponse:
|
||
"""评价详情:按 Django 用户中心 ID 校验归属并返回完整报告。"""
|
||
record = self.eval_repo.get_owned_record(evaluation_id, user_id)
|
||
if not record:
|
||
raise AppError("EVALUATION_NOT_FOUND", "evaluation not found or not owned by current user", 404)
|
||
base = self._to_response(record)
|
||
return EvaluationDetailResponse(
|
||
**base.model_dump(),
|
||
session_id=record.session_id or 0,
|
||
case_id=record.case_id,
|
||
case_title=self._case_title(record.case_id),
|
||
created_at=record.created_at,
|
||
pdf_file_path=record.pdf_file_path,
|
||
)
|
||
|
||
def _to_response(self, record: TrainingRecord) -> EvaluationResponse:
|
||
"""评价转换:把 training_record 转换为接口响应结构。"""
|
||
structured = record.ai_feedback_structured or {}
|
||
dimension_scores = structured.get("dimension_scores") or []
|
||
return EvaluationResponse(
|
||
evaluation_id=record.id,
|
||
score_type=record.score_type,
|
||
total_score=float(record.total_score or structured.get("total_score") or 0),
|
||
dimension_scores=[DimensionScore(**item) for item in dimension_scores],
|
||
score_details=self._score_detail_response(record),
|
||
errors=structured.get("errors") or record.wrong_points or [],
|
||
improvement_plan=structured.get("improvement_plan") or (record.recommendation_result or {}).get("improvement_plan") or [],
|
||
evidence_summary=structured.get("evidence_summary") or [],
|
||
guideline_refs=structured.get("guideline_refs") or [],
|
||
overall_comment=structured.get("overall_comment") or record.feedback or "",
|
||
)
|
||
|
||
def _score_detail_response(self, record: TrainingRecord) -> list[ScoreDetailItem]:
|
||
"""评分明细响应:优先读取 training_score_detail,旧记录回退到结构化维度评分。"""
|
||
details = self.eval_repo.list_score_details(record.id)
|
||
if details:
|
||
return [
|
||
ScoreDetailItem(
|
||
id=item.id,
|
||
record_id=item.record_id,
|
||
rule_id=item.rule_id,
|
||
dimension=item.dimension,
|
||
score=float(item.score) if item.score is not None else None,
|
||
deducted_reason=item.deducted_reason,
|
||
evidence_message_ids=item.evidence_message_ids or [],
|
||
ai_confidence=float(item.ai_confidence) if item.ai_confidence is not None else None,
|
||
comment=item.comment,
|
||
)
|
||
for item in details
|
||
]
|
||
structured = record.ai_feedback_structured or {}
|
||
return [
|
||
ScoreDetailItem(
|
||
record_id=record.id,
|
||
dimension=item.get("dimension", "综合表现"),
|
||
score=float(item.get("score") or 0),
|
||
deducted_reason=";".join(str(value) for value in item.get("deductions", []) if value),
|
||
evidence_message_ids=item.get("evidence") or [],
|
||
ai_confidence=None,
|
||
comment=item.get("comment") or "",
|
||
)
|
||
for item in structured.get("dimension_scores") or []
|
||
if isinstance(item, dict)
|
||
]
|
||
|
||
def _case_title(self, case_id: int | None) -> str:
|
||
"""病例标题:历史记录只保存 case_id,展示时按新病例主表读取标题。"""
|
||
if not case_id:
|
||
return ""
|
||
case = self.case_repo.get_active_case(case_id)
|
||
return case.title if case else ""
|