init medical training project
This commit is contained in:
@@ -0,0 +1,95 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import jsonschema
|
||||
from django.core.cache import cache
|
||||
from pathlib import Path
|
||||
|
||||
from config.exceptions import AppError
|
||||
from . import deepseek_client
|
||||
from .pdf_reader import extract_text_from_pdfs
|
||||
from prompts.loader import load_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
audit = logging.getLogger('audit')
|
||||
|
||||
_SCHEMA_PATH = Path(__file__).resolve().parent.parent / 'schemas' / 'case_full.json'
|
||||
PARSE_RESULT_TTL = 300 # 5 minutes
|
||||
|
||||
|
||||
def parse_pdf(files, case_type: str, user) -> dict:
|
||||
"""C1: PDF 解析 → DeepSeek → 结构化数据(不落库,不含评分规则)。"""
|
||||
if case_type not in ('traditional', 'teaching'):
|
||||
raise AppError('CASE_TYPE_NOT_SUPPORTED', f'case_type 不支持: {case_type}', status_code=400)
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
text = extract_text_from_pdfs(files)
|
||||
|
||||
prompt_name = f'case_{case_type}_full'
|
||||
system_prompt, prompt_version = load_prompt(prompt_name)
|
||||
|
||||
result = deepseek_client.call_deepseek(system_prompt, text)
|
||||
data = result['data']
|
||||
|
||||
data.pop('scoring_rules', None)
|
||||
data.pop('stages', None)
|
||||
|
||||
data['case_type'] = case_type
|
||||
|
||||
_strip_unknown_fields(data)
|
||||
_validate_schema(data)
|
||||
|
||||
parse_id = uuid.uuid4().hex[:12]
|
||||
cache.set(f'parse_result:{parse_id}', json.dumps(data, ensure_ascii=False), PARSE_RESULT_TTL)
|
||||
|
||||
source = {
|
||||
'files': [f.name for f in files],
|
||||
'total_bytes': sum(f.size for f in files),
|
||||
}
|
||||
|
||||
audit.info(
|
||||
'CASE_PARSE user=%s files=%d parse_id=%s tokens=%s prompt_version=%s',
|
||||
user.id, len(files), parse_id,
|
||||
result.get('usage', {}), prompt_version,
|
||||
)
|
||||
|
||||
return {
|
||||
'parse_id': parse_id,
|
||||
'case_type': case_type,
|
||||
'source': source,
|
||||
'ai_usage': result.get('usage', {}),
|
||||
'prompt_version': prompt_version,
|
||||
'parsing_seconds': round(time.time() - t0, 1),
|
||||
'data': data,
|
||||
}
|
||||
|
||||
|
||||
_SCHEMA_ALLOWED_KEYS = {
|
||||
'title', 'case_type', 'difficulty', 'chief_complaint', 'description',
|
||||
'patient_age', 'patient_gender', 'tags', 'symptom_tags', 'disease_tags',
|
||||
'competency_tags', 'guideline_tags', 'knowledge_points', 'icd_codes',
|
||||
'estimated_minutes', 'osce_enabled', 'department_name',
|
||||
'traditional', 'teaching',
|
||||
}
|
||||
|
||||
|
||||
def _strip_unknown_fields(data):
|
||||
for key in list(data.keys()):
|
||||
if key not in _SCHEMA_ALLOWED_KEYS:
|
||||
data.pop(key)
|
||||
|
||||
|
||||
def _validate_schema(data):
|
||||
schema = json.loads(_SCHEMA_PATH.read_text(encoding='utf-8'))
|
||||
try:
|
||||
jsonschema.validate(instance=data, schema=schema)
|
||||
except jsonschema.ValidationError as e:
|
||||
logger.error('AI parse output schema violation: %s', e.message)
|
||||
raise AppError(
|
||||
'AI_SCHEMA_VIOLATION',
|
||||
f'AI 输出字段不合法: {e.message}',
|
||||
status_code=500,
|
||||
)
|
||||
@@ -0,0 +1,65 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from django.conf import settings
|
||||
from openai import OpenAI, APITimeoutError, APIConnectionError, APIStatusError
|
||||
|
||||
from config.exceptions import AppError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_client():
|
||||
return OpenAI(
|
||||
api_key=settings.DEEPSEEK_API_KEY,
|
||||
base_url=settings.DEEPSEEK_BASE_URL,
|
||||
timeout=settings.DEEPSEEK_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def call_deepseek(system_prompt: str, user_content: str) -> dict:
|
||||
"""调用 DeepSeek,返回解析后的 JSON dict + usage 信息。
|
||||
|
||||
自带 1 次重试:首次失败时将错误信息附给第二次调用。
|
||||
"""
|
||||
client = get_client()
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
last_error = None
|
||||
for attempt in range(1 + settings.DEEPSEEK_MAX_RETRIES):
|
||||
if attempt > 0 and last_error:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"上一次输出不是合法 JSON,错误:{last_error}。请严格输出合法 JSON。",
|
||||
})
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=settings.DEEPSEEK_MODEL,
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.3,
|
||||
)
|
||||
except APITimeoutError:
|
||||
raise AppError('AI_TIMEOUT', 'DeepSeek 请求超时', status_code=504)
|
||||
except (APIConnectionError, APIStatusError) as e:
|
||||
logger.error('DeepSeek API error: %s', e)
|
||||
raise AppError('AI_PROVIDER_ERROR', f'DeepSeek 服务异常: {e}', status_code=502)
|
||||
|
||||
raw = response.choices[0].message.content
|
||||
usage = {
|
||||
'prompt_tokens': response.usage.prompt_tokens,
|
||||
'completion_tokens': response.usage.completion_tokens,
|
||||
} if response.usage else {}
|
||||
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
return {'data': parsed, 'usage': usage}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
last_error = str(e)
|
||||
logger.warning('DeepSeek JSON parse failed (attempt %d): %s', attempt + 1, e)
|
||||
|
||||
raise AppError('AI_BAD_JSON', f'AI 返回非合法 JSON(重试后仍失败): {last_error}', status_code=500)
|
||||
@@ -0,0 +1,32 @@
|
||||
from apps.user.models import Department
|
||||
from config.exceptions import AppError
|
||||
|
||||
|
||||
def resolve_department(department_name: str):
|
||||
"""按名称解析科室,返回 Department 实例。
|
||||
|
||||
- 空/None → 返回 None(不强制)
|
||||
- 精确匹配 1 条 → 返回该 Department
|
||||
- 匹配 0 条 → 400 CASE_DEPARTMENT_NOT_FOUND
|
||||
- 匹配多条 → 400 CASE_DEPARTMENT_AMBIGUOUS
|
||||
"""
|
||||
if not department_name:
|
||||
return None
|
||||
|
||||
qs = Department.objects.filter(name=department_name)
|
||||
count = qs.count()
|
||||
|
||||
if count == 0:
|
||||
raise AppError(
|
||||
'CASE_DEPARTMENT_NOT_FOUND',
|
||||
f'科室 "{department_name}" 不存在',
|
||||
status_code=400,
|
||||
)
|
||||
if count > 1:
|
||||
raise AppError(
|
||||
'CASE_DEPARTMENT_AMBIGUOUS',
|
||||
f'科室 "{department_name}" 匹配到多条记录,请精确指定',
|
||||
details={'matches': list(qs.values_list('name', flat=True))},
|
||||
status_code=400,
|
||||
)
|
||||
return qs.first()
|
||||
@@ -0,0 +1,50 @@
|
||||
import logging
|
||||
|
||||
import pdfplumber
|
||||
|
||||
from config.exceptions import AppError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_FILES = 5
|
||||
MAX_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
MAX_TOTAL_SIZE = 60 * 1024 * 1024 # 60 MB
|
||||
FILE_BREAK = '\n\n---FILE_BREAK: {name}---\n\n'
|
||||
|
||||
|
||||
def extract_text_from_pdfs(files) -> str:
|
||||
"""从 1~5 份 PDF 中提取文本,拼接返回。
|
||||
|
||||
files: request.FILES.getlist('files') 或类似 UploadedFile 列表。
|
||||
"""
|
||||
if not files:
|
||||
raise AppError('CASE_PDF_EMPTY', '未上传 PDF 文件')
|
||||
if len(files) > MAX_FILES:
|
||||
raise AppError('CASE_TOO_MANY_FILES', f'最多上传 {MAX_FILES} 份 PDF', status_code=400)
|
||||
|
||||
total_size = 0
|
||||
for f in files:
|
||||
if f.size > MAX_FILE_SIZE:
|
||||
raise AppError('CASE_FILE_TOO_LARGE', f'单份 PDF 不得超过 {MAX_FILE_SIZE // (1024*1024)} MB', status_code=400)
|
||||
total_size += f.size
|
||||
if total_size > MAX_TOTAL_SIZE:
|
||||
raise AppError('CASE_FILE_TOO_LARGE', f'PDF 总大小不得超过 {MAX_TOTAL_SIZE // (1024*1024)} MB', status_code=400)
|
||||
|
||||
parts = []
|
||||
for f in files:
|
||||
text = _extract_single(f)
|
||||
if not text.strip():
|
||||
raise AppError('CASE_PDF_EMPTY', f'PDF "{f.name}" 无法提取文本(可能为扫描版)', status_code=400)
|
||||
parts.append(FILE_BREAK.format(name=f.name) + text if len(files) > 1 else text)
|
||||
|
||||
return ''.join(parts)
|
||||
|
||||
|
||||
def _extract_single(uploaded_file) -> str:
|
||||
try:
|
||||
with pdfplumber.open(uploaded_file) as pdf:
|
||||
pages = [page.extract_text() or '' for page in pdf.pages]
|
||||
return '\n'.join(pages)
|
||||
except Exception as e:
|
||||
logger.error('pdfplumber extract failed for %s: %s', uploaded_file.name, e)
|
||||
raise AppError('CASE_PDF_EMPTY', f'PDF "{uploaded_file.name}" 解析失败: {e}', status_code=400)
|
||||
@@ -0,0 +1,79 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import jsonschema
|
||||
from pathlib import Path
|
||||
|
||||
from config.exceptions import AppError
|
||||
from . import deepseek_client
|
||||
from prompts.loader import load_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SCHEMA_PATH = Path(__file__).resolve().parent.parent / 'schemas' / 'scoring_rules.json'
|
||||
|
||||
CONTEXT_FIELDS = [
|
||||
'title', 'case_type', 'chief_complaint', 'description',
|
||||
'patient_age', 'patient_gender', 'icd_codes',
|
||||
'symptom_tags', 'disease_tags', 'competency_tags',
|
||||
'guideline_tags', 'knowledge_points',
|
||||
]
|
||||
|
||||
TRADITIONAL_FIELDS = ['standard_diagnosis', 'standard_treatment', 'guideline_reference']
|
||||
TEACHING_FIELDS = ['teaching_goal', 'discussion_questions', 'teacher_guide', 'scoring_focus']
|
||||
|
||||
|
||||
def generate(case_data: dict) -> dict:
|
||||
"""从病例数据 JSON 生成评分规则列表(不写库)。
|
||||
|
||||
case_data: 前端传入的病例数据(来自 parse-pdf data 或表单)。
|
||||
返回 {"scoring_rules": [...], "usage": {...}, "prompt_version": "..."}
|
||||
"""
|
||||
case_type = case_data.get('case_type', '')
|
||||
if case_type not in ('traditional', 'teaching'):
|
||||
raise AppError('CASE_TYPE_NOT_SUPPORTED', f'case_type 不支持: {case_type}', status_code=400)
|
||||
|
||||
sub = case_data.get(case_type) or {}
|
||||
if not sub:
|
||||
raise AppError('CASE_SUBTYPE_REQUIRED', f'{case_type} 子表数据缺失,AI 无法生成评分规则', status_code=400)
|
||||
|
||||
system_prompt, prompt_version = load_prompt('case_scoring_rules')
|
||||
|
||||
context = {}
|
||||
for field in CONTEXT_FIELDS:
|
||||
val = case_data.get(field)
|
||||
if val is not None and val != '' and val != []:
|
||||
context[field] = val
|
||||
|
||||
sub_fields = TRADITIONAL_FIELDS if case_type == 'traditional' else TEACHING_FIELDS
|
||||
for field in sub_fields:
|
||||
val = sub.get(field)
|
||||
if val is not None and val != '':
|
||||
context[field] = val
|
||||
|
||||
result = deepseek_client.call_deepseek(system_prompt, json.dumps(context, ensure_ascii=False))
|
||||
|
||||
rules = result['data'].get('scoring_rules', [])
|
||||
_validate_schema(rules)
|
||||
|
||||
if not rules:
|
||||
raise AppError('AI_EMPTY_RESULT', 'AI 返回 scoring_rules 为空数组', status_code=500)
|
||||
|
||||
return {
|
||||
'scoring_rules': rules,
|
||||
'usage': result['usage'],
|
||||
'prompt_version': prompt_version,
|
||||
}
|
||||
|
||||
|
||||
def _validate_schema(rules):
|
||||
schema = json.loads(_SCHEMA_PATH.read_text(encoding='utf-8'))
|
||||
try:
|
||||
jsonschema.validate(instance={'scoring_rules': rules}, schema=schema)
|
||||
except jsonschema.ValidationError as e:
|
||||
logger.error('AI scoring rules schema violation: %s', e.message)
|
||||
raise AppError(
|
||||
'AI_SCHEMA_VIOLATION',
|
||||
f'AI 输出字段类型不合法: {e.message}',
|
||||
status_code=500,
|
||||
)
|
||||
Reference in New Issue
Block a user