commit b4bb38b7be954920661cfc989fdbba5ba138979f Author: shihan11 Date: Fri May 29 15:58:00 2026 +0800 init medical training project diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..bd65738 --- /dev/null +++ b/.env.example @@ -0,0 +1,25 @@ +# ────────── 数据库 ────────── +DB_NAME=medical_training +DB_USER=root +DB_PASSWORD=your-db-password +DB_HOST=localhost +DB_PORT=3306 + +# ────────── Redis ────────── +REDIS_URL=redis://127.0.0.1:6379/1 + +# ────────── 短信 ────────── +SMS_PROVIDER=mock # mock | aliyun +ALIYUN_SMS_ACCESS_KEY_ID= +ALIYUN_SMS_ACCESS_KEY_SECRET= +ALIYUN_SMS_SIGN_NAME=医疗训练平台 +ALIYUN_SMS_TEMPLATE_REGISTER=SMS_xxx_001 +ALIYUN_SMS_TEMPLATE_LOGIN=SMS_xxx_002 +ALIYUN_SMS_TEMPLATE_RESET=SMS_xxx_003 + +# ────────── DeepSeek ────────── +DEEPSEEK_API_KEY=your-deepseek-api-key +DEEPSEEK_BASE_URL=https://api.deepseek.com +DEEPSEEK_MODEL=deepseek-chat +DEEPSEEK_TIMEOUT_SECONDS=120 +DEEPSEEK_MAX_RETRIES=1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a5e3a41 --- /dev/null +++ b/.gitignore @@ -0,0 +1,61 @@ +# 环境变量(含敏感信息,绝不入库) +.env +.env.local + +# Python +__pycache__/ +*.py[cod] +*.pyo +*.pyd +.Python +.pytest_cache/ +.coverage +htmlcov/ + +# 虚拟环境 +.venv/ +venv/ +env/ + +# 日志(保留 logs/ 目录结构时可提交 logs/.gitkeep) +logs/* +!logs/.gitkeep + +# 本地缓存 / 临时脚本 +.cache/ + +# Redis 运行时快照 +dump.rdb +*.rdb + +# 本地 Redis 二进制(请自行安装 Redis) +redis-server/ + +# 数据库导出 +*.sql + +# 分发 / 构建 +dist/ +build/ +*.egg-info/ +*.egg + +# 静态文件收集目录 +staticfiles/ +media/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# macOS +.DS_Store + +# Windows +Thumbs.db +Desktop.ini + +# 可选:样例 PDF 若含真实病例且不想入库,取消下行注释 +# *.pdf diff --git a/apps/__init__.py b/apps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/case/__init__.py b/apps/case/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/case/admin.py b/apps/case/admin.py new file mode 100644 index 0000000..139e373 --- /dev/null +++ b/apps/case/admin.py @@ -0,0 +1,50 @@ +from django.contrib import admin +from .models import ( + CaseBase, TraditionalCase, ScriptCase, + TeachingCase, CaseStage, ScoringRule +) + + +@admin.register(CaseBase) +class CaseBaseAdmin(admin.ModelAdmin): + list_display = [ + 'id', 'title', 'case_type', 'difficulty', + 'department', 'patient_age', 'patient_gender', + 'osce_enabled', 'publish_status', 'status', 'created_at' + ] + list_filter = ['case_type', 'difficulty', 'publish_status', 'status', 'osce_enabled'] + search_fields = ['title', 'chief_complaint', 'tags', 'icd_codes'] + ordering = ['-created_at'] + + +@admin.register(TraditionalCase) +class TraditionalCaseAdmin(admin.ModelAdmin): + list_display = ['id', 'case'] + search_fields = ['case__title'] + + +@admin.register(ScriptCase) +class ScriptCaseAdmin(admin.ModelAdmin): + list_display = ['id', 'case', 'emotional_state', 'cultural_level'] + search_fields = ['case__title'] + + +@admin.register(TeachingCase) +class TeachingCaseAdmin(admin.ModelAdmin): + list_display = ['id', 'case'] + search_fields = ['case__title'] + + +@admin.register(CaseStage) +class CaseStageAdmin(admin.ModelAdmin): + list_display = ['id', 'case', 'stage_name', 'stage_type', 'stage_mode', 'sort_order', 'timeout_seconds'] + list_filter = ['stage_type', 'stage_mode'] + search_fields = ['case__title', 'stage_name'] + ordering = ['case', 'sort_order'] + + +@admin.register(ScoringRule) +class ScoringRuleAdmin(admin.ModelAdmin): + list_display = ['id', 'case', 'dimension', 'competency_dimension', 'score_weight', 'ai_auto_score'] + list_filter = ['ai_auto_score', 'osce_dimension'] + search_fields = ['case__title', 'dimension'] diff --git a/apps/case/apps.py b/apps/case/apps.py new file mode 100644 index 0000000..2fcb621 --- /dev/null +++ b/apps/case/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class CaseConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'apps.case' + verbose_name = '病例管理' diff --git a/apps/case/migrations/0001_initial.py b/apps/case/migrations/0001_initial.py new file mode 100644 index 0000000..920a3c3 --- /dev/null +++ b/apps/case/migrations/0001_initial.py @@ -0,0 +1,146 @@ +# Generated by Django 6.0.5 on 2026-05-26 07:02 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='CaseBase', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('title', models.CharField(max_length=255, verbose_name='病例标题')), + ('case_type', models.CharField(choices=[('traditional', '传统病例'), ('script', '剧本病例'), ('teaching', '教学互动病例'), ('osce', 'OSCE')], max_length=30, verbose_name='病例类型')), + ('difficulty', models.CharField(blank=True, max_length=20, verbose_name='难度')), + ('difficulty_score', models.IntegerField(blank=True, null=True, verbose_name='AI难度评分')), + ('chief_complaint', models.TextField(blank=True, verbose_name='主诉')), + ('description', models.TextField(blank=True, verbose_name='病例简介')), + ('patient_age', models.IntegerField(blank=True, null=True, verbose_name='患者年龄')), + ('patient_gender', models.CharField(blank=True, max_length=10, verbose_name='患者性别')), + ('tags', models.CharField(blank=True, max_length=500, verbose_name='标签')), + ('symptom_tags', models.JSONField(blank=True, default=list, verbose_name='症状标签')), + ('disease_tags', models.JSONField(blank=True, default=list, verbose_name='疾病标签')), + ('competency_tags', models.JSONField(blank=True, default=list, verbose_name='能力标签')), + ('guideline_tags', models.JSONField(blank=True, default=list, verbose_name='指南标签')), + ('knowledge_points', models.JSONField(blank=True, default=list, verbose_name='知识点')), + ('icd_codes', models.CharField(blank=True, max_length=500, verbose_name='ICD编码')), + ('estimated_minutes', models.IntegerField(blank=True, null=True, verbose_name='预计训练时长')), + ('osce_enabled', models.BooleanField(default=False, verbose_name='是否OSCE')), + ('rag_enabled', models.BooleanField(default=False, verbose_name='是否启用知识增强')), + ('ai_prompt_template', models.TextField(blank=True, verbose_name='AI角色Prompt')), + ('multimodal_assets', models.JSONField(blank=True, default=dict, verbose_name='图片/影像/附件')), + ('vector_status', models.SmallIntegerField(default=0, verbose_name='是否向量化')), + ('publish_status', models.SmallIntegerField(choices=[(0, '草稿'), (1, '已发布'), (2, '已下架')], default=0, verbose_name='发布状态')), + ('status', models.SmallIntegerField(choices=[(0, '禁用'), (1, '正常')], default=1, verbose_name='状态')), + ], + options={ + 'verbose_name': '病例', + 'verbose_name_plural': '病例', + 'db_table': 'case_base', + }, + ), + migrations.CreateModel( + name='CaseStage', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('stage_type', models.CharField(blank=True, max_length=50, verbose_name='阶段类型')), + ('stage_name', models.CharField(max_length=100, verbose_name='阶段名称')), + ('stage_mode', models.CharField(choices=[('dialogue', '对话'), ('osce', 'OSCE'), ('choice', '选择')], default='dialogue', max_length=30, verbose_name='阶段模式')), + ('stage_goal', models.TextField(blank=True, verbose_name='阶段目标')), + ('ai_role_prompt', models.TextField(blank=True, verbose_name='AI阶段Prompt')), + ('standard_action', models.TextField(blank=True, verbose_name='标准动作')), + ('expected_questions', models.TextField(blank=True, verbose_name='期望问题')), + ('scoring_points', models.TextField(blank=True, verbose_name='评分点')), + ('timeout_seconds', models.IntegerField(blank=True, null=True, verbose_name='超时时间')), + ('unlock_condition', models.CharField(blank=True, max_length=255, verbose_name='解锁条件')), + ('sort_order', models.IntegerField(default=0, verbose_name='排序')), + ], + options={ + 'verbose_name': '病例阶段', + 'verbose_name_plural': '病例阶段', + 'db_table': 'case_stage', + 'ordering': ['sort_order', 'id'], + }, + ), + migrations.CreateModel( + name='ScoringRule', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('dimension', models.CharField(max_length=50, verbose_name='评分维度')), + ('competency_dimension', models.CharField(blank=True, max_length=50, verbose_name='能力维度')), + ('score_weight', models.DecimalField(decimal_places=2, default=1.0, max_digits=5, verbose_name='权重')), + ('ai_auto_score', models.BooleanField(default=False, verbose_name='AI自动评分')), + ('osce_dimension', models.BooleanField(default=False, verbose_name='是否OSCE')), + ('scoring_standard', models.TextField(blank=True, verbose_name='评分标准')), + ('rubric_json', models.JSONField(blank=True, default=dict, verbose_name='评分Rubric')), + ], + options={ + 'verbose_name': '评分规则', + 'verbose_name_plural': '评分规则', + 'db_table': 'scoring_rule', + }, + ), + migrations.CreateModel( + name='ScriptCase', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('scenario_setting', models.TextField(blank=True, verbose_name='场景设定')), + ('emotional_state', models.CharField(blank=True, max_length=50, verbose_name='情绪状态')), + ('cultural_level', models.CharField(blank=True, max_length=50, verbose_name='文化水平')), + ('branch_logic', models.TextField(blank=True, verbose_name='分支逻辑')), + ('hidden_clues', models.TextField(blank=True, verbose_name='隐藏线索')), + ], + options={ + 'verbose_name': '剧本病例', + 'verbose_name_plural': '剧本病例', + 'db_table': 'script_case', + }, + ), + migrations.CreateModel( + name='TeachingCase', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('teaching_goal', models.TextField(blank=True, verbose_name='教学目标')), + ('discussion_questions', models.TextField(blank=True, verbose_name='讨论问题')), + ('teacher_guide', models.TextField(blank=True, verbose_name='教师指南')), + ('scoring_focus', models.TextField(blank=True, verbose_name='评分重点')), + ], + options={ + 'verbose_name': '教学互动病例', + 'verbose_name_plural': '教学互动病例', + 'db_table': 'teaching_case', + }, + ), + migrations.CreateModel( + name='TraditionalCase', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('standard_diagnosis', models.TextField(blank=True, verbose_name='标准诊断')), + ('standard_treatment', models.TextField(blank=True, verbose_name='标准治疗')), + ('guideline_reference', models.TextField(blank=True, verbose_name='指南参考')), + ], + options={ + 'verbose_name': '传统病例', + 'verbose_name_plural': '传统病例', + 'db_table': 'traditional_case', + }, + ), + ] diff --git a/apps/case/migrations/0002_initial.py b/apps/case/migrations/0002_initial.py new file mode 100644 index 0000000..0126a2d --- /dev/null +++ b/apps/case/migrations/0002_initial.py @@ -0,0 +1,54 @@ +# Generated by Django 6.0.5 on 2026-05-26 07:02 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('case', '0001_initial'), + ('user', '0001_initial'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AddField( + model_name='casebase', + name='created_by', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL, verbose_name='创建人'), + ), + migrations.AddField( + model_name='casebase', + name='department', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='user.department', verbose_name='所属科室'), + ), + migrations.AddField( + model_name='casestage', + name='case', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='stages', to='case.casebase', verbose_name='病例'), + ), + migrations.AddField( + model_name='scoringrule', + name='case', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='scoring_rules', to='case.casebase', verbose_name='病例'), + ), + migrations.AddField( + model_name='scriptcase', + name='case', + field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to='case.casebase', verbose_name='病例'), + ), + migrations.AddField( + model_name='teachingcase', + name='case', + field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to='case.casebase', verbose_name='病例'), + ), + migrations.AddField( + model_name='traditionalcase', + name='case', + field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to='case.casebase', verbose_name='病例'), + ), + ] diff --git a/apps/case/migrations/__init__.py b/apps/case/migrations/__init__.py new file mode 100644 index 0000000..c65238f --- /dev/null +++ b/apps/case/migrations/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/apps/case/models.py b/apps/case/models.py new file mode 100644 index 0000000..6226432 --- /dev/null +++ b/apps/case/models.py @@ -0,0 +1,185 @@ +from django.db import models +from apps.common.models import BaseModel +from apps.user.models import User + + +class CaseBase(BaseModel): + """病例主表""" + CASE_TYPE_CHOICES = [ + ('traditional', '传统病例'), + ('script', '剧本病例'), + ('teaching', '教学互动病例'), + ('osce', 'OSCE'), + ] + STATUS_CHOICES = [ + (0, '禁用'), + (1, '正常'), + ] + PUBLISH_STATUS_CHOICES = [ + (0, '草稿'), + (1, '已发布'), + (2, '已下架'), + ] + + id = models.BigAutoField(primary_key=True) + title = models.CharField('病例标题', max_length=255) + case_type = models.CharField('病例类型', max_length=30, choices=CASE_TYPE_CHOICES) + difficulty = models.CharField('难度', max_length=20, blank=True) + difficulty_score = models.IntegerField('AI难度评分', null=True, blank=True) + department = models.ForeignKey( + 'user.Department', on_delete=models.SET_NULL, + null=True, blank=True, verbose_name='所属科室' + ) + chief_complaint = models.TextField('主诉', blank=True) + description = models.TextField('病例简介', blank=True) + patient_age = models.IntegerField('患者年龄', null=True, blank=True) + patient_gender = models.CharField('患者性别', max_length=10, blank=True) + tags = models.CharField('标签', max_length=500, blank=True) + symptom_tags = models.JSONField('症状标签', default=list, blank=True) + disease_tags = models.JSONField('疾病标签', default=list, blank=True) + competency_tags = models.JSONField('能力标签', default=list, blank=True) + guideline_tags = models.JSONField('指南标签', default=list, blank=True) + knowledge_points = models.JSONField('知识点', default=list, blank=True) + icd_codes = models.CharField('ICD编码', max_length=500, blank=True) + estimated_minutes = models.IntegerField('预计训练时长', null=True, blank=True) + osce_enabled = models.BooleanField('是否OSCE', default=False) + rag_enabled = models.BooleanField('是否启用知识增强', default=False) + ai_prompt_template = models.TextField('AI角色Prompt', blank=True) + multimodal_assets = models.JSONField('图片/影像/附件', default=dict, blank=True) + vector_status = models.SmallIntegerField('是否向量化', default=0) + created_by = models.ForeignKey( + User, on_delete=models.SET_NULL, + null=True, blank=True, verbose_name='创建人' + ) + publish_status = models.SmallIntegerField('发布状态', choices=PUBLISH_STATUS_CHOICES, default=0) + status = models.SmallIntegerField('状态', choices=STATUS_CHOICES, default=1) + + class Meta: + db_table = 'case_base' + verbose_name = '病例' + verbose_name_plural = '病例' + + def __str__(self): + return self.title + + +class TraditionalCase(BaseModel): + """传统病例表""" + id = models.BigAutoField(primary_key=True) + case = models.OneToOneField( + CaseBase, on_delete=models.CASCADE, + verbose_name='病例' + ) + standard_diagnosis = models.TextField('标准诊断', blank=True) + standard_treatment = models.TextField('标准治疗', blank=True) + guideline_reference = models.TextField('指南参考', blank=True) + + class Meta: + db_table = 'traditional_case' + verbose_name = '传统病例' + verbose_name_plural = '传统病例' + + def __str__(self): + return f"传统病例: {self.case.title}" + + +class ScriptCase(BaseModel): + """剧本病例表""" + id = models.BigAutoField(primary_key=True) + case = models.OneToOneField( + CaseBase, on_delete=models.CASCADE, + verbose_name='病例' + ) + scenario_setting = models.TextField('场景设定', blank=True) + emotional_state = models.CharField('情绪状态', max_length=50, blank=True) + cultural_level = models.CharField('文化水平', max_length=50, blank=True) + branch_logic = models.TextField('分支逻辑', blank=True) + hidden_clues = models.TextField('隐藏线索', blank=True) + + class Meta: + db_table = 'script_case' + verbose_name = '剧本病例' + verbose_name_plural = '剧本病例' + + def __str__(self): + return f"剧本病例: {self.case.title}" + + +class TeachingCase(BaseModel): + """教学互动病例表""" + id = models.BigAutoField(primary_key=True) + case = models.OneToOneField( + CaseBase, on_delete=models.CASCADE, + verbose_name='病例' + ) + teaching_goal = models.TextField('教学目标', blank=True) + discussion_questions = models.TextField('讨论问题', blank=True) + teacher_guide = models.TextField('教师指南', blank=True) + scoring_focus = models.TextField('评分重点', blank=True) + + class Meta: + db_table = 'teaching_case' + verbose_name = '教学互动病例' + verbose_name_plural = '教学互动病例' + + def __str__(self): + return f"教学病例: {self.case.title}" + + +class CaseStage(BaseModel): + """病例阶段表""" + STAGE_MODE_CHOICES = [ + ('dialogue', '对话'), + ('osce', 'OSCE'), + ('choice', '选择'), + ] + + id = models.BigAutoField(primary_key=True) + case = models.ForeignKey( + CaseBase, on_delete=models.CASCADE, + related_name='stages', verbose_name='病例' + ) + stage_type = models.CharField('阶段类型', max_length=50, blank=True) + stage_name = models.CharField('阶段名称', max_length=100) + stage_mode = models.CharField('阶段模式', max_length=30, choices=STAGE_MODE_CHOICES, default='dialogue') + stage_goal = models.TextField('阶段目标', blank=True) + ai_role_prompt = models.TextField('AI阶段Prompt', blank=True) + standard_action = models.TextField('标准动作', blank=True) + expected_questions = models.TextField('期望问题', blank=True) + scoring_points = models.TextField('评分点', blank=True) + timeout_seconds = models.IntegerField('超时时间', null=True, blank=True) + unlock_condition = models.CharField('解锁条件', max_length=255, blank=True) + sort_order = models.IntegerField('排序', default=0) + + class Meta: + db_table = 'case_stage' + verbose_name = '病例阶段' + verbose_name_plural = '病例阶段' + ordering = ['sort_order', 'id'] + + def __str__(self): + return f"{self.case.title} - {self.stage_name}" + + +class ScoringRule(BaseModel): + """评分规则表""" + id = models.BigAutoField(primary_key=True) + case = models.ForeignKey( + CaseBase, on_delete=models.CASCADE, + related_name='scoring_rules', verbose_name='病例' + ) + dimension = models.CharField('评分维度', max_length=50) + competency_dimension = models.CharField('能力维度', max_length=50, blank=True) + score_weight = models.DecimalField('权重', max_digits=5, decimal_places=2, default=1.00) + ai_auto_score = models.BooleanField('AI自动评分', default=False) + osce_dimension = models.BooleanField('是否OSCE', default=False) + scoring_standard = models.TextField('评分标准', blank=True) + rubric_json = models.JSONField('评分Rubric', default=dict, blank=True) + + class Meta: + db_table = 'scoring_rule' + verbose_name = '评分规则' + verbose_name_plural = '评分规则' + + def __str__(self): + return f"{self.case.title} - {self.dimension}" diff --git a/apps/case/schemas/case_full.json b/apps/case/schemas/case_full.json new file mode 100644 index 0000000..aff2eed --- /dev/null +++ b/apps/case/schemas/case_full.json @@ -0,0 +1,94 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "CaseFullParseResult", + "description": "C1 parse-pdf 输出 / C3 full-create 入参共用 schema(不含 scoring_rules)", + "type": "object", + "required": ["title", "case_type"], + "properties": { + "title": { + "type": "string", + "minLength": 1, + "maxLength": 255 + }, + "case_type": { + "type": "string", + "enum": ["traditional", "teaching"] + }, + "difficulty": { + "type": "string", + "enum": ["easy", "medium", "hard", ""] + }, + "chief_complaint": { "type": "string" }, + "description": { "type": "string" }, + "patient_age": { + "oneOf": [ + { "type": "integer", "minimum": 0, "maximum": 200 }, + { "type": "null" } + ] + }, + "patient_gender": { + "type": "string", + "enum": ["male", "female", ""] + }, + "tags": { "type": "string" }, + "symptom_tags": { + "type": "array", + "items": { "type": "string" } + }, + "disease_tags": { + "type": "array", + "items": { "type": "string" } + }, + "competency_tags": { + "type": "array", + "items": { "type": "string" } + }, + "guideline_tags": { + "type": "array", + "items": { "type": "string" } + }, + "knowledge_points": { + "type": "array", + "items": { "type": "string" } + }, + "icd_codes": { "type": "string" }, + "estimated_minutes": { + "oneOf": [ + { "type": "integer", "minimum": 1 }, + { "type": "null" } + ] + }, + "osce_enabled": { "type": "boolean" }, + "department_name": { "type": "string" }, + "traditional": { + "type": "object", + "properties": { + "standard_diagnosis": { "type": "string" }, + "standard_treatment": { "type": "string" }, + "guideline_reference": { "type": "string" } + }, + "additionalProperties": false + }, + "teaching": { + "type": "object", + "properties": { + "teaching_goal": { "type": "string" }, + "discussion_questions": { "type": "string" }, + "teacher_guide": { "type": "string" }, + "scoring_focus": { "type": "string" } + }, + "additionalProperties": false + } + }, + "oneOf": [ + { + "properties": { "case_type": { "const": "traditional" } }, + "required": ["traditional"] + }, + { + "properties": { "case_type": { "const": "teaching" } }, + "required": ["teaching"] + } + ], + "additionalProperties": false +} diff --git a/apps/case/schemas/scoring_rules.json b/apps/case/schemas/scoring_rules.json new file mode 100644 index 0000000..b6b85a2 --- /dev/null +++ b/apps/case/schemas/scoring_rules.json @@ -0,0 +1,47 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "ScoringRulesSchema", + "description": "C2 generate-scoring-rules AI 输出校验 schema", + "type": "object", + "required": ["scoring_rules"], + "properties": { + "scoring_rules": { + "type": "array", + "items": { + "type": "object", + "required": ["dimension", "score_weight"], + "properties": { + "dimension": { + "type": "string", + "minLength": 1, + "maxLength": 50 + }, + "competency_dimension": { + "type": "string", + "maxLength": 50 + }, + "score_weight": { + "type": "number", + "exclusiveMinimum": 0, + "maximum": 1 + }, + "ai_auto_score": { + "type": "boolean" + }, + "osce_dimension": { + "type": "boolean" + }, + "scoring_standard": { + "type": "string" + }, + "rubric_json": { + "type": "object", + "additionalProperties": { "type": "string" } + } + }, + "additionalProperties": false + } + } + }, + "additionalProperties": false +} diff --git a/apps/case/serializers.py b/apps/case/serializers.py new file mode 100644 index 0000000..88fad70 --- /dev/null +++ b/apps/case/serializers.py @@ -0,0 +1,84 @@ +from rest_framework import serializers +from .models import ( + CaseBase, TraditionalCase, ScriptCase, + TeachingCase, CaseStage, ScoringRule +) + + +class CaseBaseListSerializer(serializers.ModelSerializer): + department_name = serializers.CharField(source='department.name', read_only=True) + created_by_name = serializers.CharField(source='created_by.real_name', read_only=True) + + class Meta: + model = CaseBase + fields = [ + 'id', 'title', 'case_type', 'difficulty', 'difficulty_score', + 'department', 'department_name', 'chief_complaint', 'patient_age', + 'patient_gender', 'tags', 'estimated_minutes', 'osce_enabled', + 'publish_status', 'status', 'created_by_name', 'created_at', 'updated_at' + ] + + +class CaseBaseDetailSerializer(serializers.ModelSerializer): + department_name = serializers.CharField(source='department.name', read_only=True) + created_by_name = serializers.CharField(source='created_by.real_name', read_only=True) + + class Meta: + model = CaseBase + fields = '__all__' + + +class CaseBaseCreateSerializer(serializers.ModelSerializer): + """病例创建序列化器""" + class Meta: + model = CaseBase + fields = [ + 'title', 'case_type', 'difficulty', 'department', + 'chief_complaint', 'description', 'patient_age', + 'patient_gender', 'tags', 'symptom_tags', 'disease_tags', + 'competency_tags', 'guideline_tags', 'knowledge_points', + 'icd_codes', 'estimated_minutes', 'osce_enabled', + 'rag_enabled', 'ai_prompt_template', 'multimodal_assets' + ] + + def create(self, validated_data): + validated_data['created_by'] = self.context['request'].user + return super().create(validated_data) + + +class TraditionalCaseSerializer(serializers.ModelSerializer): + case_title = serializers.CharField(source='case.title', read_only=True) + + class Meta: + model = TraditionalCase + fields = '__all__' + + +class ScriptCaseSerializer(serializers.ModelSerializer): + case_title = serializers.CharField(source='case.title', read_only=True) + + class Meta: + model = ScriptCase + fields = '__all__' + + +class TeachingCaseSerializer(serializers.ModelSerializer): + case_title = serializers.CharField(source='case.title', read_only=True) + + class Meta: + model = TeachingCase + fields = '__all__' + + +class CaseStageSerializer(serializers.ModelSerializer): + class Meta: + model = CaseStage + fields = '__all__' + + +class ScoringRuleSerializer(serializers.ModelSerializer): + case_title = serializers.CharField(source='case.title', read_only=True) + + class Meta: + model = ScoringRule + fields = '__all__' diff --git a/apps/case/services/__init__.py b/apps/case/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/case/services/case_importer.py b/apps/case/services/case_importer.py new file mode 100644 index 0000000..c817cca --- /dev/null +++ b/apps/case/services/case_importer.py @@ -0,0 +1,95 @@ +import json +import logging +import time +import uuid + +import jsonschema +from django.core.cache import cache +from pathlib import Path + +from config.exceptions import AppError +from . import deepseek_client +from .pdf_reader import extract_text_from_pdfs +from prompts.loader import load_prompt + +logger = logging.getLogger(__name__) +audit = logging.getLogger('audit') + +_SCHEMA_PATH = Path(__file__).resolve().parent.parent / 'schemas' / 'case_full.json' +PARSE_RESULT_TTL = 300 # 5 minutes + + +def parse_pdf(files, case_type: str, user) -> dict: + """C1: PDF 解析 → DeepSeek → 结构化数据(不落库,不含评分规则)。""" + if case_type not in ('traditional', 'teaching'): + raise AppError('CASE_TYPE_NOT_SUPPORTED', f'case_type 不支持: {case_type}', status_code=400) + + t0 = time.time() + + text = extract_text_from_pdfs(files) + + prompt_name = f'case_{case_type}_full' + system_prompt, prompt_version = load_prompt(prompt_name) + + result = deepseek_client.call_deepseek(system_prompt, text) + data = result['data'] + + data.pop('scoring_rules', None) + data.pop('stages', None) + + data['case_type'] = case_type + + _strip_unknown_fields(data) + _validate_schema(data) + + parse_id = uuid.uuid4().hex[:12] + cache.set(f'parse_result:{parse_id}', json.dumps(data, ensure_ascii=False), PARSE_RESULT_TTL) + + source = { + 'files': [f.name for f in files], + 'total_bytes': sum(f.size for f in files), + } + + audit.info( + 'CASE_PARSE user=%s files=%d parse_id=%s tokens=%s prompt_version=%s', + user.id, len(files), parse_id, + result.get('usage', {}), prompt_version, + ) + + return { + 'parse_id': parse_id, + 'case_type': case_type, + 'source': source, + 'ai_usage': result.get('usage', {}), + 'prompt_version': prompt_version, + 'parsing_seconds': round(time.time() - t0, 1), + 'data': data, + } + + +_SCHEMA_ALLOWED_KEYS = { + 'title', 'case_type', 'difficulty', 'chief_complaint', 'description', + 'patient_age', 'patient_gender', 'tags', 'symptom_tags', 'disease_tags', + 'competency_tags', 'guideline_tags', 'knowledge_points', 'icd_codes', + 'estimated_minutes', 'osce_enabled', 'department_name', + 'traditional', 'teaching', +} + + +def _strip_unknown_fields(data): + for key in list(data.keys()): + if key not in _SCHEMA_ALLOWED_KEYS: + data.pop(key) + + +def _validate_schema(data): + schema = json.loads(_SCHEMA_PATH.read_text(encoding='utf-8')) + try: + jsonschema.validate(instance=data, schema=schema) + except jsonschema.ValidationError as e: + logger.error('AI parse output schema violation: %s', e.message) + raise AppError( + 'AI_SCHEMA_VIOLATION', + f'AI 输出字段不合法: {e.message}', + status_code=500, + ) diff --git a/apps/case/services/deepseek_client.py b/apps/case/services/deepseek_client.py new file mode 100644 index 0000000..ca53ce2 --- /dev/null +++ b/apps/case/services/deepseek_client.py @@ -0,0 +1,65 @@ +import json +import logging + +from django.conf import settings +from openai import OpenAI, APITimeoutError, APIConnectionError, APIStatusError + +from config.exceptions import AppError + +logger = logging.getLogger(__name__) + + +def get_client(): + return OpenAI( + api_key=settings.DEEPSEEK_API_KEY, + base_url=settings.DEEPSEEK_BASE_URL, + timeout=settings.DEEPSEEK_TIMEOUT_SECONDS, + ) + + +def call_deepseek(system_prompt: str, user_content: str) -> dict: + """调用 DeepSeek,返回解析后的 JSON dict + usage 信息。 + + 自带 1 次重试:首次失败时将错误信息附给第二次调用。 + """ + client = get_client() + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ] + + last_error = None + for attempt in range(1 + settings.DEEPSEEK_MAX_RETRIES): + if attempt > 0 and last_error: + messages.append({ + "role": "user", + "content": f"上一次输出不是合法 JSON,错误:{last_error}。请严格输出合法 JSON。", + }) + + try: + response = client.chat.completions.create( + model=settings.DEEPSEEK_MODEL, + messages=messages, + response_format={"type": "json_object"}, + temperature=0.3, + ) + except APITimeoutError: + raise AppError('AI_TIMEOUT', 'DeepSeek 请求超时', status_code=504) + except (APIConnectionError, APIStatusError) as e: + logger.error('DeepSeek API error: %s', e) + raise AppError('AI_PROVIDER_ERROR', f'DeepSeek 服务异常: {e}', status_code=502) + + raw = response.choices[0].message.content + usage = { + 'prompt_tokens': response.usage.prompt_tokens, + 'completion_tokens': response.usage.completion_tokens, + } if response.usage else {} + + try: + parsed = json.loads(raw) + return {'data': parsed, 'usage': usage} + except (json.JSONDecodeError, TypeError) as e: + last_error = str(e) + logger.warning('DeepSeek JSON parse failed (attempt %d): %s', attempt + 1, e) + + raise AppError('AI_BAD_JSON', f'AI 返回非合法 JSON(重试后仍失败): {last_error}', status_code=500) diff --git a/apps/case/services/department_resolver.py b/apps/case/services/department_resolver.py new file mode 100644 index 0000000..032b5d6 --- /dev/null +++ b/apps/case/services/department_resolver.py @@ -0,0 +1,32 @@ +from apps.user.models import Department +from config.exceptions import AppError + + +def resolve_department(department_name: str): + """按名称解析科室,返回 Department 实例。 + + - 空/None → 返回 None(不强制) + - 精确匹配 1 条 → 返回该 Department + - 匹配 0 条 → 400 CASE_DEPARTMENT_NOT_FOUND + - 匹配多条 → 400 CASE_DEPARTMENT_AMBIGUOUS + """ + if not department_name: + return None + + qs = Department.objects.filter(name=department_name) + count = qs.count() + + if count == 0: + raise AppError( + 'CASE_DEPARTMENT_NOT_FOUND', + f'科室 "{department_name}" 不存在', + status_code=400, + ) + if count > 1: + raise AppError( + 'CASE_DEPARTMENT_AMBIGUOUS', + f'科室 "{department_name}" 匹配到多条记录,请精确指定', + details={'matches': list(qs.values_list('name', flat=True))}, + status_code=400, + ) + return qs.first() diff --git a/apps/case/services/pdf_reader.py b/apps/case/services/pdf_reader.py new file mode 100644 index 0000000..18786f9 --- /dev/null +++ b/apps/case/services/pdf_reader.py @@ -0,0 +1,50 @@ +import logging + +import pdfplumber + +from config.exceptions import AppError + +logger = logging.getLogger(__name__) + +MAX_FILES = 5 +MAX_FILE_SIZE = 20 * 1024 * 1024 # 20 MB +MAX_TOTAL_SIZE = 60 * 1024 * 1024 # 60 MB +FILE_BREAK = '\n\n---FILE_BREAK: {name}---\n\n' + + +def extract_text_from_pdfs(files) -> str: + """从 1~5 份 PDF 中提取文本,拼接返回。 + + files: request.FILES.getlist('files') 或类似 UploadedFile 列表。 + """ + if not files: + raise AppError('CASE_PDF_EMPTY', '未上传 PDF 文件') + if len(files) > MAX_FILES: + raise AppError('CASE_TOO_MANY_FILES', f'最多上传 {MAX_FILES} 份 PDF', status_code=400) + + total_size = 0 + for f in files: + if f.size > MAX_FILE_SIZE: + raise AppError('CASE_FILE_TOO_LARGE', f'单份 PDF 不得超过 {MAX_FILE_SIZE // (1024*1024)} MB', status_code=400) + total_size += f.size + if total_size > MAX_TOTAL_SIZE: + raise AppError('CASE_FILE_TOO_LARGE', f'PDF 总大小不得超过 {MAX_TOTAL_SIZE // (1024*1024)} MB', status_code=400) + + parts = [] + for f in files: + text = _extract_single(f) + if not text.strip(): + raise AppError('CASE_PDF_EMPTY', f'PDF "{f.name}" 无法提取文本(可能为扫描版)', status_code=400) + parts.append(FILE_BREAK.format(name=f.name) + text if len(files) > 1 else text) + + return ''.join(parts) + + +def _extract_single(uploaded_file) -> str: + try: + with pdfplumber.open(uploaded_file) as pdf: + pages = [page.extract_text() or '' for page in pdf.pages] + return '\n'.join(pages) + except Exception as e: + logger.error('pdfplumber extract failed for %s: %s', uploaded_file.name, e) + raise AppError('CASE_PDF_EMPTY', f'PDF "{uploaded_file.name}" 解析失败: {e}', status_code=400) diff --git a/apps/case/services/scoring_rule_generator.py b/apps/case/services/scoring_rule_generator.py new file mode 100644 index 0000000..6f93e37 --- /dev/null +++ b/apps/case/services/scoring_rule_generator.py @@ -0,0 +1,79 @@ +import json +import logging + +import jsonschema +from pathlib import Path + +from config.exceptions import AppError +from . import deepseek_client +from prompts.loader import load_prompt + +logger = logging.getLogger(__name__) + +_SCHEMA_PATH = Path(__file__).resolve().parent.parent / 'schemas' / 'scoring_rules.json' + +CONTEXT_FIELDS = [ + 'title', 'case_type', 'chief_complaint', 'description', + 'patient_age', 'patient_gender', 'icd_codes', + 'symptom_tags', 'disease_tags', 'competency_tags', + 'guideline_tags', 'knowledge_points', +] + +TRADITIONAL_FIELDS = ['standard_diagnosis', 'standard_treatment', 'guideline_reference'] +TEACHING_FIELDS = ['teaching_goal', 'discussion_questions', 'teacher_guide', 'scoring_focus'] + + +def generate(case_data: dict) -> dict: + """从病例数据 JSON 生成评分规则列表(不写库)。 + + case_data: 前端传入的病例数据(来自 parse-pdf data 或表单)。 + 返回 {"scoring_rules": [...], "usage": {...}, "prompt_version": "..."} + """ + case_type = case_data.get('case_type', '') + if case_type not in ('traditional', 'teaching'): + raise AppError('CASE_TYPE_NOT_SUPPORTED', f'case_type 不支持: {case_type}', status_code=400) + + sub = case_data.get(case_type) or {} + if not sub: + raise AppError('CASE_SUBTYPE_REQUIRED', f'{case_type} 子表数据缺失,AI 无法生成评分规则', status_code=400) + + system_prompt, prompt_version = load_prompt('case_scoring_rules') + + context = {} + for field in CONTEXT_FIELDS: + val = case_data.get(field) + if val is not None and val != '' and val != []: + context[field] = val + + sub_fields = TRADITIONAL_FIELDS if case_type == 'traditional' else TEACHING_FIELDS + for field in sub_fields: + val = sub.get(field) + if val is not None and val != '': + context[field] = val + + result = deepseek_client.call_deepseek(system_prompt, json.dumps(context, ensure_ascii=False)) + + rules = result['data'].get('scoring_rules', []) + _validate_schema(rules) + + if not rules: + raise AppError('AI_EMPTY_RESULT', 'AI 返回 scoring_rules 为空数组', status_code=500) + + return { + 'scoring_rules': rules, + 'usage': result['usage'], + 'prompt_version': prompt_version, + } + + +def _validate_schema(rules): + schema = json.loads(_SCHEMA_PATH.read_text(encoding='utf-8')) + try: + jsonschema.validate(instance={'scoring_rules': rules}, schema=schema) + except jsonschema.ValidationError as e: + logger.error('AI scoring rules schema violation: %s', e.message) + raise AppError( + 'AI_SCHEMA_VIOLATION', + f'AI 输出字段类型不合法: {e.message}', + status_code=500, + ) diff --git a/apps/case/urls.py b/apps/case/urls.py new file mode 100644 index 0000000..8f5ab62 --- /dev/null +++ b/apps/case/urls.py @@ -0,0 +1,15 @@ +from django.urls import path, include +from rest_framework.routers import DefaultRouter +from . import views + +router = DefaultRouter() +router.register(r'cases', views.CaseBaseViewSet, basename='case') +router.register(r'traditional-cases', views.TraditionalCaseViewSet, basename='traditional-case') +router.register(r'script-cases', views.ScriptCaseViewSet, basename='script-case') +router.register(r'teaching-cases', views.TeachingCaseViewSet, basename='teaching-case') +router.register(r'case-stages', views.CaseStageViewSet, basename='case-stage') +router.register(r'scoring-rules', views.ScoringRuleViewSet, basename='scoring-rule') + +urlpatterns = [ + path('', include(router.urls)), +] diff --git a/apps/case/views.py b/apps/case/views.py new file mode 100644 index 0000000..0a0991b --- /dev/null +++ b/apps/case/views.py @@ -0,0 +1,459 @@ +import logging + +from django.db import transaction +from drf_spectacular.utils import extend_schema, OpenApiResponse, inline_serializer +from rest_framework import viewsets, filters, status, serializers as drf_serializers +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.response import Response +from django_filters.rest_framework import DjangoFilterBackend + +from config.exceptions import AppError +from apps.user.permissions import IsCaseOperationPermitted +from apps.user.throttling import PdfParseUserThrottle, ScoringRuleGenerateUserThrottle +from .models import ( + CaseBase, TraditionalCase, ScriptCase, + TeachingCase, CaseStage, ScoringRule +) +from .serializers import ( + CaseBaseListSerializer, CaseBaseDetailSerializer, + CaseBaseCreateSerializer, TraditionalCaseSerializer, + ScriptCaseSerializer, TeachingCaseSerializer, + CaseStageSerializer, ScoringRuleSerializer +) +from .services import case_importer, scoring_rule_generator +from .services.department_resolver import resolve_department + +audit = logging.getLogger('audit') + +TRADITIONAL_FIELDS = {'standard_diagnosis', 'standard_treatment', 'guideline_reference'} +TEACHING_FIELDS = {'teaching_goal', 'discussion_questions', 'teacher_guide', 'scoring_focus'} +SCORING_RULE_FIELDS = { + 'dimension', 'competency_dimension', 'score_weight', + 'ai_auto_score', 'osce_dimension', 'scoring_standard', 'rubric_json', +} + +CASE_BASE_FIELDS = { + 'title', 'case_type', 'difficulty', 'difficulty_score', + 'chief_complaint', 'description', 'patient_age', 'patient_gender', + 'tags', 'symptom_tags', 'disease_tags', 'competency_tags', + 'guideline_tags', 'knowledge_points', 'icd_codes', + 'estimated_minutes', 'osce_enabled', 'rag_enabled', + 'ai_prompt_template', 'multimodal_assets', +} + + +class CaseBaseViewSet(viewsets.ModelViewSet): + """病例管理""" + queryset = CaseBase.objects.all() + filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter] + filterset_fields = [ + 'case_type', 'difficulty', 'department', + 'publish_status', 'status', 'osce_enabled' + ] + search_fields = ['title', 'chief_complaint', 'tags', 'icd_codes'] + ordering_fields = ['created_at', 'difficulty_score', 'estimated_minutes'] + + def get_serializer_class(self): + if self.action == 'list': + return CaseBaseListSerializer + elif self.action == 'create': + return CaseBaseCreateSerializer + return CaseBaseDetailSerializer + + # ── C1: parse-pdf ──────────────────────────────────────────────────── + + @extend_schema( + summary='C1: PDF 解析', + description='上传 1~5 份 PDF,调用 DeepSeek 提取结构化病例数据。不落库,不含评分规则。', + request={'multipart/form-data': {'type': 'object', 'properties': { + 'files': {'type': 'array', 'items': {'type': 'string', 'format': 'binary'}}, + 'case_type': {'type': 'string', 'enum': ['traditional', 'teaching']}, + }, 'required': ['files', 'case_type']}}, + responses={200: OpenApiResponse(description='解析结果(含 parse_id、data)')}, + tags=['病例'], + ) + @action( + detail=False, methods=['post'], url_path='parse-pdf', + parser_classes=[MultiPartParser], + permission_classes=[IsCaseOperationPermitted], + throttle_classes=[PdfParseUserThrottle], + ) + def parse_pdf(self, request): + """C1: PDF 解析 → 结构化数据(不落库,不含评分规则)""" + files = request.FILES.getlist('files') + case_type = request.data.get('case_type', '') + result = case_importer.parse_pdf(files, case_type, request.user) + return Response(result) + + # ── C2: generate-scoring-rules ─────────────────────────────────────── + + @extend_schema( + summary='C2: AI 生成评分规则预览', + description='传入病例数据 JSON,调用 DeepSeek 生成评分规则。不落库,返回规则列表供前端审核。', + request=inline_serializer('GenerateScoringRulesRequest', fields={ + 'case_type': drf_serializers.ChoiceField(choices=['traditional', 'teaching'], help_text='病例类型'), + 'title': drf_serializers.CharField(help_text='病例标题'), + 'chief_complaint': drf_serializers.CharField(required=False, help_text='主诉'), + 'traditional': drf_serializers.DictField(required=False, help_text='传统病例子表(case_type=traditional 时传)'), + 'teaching': drf_serializers.DictField(required=False, help_text='教学病例子表(case_type=teaching 时传)'), + }), + responses={200: OpenApiResponse(description='评分规则列表 + AI 用量信息')}, + tags=['病例'], + ) + @action( + detail=False, methods=['post'], url_path='generate-scoring-rules', + permission_classes=[IsCaseOperationPermitted], + throttle_classes=[ScoringRuleGenerateUserThrottle], + ) + def generate_scoring_rules(self, request): + """C2: AI 生成评分规则预览(不落库)""" + result = scoring_rule_generator.generate(request.data) + + audit.info( + 'CASE_SCORING_RULE_PREVIEW user=%s case_type=%s rules=%d prompt_version=%s', + request.user.id, request.data.get('case_type', ''), + len(result['scoring_rules']), result['prompt_version'], + ) + + return Response({ + 'generated': len(result['scoring_rules']), + 'ai_usage': result['usage'], + 'prompt_version': result['prompt_version'], + 'scoring_rules': result['scoring_rules'], + }) + + # ── C3: full-create ────────────────────────────────────────────────── + + @extend_schema( + summary='C3: 创建病例(统一落库入口)', + description='病例主表 + 子表 + scoring_rules(≥1 条)同一事务入库。scoring_rules 必填。', + request=inline_serializer('FullCreateRequest', fields={ + 'title': drf_serializers.CharField(help_text='病例标题'), + 'case_type': drf_serializers.ChoiceField(choices=['traditional', 'teaching']), + 'department_name': drf_serializers.CharField(required=False, help_text='科室名称(后端解析为 department_id)'), + 'traditional': drf_serializers.DictField(required=False), + 'teaching': drf_serializers.DictField(required=False), + 'scoring_rules': drf_serializers.ListField(child=drf_serializers.DictField(), help_text='评分规则(≥1 条,必填)'), + 'parse_id': drf_serializers.CharField(required=False, help_text='来自 parse-pdf 的 parse_id(审计用)'), + 'auto_publish': drf_serializers.BooleanField(required=False, default=False), + }), + responses={201: OpenApiResponse(description='完整病例结构(同 GET full)')}, + tags=['病例'], + ) + @action( + detail=False, methods=['post'], url_path='full-create', + permission_classes=[IsCaseOperationPermitted], + ) + def full_create(self, request): + """C3: 创建病例(主表+子表+评分规则同一事务)""" + data = request.data + case_type = data.get('case_type', '') + + if case_type not in ('traditional', 'teaching'): + raise AppError('CASE_TYPE_NOT_SUPPORTED', f'case_type 不支持: {case_type}', status_code=400) + + if 'stages' in data: + raise AppError('CASE_FIELD_NOT_ALLOWED', '本接口不接收 stages 字段', status_code=400) + + sub_data = data.get(case_type) + other_type = 'teaching' if case_type == 'traditional' else 'traditional' + if not sub_data: + raise AppError('CASE_SUBTYPE_REQUIRED', f'{case_type} 子表数据缺失', status_code=400) + if data.get(other_type): + raise AppError('CASE_SUBTYPE_CONFLICT', '不允许同时传入两种子表数据', status_code=400) + + scoring_rules_data = data.get('scoring_rules', []) + if not scoring_rules_data: + raise AppError('CASE_VALIDATION_ERROR', 'scoring_rules 必填且至少 1 条', status_code=400) + _validate_scoring_rules(scoring_rules_data) + + department = resolve_department(data.get('department_name', '')) + + with transaction.atomic(): + case_kwargs = {k: data[k] for k in CASE_BASE_FIELDS if k in data} + case_kwargs['department'] = department + case_kwargs['created_by'] = request.user + case_kwargs['status'] = 1 + case_kwargs['vector_status'] = 0 + case_kwargs['publish_status'] = 1 if data.get('auto_publish') else 0 + case = CaseBase.objects.create(**case_kwargs) + + sub_model = TraditionalCase if case_type == 'traditional' else TeachingCase + allowed_sub = TRADITIONAL_FIELDS if case_type == 'traditional' else TEACHING_FIELDS + sub_model.objects.create(case=case, **{k: v for k, v in sub_data.items() if k in allowed_sub}) + + rule_objs = [ + ScoringRule(case=case, **{k: v for k, v in rule.items() if k in SCORING_RULE_FIELDS}) + for rule in scoring_rules_data + ] + ScoringRule.objects.bulk_create(rule_objs) + + audit.info( + 'CASE_CREATE case_id=%s from=%s by=%s scoring_rules=%d', + case.id, data.get('parse_id', 'form'), request.user.id, len(rule_objs), + ) + + return Response(_build_full_response(case), status=status.HTTP_201_CREATED) + + # ── C4 / C5: GET + PATCH full ─────────────────────────────────────── + + @extend_schema( + summary='C4: GET 完整查看 / C5: PATCH 局部编辑草稿', + description='GET: 返回主表+子表+scoring_rules。PATCH: 仅草稿可编辑,scoring_rules 传入时整体替换。', + responses={200: OpenApiResponse(description='完整病例结构')}, + tags=['病例'], + ) + @action(detail=True, methods=['get', 'patch'], url_path='full') + def full(self, request, pk=None): + """C4: GET 完整查看 / C5: PATCH 局部编辑草稿""" + if request.method == 'GET': + return self._full_detail(request, pk) + return self._full_update(request, pk) + + def _full_detail(self, request, pk): + case = CaseBase.objects.select_related( + 'department', 'created_by' + ).prefetch_related('scoring_rules').filter(pk=pk).first() + + if not case: + raise AppError('NOT_FOUND', '病例不存在', status_code=404) + + if case.publish_status != 1: + if not request.user.is_authenticated: + raise AppError('AUTH_UNAUTHORIZED', '请先登录', status_code=401) + if not (request.user.id == case.created_by_id + or request.user.role_type in ('super_admin', 'content_admin') + or request.user.is_staff): + raise AppError('CASE_PERMISSION_DENIED', '无权查看该草稿', status_code=403) + + return Response(_build_full_response(case)) + + def _full_update(self, request, pk): + case = CaseBase.objects.select_related('department', 'created_by').filter(pk=pk).first() + if not case: + raise AppError('NOT_FOUND', '病例不存在', status_code=404) + + if not (request.user.id == case.created_by_id + or request.user.role_type in ('super_admin', 'content_admin') + or request.user.is_staff): + raise AppError('CASE_PERMISSION_DENIED', '无权编辑该病例', status_code=403) + + if case.publish_status != 0: + raise AppError('CASE_NOT_EDITABLE', '仅草稿可编辑,请先下架', status_code=400) + + data = request.data + + with transaction.atomic(): + changed = False + for field in CASE_BASE_FIELDS: + if field in data: + setattr(case, field, data[field]) + changed = True + if 'department_name' in data: + case.department = resolve_department(data['department_name']) + changed = True + if changed: + case.save() + + case_type = case.case_type + sub_data = data.get(case_type) + if sub_data: + sub_model = TraditionalCase if case_type == 'traditional' else TeachingCase + allowed_sub = TRADITIONAL_FIELDS if case_type == 'traditional' else TEACHING_FIELDS + sub_obj, _ = sub_model.objects.get_or_create(case=case) + for k, v in sub_data.items(): + if k in allowed_sub: + setattr(sub_obj, k, v) + sub_obj.save() + + scoring_rules_data = data.get('scoring_rules') + if scoring_rules_data is not None: + _validate_scoring_rules(scoring_rules_data) + case.scoring_rules.all().delete() + rule_objs = [ + ScoringRule(case=case, **{k: v for k, v in rule.items() if k in SCORING_RULE_FIELDS}) + for rule in scoring_rules_data + ] + ScoringRule.objects.bulk_create(rule_objs) + + audit.info('CASE_UPDATE case_id=%s by=%s', case.id, request.user.id) + + case.refresh_from_db() + return Response(_build_full_response(case)) + + # ── existing actions ───────────────────────────────────────────────── + + @action(detail=True, methods=['get']) + def stages(self, request, pk=None): + case = self.get_object() + serializer = CaseStageSerializer(case.stages.all(), many=True) + return Response(serializer.data) + + @action(detail=True, methods=['post']) + def add_stage(self, request, pk=None): + case = self.get_object() + serializer = CaseStageSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save(case=case) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + @action(detail=True, methods=['get']) + def scoring_rules_list(self, request, pk=None): + case = self.get_object() + serializer = ScoringRuleSerializer(case.scoring_rules.all(), many=True) + return Response(serializer.data) + + @action(detail=True, methods=['post']) + def add_scoring_rule(self, request, pk=None): + case = self.get_object() + serializer = ScoringRuleSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save(case=case) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + @action(detail=True, methods=['post']) + def publish(self, request, pk=None): + case = self.get_object() + case.publish_status = 1 + case.save() + return Response({'message': '病例已发布'}) + + @action(detail=True, methods=['post']) + def unpublish(self, request, pk=None): + case = self.get_object() + case.publish_status = 2 + case.save() + return Response({'message': '病例已下架'}) + + +class TraditionalCaseViewSet(viewsets.ModelViewSet): + queryset = TraditionalCase.objects.all() + serializer_class = TraditionalCaseSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = ['case'] + + +class ScriptCaseViewSet(viewsets.ModelViewSet): + queryset = ScriptCase.objects.all() + serializer_class = ScriptCaseSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = ['case'] + + +class TeachingCaseViewSet(viewsets.ModelViewSet): + queryset = TeachingCase.objects.all() + serializer_class = TeachingCaseSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = ['case'] + + +class CaseStageViewSet(viewsets.ModelViewSet): + queryset = CaseStage.objects.all() + serializer_class = CaseStageSerializer + filter_backends = [DjangoFilterBackend, filters.OrderingFilter] + filterset_fields = ['case', 'stage_type', 'stage_mode'] + ordering_fields = ['sort_order', 'created_at'] + + +class ScoringRuleViewSet(viewsets.ModelViewSet): + queryset = ScoringRule.objects.all() + serializer_class = ScoringRuleSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = ['case', 'dimension', 'ai_auto_score', 'osce_dimension'] + + +# ── helpers ────────────────────────────────────────────────────────────── + +def _validate_scoring_rules(rules): + if not isinstance(rules, list): + raise AppError('CASE_VALIDATION_ERROR', 'scoring_rules 必须为数组', status_code=400) + for i, rule in enumerate(rules): + if not isinstance(rule, dict): + raise AppError('CASE_VALIDATION_ERROR', f'scoring_rules[{i}] 必须为对象', status_code=400) + if not rule.get('dimension'): + raise AppError('CASE_VALIDATION_ERROR', f'scoring_rules[{i}].dimension 必填', status_code=400) + weight = rule.get('score_weight') + if weight is not None: + try: + weight = float(weight) + except (TypeError, ValueError): + raise AppError('CASE_VALIDATION_ERROR', f'scoring_rules[{i}].score_weight 须为数字', status_code=400) + if weight <= 0 or weight > 1: + raise AppError('CASE_VALIDATION_ERROR', f'scoring_rules[{i}].score_weight 须在 (0, 1]', status_code=400) + + +def _build_full_response(case): + result = { + 'case': { + 'id': case.id, + 'title': case.title, + 'case_type': case.case_type, + 'difficulty': case.difficulty, + 'difficulty_score': case.difficulty_score, + 'department': case.department_id, + 'department_name': case.department.name if case.department else None, + 'chief_complaint': case.chief_complaint, + 'description': case.description, + 'patient_age': case.patient_age, + 'patient_gender': case.patient_gender, + 'tags': case.tags, + 'symptom_tags': case.symptom_tags, + 'disease_tags': case.disease_tags, + 'competency_tags': case.competency_tags, + 'guideline_tags': case.guideline_tags, + 'knowledge_points': case.knowledge_points, + 'icd_codes': case.icd_codes, + 'estimated_minutes': case.estimated_minutes, + 'osce_enabled': case.osce_enabled, + 'rag_enabled': case.rag_enabled, + 'ai_prompt_template': case.ai_prompt_template, + 'multimodal_assets': case.multimodal_assets, + 'vector_status': case.vector_status, + 'publish_status': case.publish_status, + 'status': case.status, + 'created_by': case.created_by_id, + 'created_by_name': case.created_by.real_name if case.created_by else None, + 'created_at': case.created_at.isoformat() if case.created_at else None, + 'updated_at': case.updated_at.isoformat() if case.updated_at else None, + }, + } + + if case.case_type == 'traditional': + try: + tc = case.traditionalcase + result['traditional'] = { + 'standard_diagnosis': tc.standard_diagnosis, + 'standard_treatment': tc.standard_treatment, + 'guideline_reference': tc.guideline_reference, + } + except TraditionalCase.DoesNotExist: + result['traditional'] = None + elif case.case_type == 'teaching': + try: + tc = case.teachingcase + result['teaching'] = { + 'teaching_goal': tc.teaching_goal, + 'discussion_questions': tc.discussion_questions, + 'teacher_guide': tc.teacher_guide, + 'scoring_focus': tc.scoring_focus, + } + except TeachingCase.DoesNotExist: + result['teaching'] = None + + rules = case.scoring_rules.all().order_by('id') + result['scoring_rules'] = [ + { + 'id': r.id, + 'dimension': r.dimension, + 'competency_dimension': r.competency_dimension, + 'score_weight': float(r.score_weight), + 'ai_auto_score': r.ai_auto_score, + 'osce_dimension': r.osce_dimension, + 'scoring_standard': r.scoring_standard, + 'rubric_json': r.rubric_json, + } + for r in rules + ] + + return result diff --git a/apps/common/__init__.py b/apps/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/common/apps.py b/apps/common/apps.py new file mode 100644 index 0000000..b1138ae --- /dev/null +++ b/apps/common/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class CommonConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'apps.common' + verbose_name = '公共组件' diff --git a/apps/common/migrations/__init__.py b/apps/common/migrations/__init__.py new file mode 100644 index 0000000..c65238f --- /dev/null +++ b/apps/common/migrations/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/apps/common/models.py b/apps/common/models.py new file mode 100644 index 0000000..73b6fbc --- /dev/null +++ b/apps/common/models.py @@ -0,0 +1,10 @@ +from django.db import models + + +class BaseModel(models.Model): + """基础模型,包含通用字段""" + created_at = models.DateTimeField('创建时间', auto_now_add=True) + updated_at = models.DateTimeField('更新时间', auto_now=True) + + class Meta: + abstract = True diff --git a/apps/training/__init__.py b/apps/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/training/admin.py b/apps/training/admin.py new file mode 100644 index 0000000..777b328 --- /dev/null +++ b/apps/training/admin.py @@ -0,0 +1,21 @@ +from django.contrib import admin +from .models import TrainingRecord, TrainingScoreDetail + + +@admin.register(TrainingRecord) +class TrainingRecordAdmin(admin.ModelAdmin): + list_display = [ + 'id', 'user', 'case', 'training_mode', + 'total_score', 'evaluation_level', 'status', + 'start_time', 'duration_seconds' + ] + list_filter = ['training_mode', 'evaluation_level', 'status'] + search_fields = ['user__real_name', 'case__title', 'feedback'] + ordering = ['-start_time'] + + +@admin.register(TrainingScoreDetail) +class TrainingScoreDetailAdmin(admin.ModelAdmin): + list_display = ['id', 'record', 'dimension', 'score', 'ai_confidence'] + list_filter = ['dimension'] + search_fields = ['record__user__real_name', 'dimension'] diff --git a/apps/training/apps.py b/apps/training/apps.py new file mode 100644 index 0000000..7d86db2 --- /dev/null +++ b/apps/training/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class TrainingConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'apps.training' + verbose_name = '训练管理' diff --git a/apps/training/migrations/0001_initial.py b/apps/training/migrations/0001_initial.py new file mode 100644 index 0000000..6b0c65f --- /dev/null +++ b/apps/training/migrations/0001_initial.py @@ -0,0 +1,71 @@ +# Generated by Django 6.0.5 on 2026-05-26 07:02 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('case', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='TrainingScoreDetail', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('dimension', models.CharField(max_length=50, verbose_name='评分维度')), + ('score', models.DecimalField(decimal_places=2, max_digits=5, verbose_name='分数')), + ('deducted_reason', models.TextField(blank=True, verbose_name='扣分原因')), + ('evidence_message_ids', models.JSONField(blank=True, default=list, verbose_name='对应对话证据')), + ('ai_confidence', models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True, verbose_name='AI评分置信度')), + ('comment', models.TextField(blank=True, verbose_name='评语')), + ], + options={ + 'verbose_name': '评分明细', + 'verbose_name_plural': '评分明细', + 'db_table': 'training_score_detail', + }, + ), + migrations.CreateModel( + name='TrainingRecord', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('training_mode', models.CharField(choices=[('novice', '新手'), ('practice', '练习'), ('exam', '考试')], max_length=50, verbose_name='训练模式')), + ('case_type', models.CharField(blank=True, max_length=30, verbose_name='病例类型')), + ('start_time', models.DateTimeField(auto_now_add=True, verbose_name='开始时间')), + ('end_time', models.DateTimeField(blank=True, null=True, verbose_name='结束时间')), + ('duration_seconds', models.IntegerField(blank=True, null=True, verbose_name='训练时长')), + ('total_score', models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True, verbose_name='总分')), + ('ai_score', models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True, verbose_name='AI评分')), + ('teacher_score', models.DecimalField(blank=True, decimal_places=2, max_digits=5, null=True, verbose_name='教师评分')), + ('evaluation_level', models.CharField(blank=True, choices=[('excellent', '优秀'), ('good', '良好'), ('average', '一般'), ('poor', '较差')], max_length=20, verbose_name='评价等级')), + ('status', models.CharField(choices=[('in_progress', '进行中'), ('completed', '已完成'), ('aborted', '已中断')], default='in_progress', max_length=30, verbose_name='状态')), + ('feedback', models.TextField(blank=True, verbose_name='总评')), + ('thinking_chain', models.TextField(blank=True, verbose_name='临床推理链')), + ('diagnosis_path', models.TextField(blank=True, verbose_name='诊断路径')), + ('wrong_points', models.JSONField(blank=True, default=list, verbose_name='错误知识点')), + ('missed_questions', models.JSONField(blank=True, default=list, verbose_name='漏问项')), + ('recommendation_result', models.JSONField(blank=True, default=dict, verbose_name='AI推荐')), + ('ai_feedback_structured', models.JSONField(blank=True, default=dict, verbose_name='AI结构化反馈')), + ('osce_station_score', models.JSONField(blank=True, default=dict, verbose_name='OSCE各站点成绩')), + ('interruption_count', models.IntegerField(default=0, verbose_name='中断次数')), + ('emotion_analysis', models.JSONField(blank=True, default=dict, verbose_name='情绪分析')), + ('prompt_version', models.CharField(blank=True, max_length=50, verbose_name='Prompt版本')), + ('rag_context_version', models.CharField(blank=True, max_length=50, verbose_name='知识上下文版本')), + ('case', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_records', to='case.casebase', verbose_name='病例')), + ], + options={ + 'verbose_name': '训练记录', + 'verbose_name_plural': '训练记录', + 'db_table': 'training_record', + }, + ), + ] diff --git a/apps/training/migrations/0002_initial.py b/apps/training/migrations/0002_initial.py new file mode 100644 index 0000000..08d5aa9 --- /dev/null +++ b/apps/training/migrations/0002_initial.py @@ -0,0 +1,39 @@ +# Generated by Django 6.0.5 on 2026-05-26 07:02 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('case', '0002_initial'), + ('training', '0001_initial'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AddField( + model_name='trainingrecord', + name='teacher', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='supervised_records', to=settings.AUTH_USER_MODEL, verbose_name='带教老师'), + ), + migrations.AddField( + model_name='trainingrecord', + name='user', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_records', to=settings.AUTH_USER_MODEL, verbose_name='用户'), + ), + migrations.AddField( + model_name='trainingscoredetail', + name='record', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='score_details', to='training.trainingrecord', verbose_name='训练记录'), + ), + migrations.AddField( + model_name='trainingscoredetail', + name='rule', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='case.scoringrule', verbose_name='评分规则'), + ), + ] diff --git a/apps/training/migrations/__init__.py b/apps/training/migrations/__init__.py new file mode 100644 index 0000000..c65238f --- /dev/null +++ b/apps/training/migrations/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/apps/training/models.py b/apps/training/models.py new file mode 100644 index 0000000..729d811 --- /dev/null +++ b/apps/training/models.py @@ -0,0 +1,96 @@ +from django.db import models +from apps.common.models import BaseModel +from apps.user.models import User +from apps.case.models import CaseBase + + +class TrainingRecord(BaseModel): + """训练记录表""" + TRAINING_MODE_CHOICES = [ + ('novice', '新手'), + ('practice', '练习'), + ('exam', '考试'), + ] + EVALUATION_LEVEL_CHOICES = [ + ('excellent', '优秀'), + ('good', '良好'), + ('average', '一般'), + ('poor', '较差'), + ] + STATUS_CHOICES = [ + ('in_progress', '进行中'), + ('completed', '已完成'), + ('aborted', '已中断'), + ] + + id = models.BigAutoField(primary_key=True) + user = models.ForeignKey( + User, on_delete=models.CASCADE, + related_name='training_records', verbose_name='用户' + ) + case = models.ForeignKey( + CaseBase, on_delete=models.CASCADE, + related_name='training_records', verbose_name='病例' + ) + training_mode = models.CharField('训练模式', max_length=50, choices=TRAINING_MODE_CHOICES) + case_type = models.CharField('病例类型', max_length=30, blank=True) + teacher = models.ForeignKey( + User, on_delete=models.SET_NULL, + null=True, blank=True, related_name='supervised_records', + verbose_name='带教老师' + ) + start_time = models.DateTimeField('开始时间', auto_now_add=True) + end_time = models.DateTimeField('结束时间', null=True, blank=True) + duration_seconds = models.IntegerField('训练时长', null=True, blank=True) + total_score = models.DecimalField('总分', max_digits=5, decimal_places=2, null=True, blank=True) + ai_score = models.DecimalField('AI评分', max_digits=5, decimal_places=2, null=True, blank=True) + teacher_score = models.DecimalField('教师评分', max_digits=5, decimal_places=2, null=True, blank=True) + evaluation_level = models.CharField('评价等级', max_length=20, choices=EVALUATION_LEVEL_CHOICES, blank=True) + status = models.CharField('状态', max_length=30, choices=STATUS_CHOICES, default='in_progress') + feedback = models.TextField('总评', blank=True) + thinking_chain = models.TextField('临床推理链', blank=True) + diagnosis_path = models.TextField('诊断路径', blank=True) + wrong_points = models.JSONField('错误知识点', default=list, blank=True) + missed_questions = models.JSONField('漏问项', default=list, blank=True) + recommendation_result = models.JSONField('AI推荐', default=dict, blank=True) + ai_feedback_structured = models.JSONField('AI结构化反馈', default=dict, blank=True) + osce_station_score = models.JSONField('OSCE各站点成绩', default=dict, blank=True) + interruption_count = models.IntegerField('中断次数', default=0) + emotion_analysis = models.JSONField('情绪分析', default=dict, blank=True) + prompt_version = models.CharField('Prompt版本', max_length=50, blank=True) + rag_context_version = models.CharField('知识上下文版本', max_length=50, blank=True) + + class Meta: + db_table = 'training_record' + verbose_name = '训练记录' + verbose_name_plural = '训练记录' + + def __str__(self): + return f"{self.user.username} - {self.case.title}" + + +class TrainingScoreDetail(BaseModel): + """评分明细表""" + id = models.BigAutoField(primary_key=True) + record = models.ForeignKey( + TrainingRecord, on_delete=models.CASCADE, + related_name='score_details', verbose_name='训练记录' + ) + rule = models.ForeignKey( + 'case.ScoringRule', on_delete=models.CASCADE, + null=True, blank=True, verbose_name='评分规则' + ) + dimension = models.CharField('评分维度', max_length=50) + score = models.DecimalField('分数', max_digits=5, decimal_places=2) + deducted_reason = models.TextField('扣分原因', blank=True) + evidence_message_ids = models.JSONField('对应对话证据', default=list, blank=True) + ai_confidence = models.DecimalField('AI评分置信度', max_digits=5, decimal_places=2, null=True, blank=True) + comment = models.TextField('评语', blank=True) + + class Meta: + db_table = 'training_score_detail' + verbose_name = '评分明细' + verbose_name_plural = '评分明细' + + def __str__(self): + return f"{self.record} - {self.dimension}: {self.score}" diff --git a/apps/training/serializers.py b/apps/training/serializers.py new file mode 100644 index 0000000..ca77e92 --- /dev/null +++ b/apps/training/serializers.py @@ -0,0 +1,49 @@ +from rest_framework import serializers +from .models import TrainingRecord, TrainingScoreDetail + + +class TrainingRecordListSerializer(serializers.ModelSerializer): + user_name = serializers.CharField(source='user.real_name', read_only=True) + case_title = serializers.CharField(source='case.title', read_only=True) + teacher_name = serializers.CharField(source='teacher.real_name', read_only=True) + + class Meta: + model = TrainingRecord + fields = [ + 'id', 'user', 'user_name', 'case', 'case_title', + 'training_mode', 'case_type', 'teacher', 'teacher_name', + 'start_time', 'end_time', 'duration_seconds', 'total_score', + 'evaluation_level', 'status', 'created_at', 'updated_at' + ] + + +class TrainingRecordDetailSerializer(serializers.ModelSerializer): + user_name = serializers.CharField(source='user.real_name', read_only=True) + case_title = serializers.CharField(source='case.title', read_only=True) + teacher_name = serializers.CharField(source='teacher.real_name', read_only=True) + + class Meta: + model = TrainingRecord + fields = '__all__' + + +class TrainingRecordCreateSerializer(serializers.ModelSerializer): + """训练记录创建序列化器""" + class Meta: + model = TrainingRecord + fields = [ + 'case', 'training_mode', 'case_type', 'teacher' + ] + + def create(self, validated_data): + validated_data['user'] = self.context['request'].user + return super().create(validated_data) + + +class TrainingScoreDetailSerializer(serializers.ModelSerializer): + record_info = serializers.CharField(source='record', read_only=True) + rule_dimension = serializers.CharField(source='rule.dimension', read_only=True) + + class Meta: + model = TrainingScoreDetail + fields = '__all__' diff --git a/apps/training/urls.py b/apps/training/urls.py new file mode 100644 index 0000000..cd5f592 --- /dev/null +++ b/apps/training/urls.py @@ -0,0 +1,11 @@ +from django.urls import path, include +from rest_framework.routers import DefaultRouter +from . import views + +router = DefaultRouter() +router.register(r'training-records', views.TrainingRecordViewSet, basename='training-record') +router.register(r'training-score-details', views.TrainingScoreDetailViewSet, basename='training-score-detail') + +urlpatterns = [ + path('', include(router.urls)), +] diff --git a/apps/training/views.py b/apps/training/views.py new file mode 100644 index 0000000..09337e7 --- /dev/null +++ b/apps/training/views.py @@ -0,0 +1,104 @@ +from rest_framework import viewsets, filters, status +from rest_framework.decorators import action +from rest_framework.response import Response +from django_filters.rest_framework import DjangoFilterBackend +from .models import TrainingRecord, TrainingScoreDetail +from .serializers import ( + TrainingRecordListSerializer, TrainingRecordDetailSerializer, + TrainingRecordCreateSerializer, TrainingScoreDetailSerializer +) + + +class TrainingRecordViewSet(viewsets.ModelViewSet): + """训练记录管理 + + list: 获取训练记录列表(支持过滤、搜索、排序) + create: 开始训练(创建记录) + retrieve: 获取训练详情 + update: 更新训练记录 + destroy: 删除训练记录 + """ + queryset = TrainingRecord.objects.all() + filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter] + filterset_fields = [ + 'user', 'case', 'training_mode', 'case_type', + 'teacher', 'evaluation_level', 'status' + ] + search_fields = ['feedback'] + ordering_fields = ['start_time', 'end_time', 'total_score', 'created_at'] + + def get_serializer_class(self): + if self.action == 'list': + return TrainingRecordListSerializer + elif self.action == 'create': + return TrainingRecordCreateSerializer + return TrainingRecordDetailSerializer + + def get_queryset(self): + """普通用户只能看到自己的记录,老师可以看到学生的""" + queryset = super().get_queryset() + user = self.request.user + + # 超级管理员可以看所有 + if user.is_superuser: + return queryset + + # 老师可以看到自己学生的记录 + return queryset.filter(user=user) | queryset.filter(teacher=user) + + @action(detail=True, methods=['get']) + def score_details(self, request, pk=None): + """获取训练评分明细""" + record = self.get_object() + details = record.score_details.all() + serializer = TrainingScoreDetailSerializer(details, many=True) + return Response(serializer.data) + + @action(detail=True, methods=['post']) + def complete(self, request, pk=None): + """完成训练""" + record = self.get_object() + record.status = 'completed' + record.end_time = timezone.now() + + # 计算训练时长 + if record.start_time: + duration = (record.end_time - record.start_time).total_seconds() + record.duration_seconds = int(duration) + + record.save() + + # 更新用户统计 + user = record.user + user.total_training_count += 1 + user.total_case_count += 1 + user.save() + + return Response({'message': '训练已完成', 'duration_seconds': record.duration_seconds}) + + @action(detail=True, methods=['post']) + def abort(self, request, pk=None): + """中断训练""" + record = self.get_object() + record.status = 'aborted' + record.end_time = timezone.now() + record.interruption_count += 1 + record.save() + return Response({'message': '训练已中断'}) + + @action(detail=True, methods=['post']) + def add_score(self, request, pk=None): + """添加评分""" + record = self.get_object() + serializer = TrainingScoreDetailSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save(record=record) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + +class TrainingScoreDetailViewSet(viewsets.ModelViewSet): + """评分明细管理""" + queryset = TrainingScoreDetail.objects.all() + serializer_class = TrainingScoreDetailSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = ['record', 'rule', 'dimension'] diff --git a/apps/user/__init__.py b/apps/user/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/user/admin.py b/apps/user/admin.py new file mode 100644 index 0000000..c799c1c --- /dev/null +++ b/apps/user/admin.py @@ -0,0 +1,65 @@ +from django.contrib import admin +from django.contrib.auth.admin import UserAdmin as BaseUserAdmin +from .models import User, Role, TeacherStudentRelation, Institution, Department + + +@admin.register(User) +class UserAdmin(BaseUserAdmin): + list_display = [ + 'id', 'username', 'real_name', 'phone', 'role_type', + 'institution', 'department', 'current_level', + 'total_training_count', 'status', 'last_login_time', 'created_at' + ] + list_filter = ['role_type', 'status', 'gender', 'created_at'] + search_fields = ['username', 'real_name', 'phone'] + ordering = ['-created_at'] + fieldsets = ( + (None, {'fields': ('username', 'password')}), + ('个人信息', {'fields': ( + 'real_name', 'phone', 'avatar', 'gender', + 'title_name', 'major', 'training_stage' + )}), + ('机构信息', {'fields': ( + 'institution', 'department', 'role_type' + )}), + ('能力数据', {'fields': ( + 'competency_profile', 'weak_dimensions', 'strong_dimensions', + 'learning_target', 'current_level', + 'total_training_count', 'total_case_count' + )}), + ('AI偏好', {'fields': ('ai_preference',)}), + ('状态', {'fields': ('status', 'last_login_time', 'is_staff', 'is_superuser')}), + ) + add_fieldsets = ( + (None, { + 'classes': ('wide',), + 'fields': ('username', 'password1', 'password2', 'real_name', 'role_type'), + }), + ) + + +@admin.register(Role) +class RoleAdmin(admin.ModelAdmin): + list_display = ['id', 'role_code', 'role_name'] + search_fields = ['role_code', 'role_name'] + + +@admin.register(TeacherStudentRelation) +class TeacherStudentRelationAdmin(admin.ModelAdmin): + list_display = ['id', 'teacher', 'student', 'relation_type', 'status', 'start_time'] + list_filter = ['relation_type', 'status'] + search_fields = ['teacher__real_name', 'student__real_name'] + + +@admin.register(Institution) +class InstitutionAdmin(admin.ModelAdmin): + list_display = ['id', 'name', 'type', 'level', 'province', 'city'] + list_filter = ['type', 'province'] + search_fields = ['name'] + + +@admin.register(Department) +class DepartmentAdmin(admin.ModelAdmin): + list_display = ['id', 'name', 'institution', 'category'] + list_filter = ['category'] + search_fields = ['name', 'institution__name'] diff --git a/apps/user/apps.py b/apps/user/apps.py new file mode 100644 index 0000000..93a23dc --- /dev/null +++ b/apps/user/apps.py @@ -0,0 +1,11 @@ +from django.apps import AppConfig + + +class UserConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'apps.user' + verbose_name = '用户管理' + + def ready(self): + # 注册 drf-spectacular 认证扩展(OpenApiAuthenticationExtension 在导入时自动注册) + import apps.user.openapi # noqa: F401 diff --git a/apps/user/audit.py b/apps/user/audit.py new file mode 100644 index 0000000..a28f7ec --- /dev/null +++ b/apps/user/audit.py @@ -0,0 +1,36 @@ +import logging + +audit_logger = logging.getLogger('audit') + + +def log_login_success(user_id, phone, ip=None, ua=None): + audit_logger.info('LOGIN_SUCCESS user_id=%s phone=%s ip=%s ua=%s', user_id, phone, ip, ua) + + +def log_login_fail(phone, ip=None, reason=None): + audit_logger.warning('LOGIN_FAIL phone=%s ip=%s reason=%s', phone, ip, reason) + + +def log_register(user_id, phone): + audit_logger.info('USER_REGISTER user_id=%s phone=%s', user_id, phone) + + +def log_logout(user_id, jti=None): + audit_logger.info('USER_LOGOUT user_id=%s jti=%s', user_id, jti) + + +def log_password_change(user_id): + audit_logger.warning('PASSWORD_CHANGE user_id=%s', user_id) + + +def log_password_reset(user_id): + audit_logger.warning('PASSWORD_RESET user_id=%s', user_id) + + +def log_user_list(viewer_id, filters=None): + audit_logger.info('USER_LIST viewer=%s filters=%s', viewer_id, filters) + + +def log_case_event(event, case_id=None, user_id=None, **extra): + extra_str = ' '.join(f'{k}={v}' for k, v in extra.items()) + audit_logger.info('CASE_%s case_id=%s user_id=%s %s', event.upper(), case_id, user_id, extra_str) diff --git a/apps/user/auth/__init__.py b/apps/user/auth/__init__.py new file mode 100644 index 0000000..f4fa4cf --- /dev/null +++ b/apps/user/auth/__init__.py @@ -0,0 +1,31 @@ +from rest_framework_simplejwt.tokens import RefreshToken + +ALLOWED_ROLE_TYPES = ('student', 'doctor', 'teacher') + + +def get_tokens_for_user(user): + refresh = RefreshToken.for_user(user) + return {'access': str(refresh.access_token), 'refresh': str(refresh)} + + +def build_user_response(user): + return { + 'id': user.id, + 'username': user.username, + 'phone': user.phone, + 'real_name': user.real_name, + 'role_type': user.role_type, + 'institution': user.institution.name if user.institution_id else None, + 'department': user.department.name if user.department_id else None, + } + + +def get_client_ip(request): + xff = request.META.get('HTTP_X_FORWARDED_FOR') + if xff: + return xff.split(',')[0].strip() + return request.META.get('REMOTE_ADDR') + + +def get_user_agent(request): + return request.META.get('HTTP_USER_AGENT', '') diff --git a/apps/user/auth/login.py b/apps/user/auth/login.py new file mode 100644 index 0000000..1696baa --- /dev/null +++ b/apps/user/auth/login.py @@ -0,0 +1,150 @@ +import re + +from django.core.cache import cache +from django.utils import timezone +from rest_framework.decorators import api_view, permission_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework import serializers as drf_serializers +from drf_spectacular.utils import extend_schema, inline_serializer + +from config.exceptions import AppError +from apps.user.models import User +from apps.user.audit import log_login_success, log_login_fail +from apps.user.auth import get_tokens_for_user, build_user_response, get_client_ip, get_user_agent + +LOGIN_FAIL_MAX = 5 +LOGIN_FAIL_LOCK_SECONDS = 15 * 60 # 15 分钟 + +_LOGIN_RESPONSE = inline_serializer('LoginResponse', fields={ + 'message': drf_serializers.CharField(), + 'user': drf_serializers.DictField(help_text='用户基本信息'), + 'tokens': drf_serializers.DictField(help_text='access + refresh'), +}) + + +# ── U3 密码登录 ────────────────────────────────────────────────────────────── + +@extend_schema( + summary='U3 密码登录', + request=inline_serializer('LoginPasswordRequest', fields={ + 'phone': drf_serializers.CharField(help_text='手机号'), + 'password': drf_serializers.CharField(help_text='密码'), + }), + responses={200: _LOGIN_RESPONSE}, + tags=['认证'], +) +@api_view(['POST']) +@permission_classes([AllowAny]) +def login_password(request): + """U3 密码登录""" + data = request.data + phone = str(data.get('phone', '')) + password = str(data.get('password', '')) + + if not phone or not password: + raise AppError('AUTH_BAD_CREDENTIALS', '手机号和密码不能为空') + + ip = get_client_ip(request) + ua = get_user_agent(request) + + # 检查账号锁定 + fail_key = f'login_fail:{phone}' + fail_count = cache.get(fail_key) + if fail_count is not None and int(fail_count) >= LOGIN_FAIL_MAX: + raise AppError('AUTH_ACCOUNT_LOCKED', '登录失败次数过多,请 15 分钟后再试', status_code=423) + + # 查找用户(不区分"未注册"和"密码错",防用户名枚举) + try: + user = User.objects.select_related('institution', 'department').get(phone=phone) + except User.DoesNotExist: + log_login_fail(phone, ip=ip, reason='phone_not_found') + raise AppError('AUTH_BAD_CREDENTIALS', '手机号或密码错误') + + # 账号禁用检查 + if user.status == 0: + log_login_fail(phone, ip=ip, reason='account_disabled') + raise AppError('AUTH_ACCOUNT_DISABLED', '账号已被禁用,请联系管理员', status_code=403) + + # 校验密码 + if not user.check_password(password): + current = cache.get(fail_key) + new_count = (int(current) + 1) if current is not None else 1 + cache.set(fail_key, new_count, timeout=LOGIN_FAIL_LOCK_SECONDS) + log_login_fail(phone, ip=ip, reason='wrong_password') + raise AppError('AUTH_BAD_CREDENTIALS', '手机号或密码错误') + + # 登录成功 + cache.delete(fail_key) + user.last_login_time = timezone.now() + user.save(update_fields=['last_login_time']) + + tokens = get_tokens_for_user(user) + log_login_success(user.id, phone, ip=ip, ua=ua) + + return Response({ + 'message': '登录成功', + 'user': build_user_response(user), + 'tokens': tokens, + }) + + +# ── U4 验证码登录 ──────────────────────────────────────────────────────────── + +@extend_schema( + summary='U4 验证码登录', + request=inline_serializer('LoginCodeRequest', fields={ + 'phone': drf_serializers.CharField(help_text='手机号'), + 'code': drf_serializers.CharField(help_text='6 位短信验证码'), + }), + responses={200: _LOGIN_RESPONSE}, + tags=['认证'], +) +@api_view(['POST']) +@permission_classes([AllowAny]) +def login_code(request): + """U4 验证码登录""" + data = request.data + phone = str(data.get('phone', '')) + code = str(data.get('code', '')) + + if not re.match(r'^1[3-9]\d{9}$', phone): + raise AppError('SMS_INVALID_PHONE', '手机号格式不合法') + + if not code: + raise AppError('AUTH_CODE_INVALID', '请输入验证码') + + # 查找用户 + try: + user = User.objects.select_related('institution', 'department').get(phone=phone) + except User.DoesNotExist: + raise AppError('AUTH_PHONE_NOT_FOUND', '手机号未注册') + + if user.status == 0: + raise AppError('AUTH_ACCOUNT_DISABLED', '账号已被禁用,请联系管理员', status_code=403) + + # 校验验证码 + cache_key = f'sms:login:{phone}' + cached_code = cache.get(cache_key) + if not cached_code: + raise AppError('AUTH_CODE_EXPIRED', '验证码已过期或未发送') + if str(cached_code) != code: + raise AppError('AUTH_CODE_MISMATCH', '验证码错误') + + # 成功:清理验证码 + 清理密码失败计数 + cache.delete(cache_key) + cache.delete(f'login_fail:{phone}') + + user.last_login_time = timezone.now() + user.save(update_fields=['last_login_time']) + + tokens = get_tokens_for_user(user) + ip = get_client_ip(request) + ua = get_user_agent(request) + log_login_success(user.id, phone, ip=ip, ua=ua) + + return Response({ + 'message': '登录成功', + 'user': build_user_response(user), + 'tokens': tokens, + }) diff --git a/apps/user/auth/logout.py b/apps/user/auth/logout.py new file mode 100644 index 0000000..87faf74 --- /dev/null +++ b/apps/user/auth/logout.py @@ -0,0 +1,51 @@ +from rest_framework.decorators import api_view, permission_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework import serializers as drf_serializers +from rest_framework_simplejwt.tokens import RefreshToken +from drf_spectacular.utils import extend_schema, inline_serializer + +from apps.user.utils.jwt_redis import revoke_token +from apps.user.audit import log_logout + + +@extend_schema( + summary='U7 退出登录', + request=inline_serializer('LogoutRequest', fields={ + 'refresh': drf_serializers.CharField(help_text='refresh token'), + }), + responses={200: inline_serializer('LogoutResponse', fields={ + 'message': drf_serializers.CharField(), + })}, + tags=['认证'], +) +@api_view(['POST']) +@permission_classes([AllowAny]) +def logout(request): + """U7 退出登录 — 无论 token 是否合法均返回 200(防探测)""" + refresh_raw = request.data.get('refresh', '') + jti = None + user_id = None + + if refresh_raw: + try: + token = RefreshToken(refresh_raw) + jti = token.payload.get('jti') + exp = token.payload.get('exp') + user_id = token.payload.get('user_id') + if jti and exp: + revoke_token(jti, exp) + except Exception: + pass # 静默处理无效 token + + # 审计:优先取已认证用户,兜底取 token payload + audit_uid = None + if hasattr(request, 'user') and request.user and request.user.is_authenticated: + audit_uid = request.user.id + elif user_id: + audit_uid = user_id + + if audit_uid: + log_logout(audit_uid, jti=jti) + + return Response({'message': '已退出登录'}) diff --git a/apps/user/auth/refresh.py b/apps/user/auth/refresh.py new file mode 100644 index 0000000..fd6eaf9 --- /dev/null +++ b/apps/user/auth/refresh.py @@ -0,0 +1,54 @@ +from rest_framework.permissions import AllowAny +from rest_framework_simplejwt.views import TokenRefreshView +from rest_framework_simplejwt.tokens import RefreshToken +from rest_framework_simplejwt.exceptions import TokenError +from drf_spectacular.utils import extend_schema + +from config.exceptions import AppError +from apps.user.utils.jwt_redis import revoke_token, is_token_revoked, get_user_invalid_before + + +@extend_schema(tags=['认证']) +class CustomTokenRefreshView(TokenRefreshView): + """U8 刷新 Token — 在 simplejwt 旋转前后加入 Redis 黑名单检查 + 旧 token 吊销""" + permission_classes = [AllowAny] + authentication_classes = () + + def post(self, request, *args, **kwargs): + refresh_raw = request.data.get('refresh', '') + if not refresh_raw: + raise AppError('AUTH_TOKEN_INVALID', '请提供 refresh token', status_code=401) + + # ── 解析旧 token(必须在 super().post() 之前,因为 simplejwt 会 mutate) ── + try: + old_token = RefreshToken(refresh_raw) + except TokenError: + raise AppError('AUTH_TOKEN_INVALID', 'refresh token 无效或已过期', status_code=401) + + old_jti = old_token.payload.get('jti') + old_exp = old_token.payload.get('exp') + uid = old_token.payload.get('user_id') + iat = old_token.payload.get('iat') + + # ── Redis 黑名单检查 ── + if old_jti and is_token_revoked(old_jti): + raise AppError('AUTH_TOKEN_INVALID', 'refresh token 已被吊销', status_code=401) + + # ── 用户级失效截止检查 ── + if uid and iat is not None: + invalid_before = get_user_invalid_before(uid) + if invalid_before is not None and iat < invalid_before: + raise AppError('AUTH_TOKEN_INVALID', 'token 已失效,请重新登录', status_code=401) + + # ── 交给 simplejwt 处理旋转 ── + response = super().post(request, *args, **kwargs) + + # ── 吊销旧 refresh token ── + if old_jti and old_exp: + revoke_token(old_jti, old_exp) + + return response + + +# 函数式引用,供 urls.py 保持一致风格 +refresh_token = CustomTokenRefreshView.as_view() diff --git a/apps/user/auth/register.py b/apps/user/auth/register.py new file mode 100644 index 0000000..dabf37f --- /dev/null +++ b/apps/user/auth/register.py @@ -0,0 +1,141 @@ +import re + +from django.core.cache import cache +from django.conf import settings +from django.db import transaction, IntegrityError +from rest_framework import status +from rest_framework.decorators import api_view, permission_classes, throttle_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework import serializers as drf_serializers +from drf_spectacular.utils import extend_schema, inline_serializer + +from config.exceptions import AppError +from apps.user.models import User, Institution, Department +from apps.user.throttling import RegisterIpThrottle +from apps.user.utils.password import validate_password_strength +from apps.user.audit import log_register +from apps.user.auth import get_tokens_for_user, build_user_response, ALLOWED_ROLE_TYPES + + +@extend_schema( + summary='U2 用户注册', + request=inline_serializer('RegisterRequest', fields={ + 'phone': drf_serializers.CharField(help_text='手机号'), + 'code': drf_serializers.CharField(help_text='6 位短信验证码'), + 'password': drf_serializers.CharField(help_text='密码(>=6 位,含字母+数字)'), + 'real_name': drf_serializers.CharField(help_text='真实姓名'), + 'role_type': drf_serializers.ChoiceField( + choices=['student', 'doctor', 'teacher'], + required=False, default='student', help_text='角色类型'), + 'institution_name': drf_serializers.CharField(required=False, help_text='机构名称'), + 'department_name': drf_serializers.CharField(required=False, help_text='科室名称'), + }), + responses={201: inline_serializer('RegisterResponse', fields={ + 'message': drf_serializers.CharField(), + 'user': drf_serializers.DictField(help_text='用户基本信息'), + 'tokens': drf_serializers.DictField(help_text='access + refresh'), + })}, + tags=['认证'], +) +@api_view(['POST']) +@permission_classes([AllowAny]) +@throttle_classes([RegisterIpThrottle]) +def register(request): + """U2 用户注册(手机号 + 验证码 + 密码)""" + data = request.data + + phone = str(data.get('phone', '')) + code = str(data.get('code', '')) + password = str(data.get('password', '')) + real_name = str(data.get('real_name', '')) + role_type = str(data.get('role_type', 'student')) + institution_name = data.get('institution_name') or '' + department_name = data.get('department_name') or '' + + # ── 入参校验 ────────────────────────────────────────────────────────────── + + if not re.match(r'^1[3-9]\d{9}$', phone): + raise AppError('SMS_INVALID_PHONE', '手机号格式不合法') + + if not code or len(code) != 6 or not code.isdigit(): + raise AppError('AUTH_CODE_INVALID', '验证码必须为 6 位数字') + + if not real_name or len(real_name) < 2 or len(real_name) > 20: + raise AppError('USER_INVALID_NAME', '姓名长度应在 2-20 字符之间') + + if role_type not in ALLOWED_ROLE_TYPES: + raise AppError('AUTH_INVALID_ROLE', '角色类型无效,仅允许 student / doctor / teacher') + + # ── 密码强度 ────────────────────────────────────────────────────────────── + + pwd_errors = validate_password_strength(password, phone=phone, real_name=real_name) + if pwd_errors: + raise AppError('AUTH_PASSWORD_WEAK', pwd_errors[0], details=pwd_errors) + + # ── 验证码校验 ──────────────────────────────────────────────────────────── + + cache_key = f'sms:register:{phone}' + cached_code = cache.get(cache_key) + if not cached_code: + raise AppError('AUTH_CODE_EXPIRED', '验证码已过期或未发送') + if str(cached_code) != code: + raise AppError('AUTH_CODE_MISMATCH', '验证码错误') + + # ── 机构 / 科室解析 ────────────────────────────────────────────────────── + + institution = None + if institution_name: + try: + institution = Institution.objects.get(name=institution_name) + except Institution.DoesNotExist: + raise AppError('USER_INSTITUTION_NOT_FOUND', f'机构"{institution_name}"不存在') + except Institution.MultipleObjectsReturned: + raise AppError('USER_INSTITUTION_AMBIGUOUS', f'存在多个同名机构"{institution_name}"') + + department = None + if department_name: + qs = Department.objects.filter(name=department_name) + if institution: + qs = qs.filter(institution=institution) + cnt = qs.count() + if cnt == 0: + raise AppError('USER_DEPARTMENT_NOT_FOUND', f'科室"{department_name}"不存在') + if cnt > 1: + raise AppError('USER_DEPARTMENT_AMBIGUOUS', + f'科室"{department_name}"不唯一,请同时指定 institution_name') + department = qs.first() + + # ── 事务内创建用户 ──────────────────────────────────────────────────────── + + try: + with transaction.atomic(): + if User.objects.filter(phone=phone).exists(): + raise AppError('AUTH_PHONE_REGISTERED', '该手机号已注册') + + user = User.objects.create_user( + username=phone, + password=password, + phone=phone, + real_name=real_name, + role_type=role_type, + institution=institution, + department=department, + status=1, + ) + except AppError: + raise + except IntegrityError: + raise AppError('AUTH_PHONE_REGISTERED', '该手机号已注册') + + # ── 善后 ────────────────────────────────────────────────────────────────── + + cache.delete(cache_key) + tokens = get_tokens_for_user(user) + log_register(user.id, phone) + + return Response({ + 'message': '注册成功', + 'user': build_user_response(user), + 'tokens': tokens, + }, status=status.HTTP_201_CREATED) diff --git a/apps/user/auth/reset_password.py b/apps/user/auth/reset_password.py new file mode 100644 index 0000000..1a61faf --- /dev/null +++ b/apps/user/auth/reset_password.py @@ -0,0 +1,92 @@ +import re + +from django.core.cache import cache +from django.db import transaction +from rest_framework.decorators import api_view, permission_classes, throttle_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework import serializers as drf_serializers +from drf_spectacular.utils import extend_schema, inline_serializer + +from config.exceptions import AppError +from apps.user.models import User +from apps.user.throttling import ResetPhoneThrottle +from apps.user.utils.password import validate_password_strength +from apps.user.utils.jwt_redis import invalidate_user_tokens +from apps.user.audit import log_password_reset + + +@extend_schema( + summary='U5 找回密码', + request=inline_serializer('ResetPasswordRequest', fields={ + 'phone': drf_serializers.CharField(help_text='手机号'), + 'code': drf_serializers.CharField(help_text='6 位短信验证码'), + 'new_password': drf_serializers.CharField(help_text='新密码'), + }), + responses={200: inline_serializer('ResetPasswordResponse', fields={ + 'message': drf_serializers.CharField(), + })}, + tags=['认证'], +) +@api_view(['POST']) +@permission_classes([AllowAny]) +@throttle_classes([ResetPhoneThrottle]) +def reset_password(request): + """U5 找回密码(手机号 + 验证码 + 新密码)""" + data = request.data + phone = str(data.get('phone', '')) + code = str(data.get('code', '')) + new_password = str(data.get('new_password', '')) + + # ── 入参校验 ────────────────────────────────────────────────────────────── + + if not re.match(r'^1[3-9]\d{9}$', phone): + raise AppError('SMS_INVALID_PHONE', '手机号格式不合法') + + if not code or len(code) != 6 or not code.isdigit(): + raise AppError('AUTH_CODE_INVALID', '验证码必须为 6 位数字') + + if not new_password: + raise AppError('AUTH_PASSWORD_WEAK', '请输入新密码') + + # ── 查找用户 ────────────────────────────────────────────────────────────── + + try: + user = User.objects.get(phone=phone) + except User.DoesNotExist: + raise AppError('AUTH_PHONE_NOT_FOUND', '手机号未注册') + + # ── 验证码校验 ──────────────────────────────────────────────────────────── + + cache_key = f'sms:reset:{phone}' + cached_code = cache.get(cache_key) + if not cached_code: + raise AppError('AUTH_CODE_EXPIRED', '验证码已过期或未发送') + if str(cached_code) != code: + raise AppError('AUTH_CODE_MISMATCH', '验证码错误') + + # ── 新密码校验 ──────────────────────────────────────────────────────────── + + # 新密码不得与旧密码相同(独立错误码) + if user.check_password(new_password): + raise AppError('AUTH_PASSWORD_SAME_AS_OLD', '新密码不能与旧密码相同') + + # 密码强度校验 + pwd_errors = validate_password_strength(new_password, phone=user.phone, real_name=user.real_name) + if pwd_errors: + raise AppError('AUTH_PASSWORD_WEAK', pwd_errors[0], details=pwd_errors) + + # ── 事务内:重置密码 + 失效旧 token ────────────────────────────────────── + + with transaction.atomic(): + user.set_password(new_password) + user.save(update_fields=['password']) + invalidate_user_tokens(user.id) + + # ── 善后 ────────────────────────────────────────────────────────────────── + + cache.delete(cache_key) + cache.delete(f'login_fail:{phone}') + log_password_reset(user.id) + + return Response({'message': '密码已重置,请重新登录'}) diff --git a/apps/user/auth/send_code.py b/apps/user/auth/send_code.py new file mode 100644 index 0000000..bf2fc11 --- /dev/null +++ b/apps/user/auth/send_code.py @@ -0,0 +1,61 @@ +import re + +from django.core.cache import cache +from django.conf import settings +from rest_framework.decorators import api_view, permission_classes, throttle_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework import serializers as drf_serializers +from drf_spectacular.utils import extend_schema, inline_serializer + +from config.exceptions import AppError +from apps.user.models import User +from apps.user.throttling import SmsPhoneMinuteThrottle, SmsPhoneDayThrottle, SmsIpThrottle +from apps.user.utils.sms import generate_sms_code, get_sms_service, SmsError + + +@extend_schema( + summary='U1 发送短信验证码', + request=inline_serializer('SendCodeRequest', fields={ + 'phone': drf_serializers.CharField(help_text='手机号'), + 'scene': drf_serializers.ChoiceField(choices=['register', 'login', 'reset'], + help_text='场景:register/login/reset'), + }), + responses={200: inline_serializer('SendCodeResponse', fields={ + 'message': drf_serializers.CharField(), + })}, + tags=['认证'], +) +@api_view(['POST']) +@permission_classes([AllowAny]) +@throttle_classes([SmsPhoneMinuteThrottle, SmsPhoneDayThrottle, SmsIpThrottle]) +def send_code(request): + """U1 发送短信验证码""" + data = request.data + phone = data.get('phone', '') + scene = data.get('scene', '') + + if not re.match(r'^1[3-9]\d{9}$', str(phone)): + raise AppError('SMS_INVALID_PHONE', '手机号格式不合法') + + if scene not in ('register', 'login', 'reset'): + raise AppError('SMS_INVALID_SCENE', 'scene 参数无效,仅允许 register / login / reset') + + user_exists = User.objects.filter(phone=phone).exists() + if scene == 'register' and user_exists: + raise AppError('AUTH_PHONE_REGISTERED', '该手机号已注册') + if scene in ('login', 'reset') and not user_exists: + raise AppError('AUTH_PHONE_NOT_FOUND', '手机号未注册') + + code = generate_sms_code() + cache_key = f'sms:{scene}:{phone}' + cache.set(cache_key, code, timeout=settings.SMS_CODE_EXPIRE) + + try: + get_sms_service().send_code(phone, scene, code) + except SmsError as e: + cache.delete(cache_key) + err_code = str(e) if str(e) in ('SMS_BIZ_ERROR',) else 'SMS_PROVIDER_ERROR' + raise AppError(err_code, '短信发送失败,请稍后重试', status_code=500) + + return Response({'message': '验证码已发送'}) diff --git a/apps/user/authentication.py b/apps/user/authentication.py new file mode 100644 index 0000000..fe710cc --- /dev/null +++ b/apps/user/authentication.py @@ -0,0 +1,43 @@ +import logging + +from django.core.cache import cache +from rest_framework.exceptions import APIException +from rest_framework_simplejwt.authentication import JWTAuthentication +from rest_framework_simplejwt.exceptions import InvalidToken + +logger = logging.getLogger(__name__) + +_REVOKED = {'code': 'AUTH_TOKEN_REVOKED', 'message': 'token已被吊销,请重新登录'} +_INVALIDATED = {'code': 'AUTH_TOKEN_INVALIDATED', 'message': 'token已失效,请重新登录'} +_REDIS_DOWN = {'code': 'SYS_DEPENDENCY_DOWN', 'message': 'Redis服务不可用'} + + +class RedisBlacklistJWTAuthentication(JWTAuthentication): + """在 simplejwt 校验通过后追加 Redis 黑名单检查。""" + + def get_validated_token(self, raw_token): + token = super().get_validated_token(raw_token) + try: + jti = token.payload.get('jti') + uid = token.payload.get('user_id') + iat = token.payload.get('iat') + + if jti and cache.get(f'jwt_blacklist:{jti}'): + raise InvalidToken(_REVOKED) + + if uid and iat is not None: + invalid_before = cache.get(f'jwt_user_invalid_before:{uid}') + if invalid_before is not None and iat < int(invalid_before): + raise InvalidToken(_INVALIDATED) + + except InvalidToken: + raise + except Exception as e: + # fail-closed:Redis 不可用时拒绝请求(安全 > 可用) + logger.error('Redis error during JWT blacklist check: %s', e) + ex = APIException() + ex.status_code = 503 + ex.detail = _REDIS_DOWN + raise ex + + return token diff --git a/apps/user/management/__init__.py b/apps/user/management/__init__.py new file mode 100644 index 0000000..423df72 --- /dev/null +++ b/apps/user/management/__init__.py @@ -0,0 +1 @@ +# init diff --git a/apps/user/management/commands/__init__.py b/apps/user/management/commands/__init__.py new file mode 100644 index 0000000..423df72 --- /dev/null +++ b/apps/user/management/commands/__init__.py @@ -0,0 +1 @@ +# init diff --git a/apps/user/management/commands/init_users.py b/apps/user/management/commands/init_users.py new file mode 100644 index 0000000..4354e5e --- /dev/null +++ b/apps/user/management/commands/init_users.py @@ -0,0 +1,114 @@ +import os +from django.core.management.base import BaseCommand +from django.contrib.auth import get_user_model + +User = get_user_model() + + +class Command(BaseCommand): + help = '初始化系统用户数据' + + def add_arguments(self, parser): + parser.add_argument( + '--reset', + action='store_true', + help='重置所有用户数据', + ) + + def handle(self, *args, **options): + if options['reset']: + self.stdout.write(self.style.WARNING('正在清除所有用户...')) + User.objects.filter(is_superuser=False).delete() + self.stdout.write(self.style.SUCCESS('普通用户已清除')) + + # 创建超级管理员 + self._create_superadmin() + + # 创建测试角色用户 + self._create_test_users() + + self.stdout.write(self.style.SUCCESS('\n[完成] 用户初始化完成')) + + def _create_superadmin(self): + """创建超级管理员""" + username = 'admin' + password = 'admin123' + + user, created = User.objects.get_or_create( + username=username, + defaults={ + 'is_staff': True, + 'is_superuser': True, + 'real_name': '系统管理员', + 'role_type': 'super_admin', + 'status': 1, + } + ) + user.set_password(password) + user.save() + + if created: + self.stdout.write(self.style.SUCCESS(f'[创建] 超级管理员: {username} / {password}')) + else: + self.stdout.write(self.style.WARNING(f'[重置] 管理员密码: {username} / {password}')) + + def _create_test_users(self): + """创建测试用户""" + test_users = [ + { + 'username': 'doctor1', + 'password': 'doctor123', + 'real_name': '张医生', + 'role_type': 'doctor', + 'phone': '13800138001', + 'title_name': '主任医师', + 'major': '心内科', + }, + { + 'username': 'student1', + 'password': 'student123', + 'real_name': '李同学', + 'role_type': 'student', + 'phone': '13800138002', + 'training_stage': '规培', + 'learning_target': '掌握常见病诊断', + }, + { + 'username': 'teacher1', + 'password': 'teacher123', + 'real_name': '王老师', + 'role_type': 'teacher', + 'phone': '13800138003', + 'title_name': '副主任医师', + 'major': '呼吸内科', + }, + { + 'username': 'content_admin', + 'password': 'content123', + 'real_name': '内容管理员', + 'role_type': 'content_admin', + 'phone': '13800138004', + }, + ] + + for user_data in test_users: + password = user_data.pop('password') + user, created = User.objects.get_or_create( + username=user_data['username'], + defaults={**user_data, 'status': 1} + ) + user.set_password(password) + user.save() + + if created: + self.stdout.write( + self.style.SUCCESS( + f'[创建] 用户: {user_data["username"]} / {password} ({user_data["real_name"]})' + ) + ) + else: + self.stdout.write( + self.style.WARNING( + f'[已存在] 用户: {user_data["username"]} ({user_data["real_name"]})' + ) + ) diff --git a/apps/user/migrations/0001_initial.py b/apps/user/migrations/0001_initial.py new file mode 100644 index 0000000..7192a93 --- /dev/null +++ b/apps/user/migrations/0001_initial.py @@ -0,0 +1,128 @@ +# Generated by Django 6.0.5 on 2026-05-26 07:02 + +import django.db.models.deletion +import django.utils.timezone +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('auth', '0012_alter_user_first_name_max_length'), + ] + + operations = [ + migrations.CreateModel( + name='Institution', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('name', models.CharField(max_length=255, verbose_name='名称')), + ('type', models.CharField(choices=[('hospital', '医院'), ('school', '学校')], max_length=30, verbose_name='类型')), + ('level', models.CharField(blank=True, max_length=30, verbose_name='等级')), + ('province', models.CharField(blank=True, max_length=50, verbose_name='省份')), + ('city', models.CharField(blank=True, max_length=50, verbose_name='城市')), + ], + options={ + 'verbose_name': '机构', + 'verbose_name_plural': '机构', + 'db_table': 'institution', + }, + ), + migrations.CreateModel( + name='Role', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('role_code', models.CharField(max_length=50, unique=True, verbose_name='角色编码')), + ('role_name', models.CharField(max_length=50, verbose_name='角色名称')), + ], + options={ + 'verbose_name': '角色', + 'verbose_name_plural': '角色', + 'db_table': 'role', + }, + ), + migrations.CreateModel( + name='Department', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('name', models.CharField(max_length=100, verbose_name='科室名称')), + ('category', models.CharField(blank=True, max_length=50, verbose_name='科室分类')), + ('institution', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='user.institution', verbose_name='所属机构')), + ], + options={ + 'verbose_name': '科室', + 'verbose_name_plural': '科室', + 'db_table': 'department', + }, + ), + migrations.CreateModel( + name='User', + fields=[ + ('last_login', models.DateTimeField(blank=True, null=True, verbose_name='last login')), + ('is_superuser', models.BooleanField(default=False, help_text='Designates that this user has all permissions without explicitly assigning them.', verbose_name='superuser status')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('username', models.CharField(max_length=50, unique=True, verbose_name='用户名')), + ('password', models.CharField(max_length=255, verbose_name='密码')), + ('real_name', models.CharField(blank=True, max_length=50, verbose_name='真实姓名')), + ('phone', models.CharField(blank=True, max_length=20, verbose_name='手机号')), + ('avatar', models.CharField(blank=True, max_length=255, verbose_name='头像')), + ('gender', models.SmallIntegerField(choices=[(0, '未知'), (1, '男'), (2, '女')], default=0, verbose_name='性别')), + ('role_type', models.CharField(blank=True, max_length=30, verbose_name='主角色')), + ('title_name', models.CharField(blank=True, max_length=50, verbose_name='职称')), + ('major', models.CharField(blank=True, max_length=100, verbose_name='专业')), + ('training_stage', models.CharField(blank=True, max_length=50, verbose_name='培训阶段')), + ('learning_target', models.CharField(blank=True, max_length=255, verbose_name='学习目标')), + ('competency_profile', models.JSONField(blank=True, default=dict, verbose_name='能力画像')), + ('weak_dimensions', models.JSONField(blank=True, default=list, verbose_name='薄弱项')), + ('strong_dimensions', models.JSONField(blank=True, default=list, verbose_name='优势项')), + ('ai_preference', models.JSONField(blank=True, default=dict, verbose_name='AI训练偏好')), + ('total_training_count', models.IntegerField(default=0, verbose_name='总训练次数')), + ('total_case_count', models.IntegerField(default=0, verbose_name='完成病例数')), + ('current_level', models.CharField(blank=True, max_length=30, verbose_name='当前能力等级')), + ('status', models.SmallIntegerField(choices=[(0, '禁用'), (1, '正常')], default=1, verbose_name='状态')), + ('last_login_time', models.DateTimeField(blank=True, null=True, verbose_name='最后登录')), + ('is_staff', models.BooleanField(default=False, verbose_name='staff status')), + ('is_active', models.BooleanField(default=True, verbose_name='active')), + ('date_joined', models.DateTimeField(default=django.utils.timezone.now, verbose_name='date joined')), + ('groups', models.ManyToManyField(blank=True, help_text='The groups this user belongs to. A user will get all permissions granted to each of their groups.', related_name='user_set', related_query_name='user', to='auth.group', verbose_name='groups')), + ('user_permissions', models.ManyToManyField(blank=True, help_text='Specific permissions for this user.', related_name='user_set', related_query_name='user', to='auth.permission', verbose_name='user permissions')), + ('department', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='user.department', verbose_name='所属科室')), + ('institution', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='user.institution', verbose_name='所属机构')), + ], + options={ + 'verbose_name': '用户', + 'verbose_name_plural': '用户', + 'db_table': 'user', + }, + ), + migrations.CreateModel( + name='TeacherStudentRelation', + fields=[ + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('relation_type', models.CharField(blank=True, max_length=30, verbose_name='关系类型')), + ('start_time', models.DateTimeField(blank=True, null=True, verbose_name='开始时间')), + ('end_time', models.DateTimeField(blank=True, null=True, verbose_name='结束时间')), + ('status', models.SmallIntegerField(choices=[(0, '已结束'), (1, '进行中')], default=1, verbose_name='状态')), + ('student', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='student_relations', to=settings.AUTH_USER_MODEL, verbose_name='学员')), + ('teacher', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='teacher_relations', to=settings.AUTH_USER_MODEL, verbose_name='带教老师')), + ], + options={ + 'verbose_name': '师生关系', + 'verbose_name_plural': '师生关系', + 'db_table': 'teacher_student_relation', + }, + ), + ] diff --git a/apps/user/migrations/0002_user_phone_unique.py b/apps/user/migrations/0002_user_phone_unique.py new file mode 100644 index 0000000..42ee7f2 --- /dev/null +++ b/apps/user/migrations/0002_user_phone_unique.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.14 on 2026-05-28 01:35 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('user', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='user', + name='phone', + field=models.CharField(blank=True, max_length=20, unique=True, verbose_name='手机号'), + ), + ] diff --git a/apps/user/migrations/__init__.py b/apps/user/migrations/__init__.py new file mode 100644 index 0000000..c65238f --- /dev/null +++ b/apps/user/migrations/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/apps/user/models.py b/apps/user/models.py new file mode 100644 index 0000000..26caf88 --- /dev/null +++ b/apps/user/models.py @@ -0,0 +1,170 @@ +from django.contrib.auth.models import AbstractBaseUser, PermissionsMixin +from django.db import models +from django.utils import timezone +from django.contrib.auth.base_user import BaseUserManager + +from apps.common.models import BaseModel + + +class UserManager(BaseUserManager): + def create_user(self, username, password=None, **extra_fields): + if not username: + raise ValueError('用户名不能为空') + user = self.model(username=username, **extra_fields) + user.set_password(password) + user.save(using=self._db) + return user + + def create_superuser(self, username, password=None, **extra_fields): + extra_fields.setdefault('is_staff', True) + extra_fields.setdefault('is_superuser', True) + extra_fields.setdefault('status', 1) + return self.create_user(username, password, **extra_fields) + + +class User(AbstractBaseUser, PermissionsMixin, BaseModel): + """用户表""" + GENDER_CHOICES = [ + (0, '未知'), + (1, '男'), + (2, '女'), + ] + STATUS_CHOICES = [ + (0, '禁用'), + (1, '正常'), + ] + + id = models.BigAutoField(primary_key=True) + username = models.CharField('用户名', max_length=50, unique=True) + password = models.CharField('密码', max_length=255) + real_name = models.CharField('真实姓名', max_length=50, blank=True) + phone = models.CharField('手机号', max_length=20, unique=True, blank=True) + avatar = models.CharField('头像', max_length=255, blank=True) + gender = models.SmallIntegerField('性别', choices=GENDER_CHOICES, default=0) + role_type = models.CharField('主角色', max_length=30, blank=True) + institution = models.ForeignKey( + 'user.Institution', on_delete=models.SET_NULL, + null=True, blank=True, verbose_name='所属机构' + ) + department = models.ForeignKey( + 'user.Department', on_delete=models.SET_NULL, + null=True, blank=True, verbose_name='所属科室' + ) + title_name = models.CharField('职称', max_length=50, blank=True) + major = models.CharField('专业', max_length=100, blank=True) + training_stage = models.CharField('培训阶段', max_length=50, blank=True) + learning_target = models.CharField('学习目标', max_length=255, blank=True) + competency_profile = models.JSONField('能力画像', default=dict, blank=True) + weak_dimensions = models.JSONField('薄弱项', default=list, blank=True) + strong_dimensions = models.JSONField('优势项', default=list, blank=True) + ai_preference = models.JSONField('AI训练偏好', default=dict, blank=True) + total_training_count = models.IntegerField('总训练次数', default=0) + total_case_count = models.IntegerField('完成病例数', default=0) + current_level = models.CharField('当前能力等级', max_length=30, blank=True) + status = models.SmallIntegerField('状态', choices=STATUS_CHOICES, default=1) + last_login_time = models.DateTimeField('最后登录', null=True, blank=True) + + # Django required fields + is_staff = models.BooleanField('staff status', default=False) + is_active = models.BooleanField('active', default=True) + date_joined = models.DateTimeField('date joined', default=timezone.now) + + objects = UserManager() + + USERNAME_FIELD = 'username' + REQUIRED_FIELDS = [] + + class Meta: + db_table = 'user' + verbose_name = '用户' + verbose_name_plural = '用户' + + def __str__(self): + return self.username + + +class Role(BaseModel): + """角色表""" + id = models.BigAutoField(primary_key=True) + role_code = models.CharField('角色编码', max_length=50, unique=True) + role_name = models.CharField('角色名称', max_length=50) + + class Meta: + db_table = 'role' + verbose_name = '角色' + verbose_name_plural = '角色' + + def __str__(self): + return self.role_name + + +class TeacherStudentRelation(BaseModel): + """师生关系表""" + STATUS_CHOICES = [ + (0, '已结束'), + (1, '进行中'), + ] + + id = models.BigAutoField(primary_key=True) + teacher = models.ForeignKey( + User, on_delete=models.CASCADE, + related_name='teacher_relations', verbose_name='带教老师' + ) + student = models.ForeignKey( + User, on_delete=models.CASCADE, + related_name='student_relations', verbose_name='学员' + ) + relation_type = models.CharField('关系类型', max_length=30, blank=True) + start_time = models.DateTimeField('开始时间', null=True, blank=True) + end_time = models.DateTimeField('结束时间', null=True, blank=True) + status = models.SmallIntegerField('状态', choices=STATUS_CHOICES, default=1) + + class Meta: + db_table = 'teacher_student_relation' + verbose_name = '师生关系' + verbose_name_plural = '师生关系' + + def __str__(self): + return f"{self.teacher.real_name or self.teacher.username} -> {self.student.real_name or self.student.username}" + + +class Institution(BaseModel): + """医院/学校表""" + TYPE_CHOICES = [ + ('hospital', '医院'), + ('school', '学校'), + ] + + id = models.BigAutoField(primary_key=True) + name = models.CharField('名称', max_length=255) + type = models.CharField('类型', max_length=30, choices=TYPE_CHOICES) + level = models.CharField('等级', max_length=30, blank=True) + province = models.CharField('省份', max_length=50, blank=True) + city = models.CharField('城市', max_length=50, blank=True) + + class Meta: + db_table = 'institution' + verbose_name = '机构' + verbose_name_plural = '机构' + + def __str__(self): + return self.name + + +class Department(BaseModel): + """科室表""" + id = models.BigAutoField(primary_key=True) + institution = models.ForeignKey( + Institution, on_delete=models.CASCADE, + verbose_name='所属机构' + ) + name = models.CharField('科室名称', max_length=100) + category = models.CharField('科室分类', max_length=50, blank=True) + + class Meta: + db_table = 'department' + verbose_name = '科室' + verbose_name_plural = '科室' + + def __str__(self): + return self.name diff --git a/apps/user/openapi.py b/apps/user/openapi.py new file mode 100644 index 0000000..19e9ee2 --- /dev/null +++ b/apps/user/openapi.py @@ -0,0 +1,17 @@ +"""drf-spectacular 扩展:注册自定义认证类的 OpenAPI 映射。""" + +from drf_spectacular.extensions import OpenApiAuthenticationExtension + + +class RedisBlacklistJWTScheme(OpenApiAuthenticationExtension): + """让 drf-spectacular 将 RedisBlacklistJWTAuthentication 识别为 Bearer JWT。""" + + target_class = 'apps.user.authentication.RedisBlacklistJWTAuthentication' + name = 'jwtAuth' + + def get_security_definition(self, auto_schema): + return { + 'type': 'http', + 'scheme': 'bearer', + 'bearerFormat': 'JWT', + } diff --git a/apps/user/permissions.py b/apps/user/permissions.py new file mode 100644 index 0000000..31c4f1f --- /dev/null +++ b/apps/user/permissions.py @@ -0,0 +1,48 @@ +from rest_framework.permissions import BasePermission + +from config.exceptions import AppError +from apps.user.models import TeacherStudentRelation + + +def _is_admin(user): + """管理员判定:super_admin / content_admin / is_staff""" + return user.role_type in ('super_admin', 'content_admin') or user.is_staff + + +class IsUserListPermitted(BasePermission): + """U9 用户列表权限:管理员全员、教师仅看自己学生、其他 403""" + + def has_permission(self, request, view): + user = request.user + if _is_admin(user): + return True + if user.role_type == 'teacher': + return True + raise AppError('USER_NO_LIST_PERMISSION', '您没有查看用户列表的权限', status_code=403) + + +class IsUserDetailPermitted(BasePermission): + """U10 用户详情权限:管理员任意、本人、教师看自己学生""" + + def has_object_permission(self, request, view, obj): + user = request.user + # 管理员:可查看任意用户 + if _is_admin(user): + return True + # 本人:可查看自己 + if user.id == obj.id: + return True + # 教师:可查看自己名下活跃学生 + if user.role_type == 'teacher': + if TeacherStudentRelation.objects.filter( + teacher=user, student=obj, status=1 + ).exists(): + return True + raise AppError('USER_NO_VIEW_PERMISSION', '您没有查看该用户信息的权限', status_code=403) + + +class IsCaseOperationPermitted(BasePermission): + """病例操作权限:所有已登录用户均可操作""" + + def has_permission(self, request, view): + return request.user and request.user.is_authenticated diff --git a/apps/user/serializers.py b/apps/user/serializers.py new file mode 100644 index 0000000..526da35 --- /dev/null +++ b/apps/user/serializers.py @@ -0,0 +1,84 @@ +from rest_framework import serializers +from .models import User, Role, TeacherStudentRelation, Institution, Department + + +class UserSerializer(serializers.ModelSerializer): + institution_name = serializers.CharField(source='institution.name', read_only=True) + department_name = serializers.CharField(source='department.name', read_only=True) + + class Meta: + model = User + fields = [ + 'id', 'username', 'real_name', 'phone', 'avatar', + 'gender', 'role_type', 'institution', 'institution_name', + 'department', 'department_name', 'title_name', 'major', + 'training_stage', 'learning_target', 'competency_profile', + 'weak_dimensions', 'strong_dimensions', 'ai_preference', + 'total_training_count', 'total_case_count', 'current_level', + 'status', 'last_login_time', 'created_at', 'updated_at' + ] + extra_kwargs = { + 'password': {'write_only': True}, + } + + +class UserCreateSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = [ + 'id', 'username', 'password', 'real_name', 'phone', + 'gender', 'role_type', 'institution', 'department', + 'title_name', 'major', 'training_stage', 'status' + ] + extra_kwargs = { + 'password': {'write_only': True}, + } + + def create(self, validated_data): + user = User.objects.create_user(**validated_data) + return user + + +class UserUpdateSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = [ + 'real_name', 'phone', 'avatar', 'gender', 'role_type', + 'institution', 'department', 'title_name', 'major', + 'training_stage', 'learning_target', 'status' + ] + + +class UserPasswordSerializer(serializers.Serializer): + """密码修改序列化器""" + old_password = serializers.CharField(required=True, write_only=True) + new_password = serializers.CharField(required=True, write_only=True, min_length=6) + + +class RoleSerializer(serializers.ModelSerializer): + class Meta: + model = Role + fields = '__all__' + + +class TeacherStudentRelationSerializer(serializers.ModelSerializer): + teacher_name = serializers.CharField(source='teacher.real_name', read_only=True) + student_name = serializers.CharField(source='student.real_name', read_only=True) + + class Meta: + model = TeacherStudentRelation + fields = '__all__' + + +class InstitutionSerializer(serializers.ModelSerializer): + class Meta: + model = Institution + fields = '__all__' + + +class DepartmentSerializer(serializers.ModelSerializer): + institution_name = serializers.CharField(source='institution.name', read_only=True) + + class Meta: + model = Department + fields = '__all__' diff --git a/apps/user/throttling.py b/apps/user/throttling.py new file mode 100644 index 0000000..6bd4918 --- /dev/null +++ b/apps/user/throttling.py @@ -0,0 +1,72 @@ +from rest_framework.throttling import SimpleRateThrottle + + +class SmsPhoneMinuteThrottle(SimpleRateThrottle): + """短信验证码:同一手机号 60 秒 1 次""" + scope = 'sms_phone_minute' + + def get_cache_key(self, request, view): + phone = request.data.get('phone') or request.query_params.get('phone') + if not phone: + return None + return self.cache_format % {'scope': self.scope, 'ident': phone} + + +class SmsPhoneDayThrottle(SimpleRateThrottle): + """短信验证码:同一手机号 24 小时 ≤ 10 次""" + scope = 'sms_phone_day' + + def get_cache_key(self, request, view): + phone = request.data.get('phone') or request.query_params.get('phone') + if not phone: + return None + return self.cache_format % {'scope': self.scope, 'ident': phone} + + +class SmsIpThrottle(SimpleRateThrottle): + """短信验证码:同一 IP 1 小时 ≤ 30 次""" + scope = 'sms_ip' + + def get_cache_key(self, request, view): + ident = self.get_ident(request) + return self.cache_format % {'scope': self.scope, 'ident': ident} + + +class RegisterIpThrottle(SimpleRateThrottle): + """注册:同一 IP 1 小时 ≤ 10 次""" + scope = 'register_ip' + + def get_cache_key(self, request, view): + ident = self.get_ident(request) + return self.cache_format % {'scope': self.scope, 'ident': ident} + + +class ResetPhoneThrottle(SimpleRateThrottle): + """找回密码:同手机号 1 小时 ≤ 5 次""" + scope = 'reset_phone' + + def get_cache_key(self, request, view): + phone = request.data.get('phone') or request.query_params.get('phone') + if not phone: + return None + return self.cache_format % {'scope': self.scope, 'ident': phone} + + +class PdfParseUserThrottle(SimpleRateThrottle): + """PDF 解析:同用户 1 小时 ≤ 20 次""" + scope = 'pdf_parse_user' + + def get_cache_key(self, request, view): + if not request.user or not request.user.is_authenticated: + return None + return self.cache_format % {'scope': self.scope, 'ident': request.user.id} + + +class ScoringRuleGenerateUserThrottle(SimpleRateThrottle): + """AI 生成评分规则:同用户 1 小时 ≤ 20 次""" + scope = 'scoring_rule_generate_user' + + def get_cache_key(self, request, view): + if not request.user or not request.user.is_authenticated: + return None + return self.cache_format % {'scope': self.scope, 'ident': request.user.id} diff --git a/apps/user/urls.py b/apps/user/urls.py new file mode 100644 index 0000000..7c5f2dd --- /dev/null +++ b/apps/user/urls.py @@ -0,0 +1,28 @@ +from django.urls import path, include +from rest_framework.routers import DefaultRouter +from . import views +from .auth.send_code import send_code +from .auth.register import register +from .auth.login import login_password, login_code +from .auth.logout import logout +from .auth.refresh import refresh_token +from .auth.reset_password import reset_password + +router = DefaultRouter() +router.register(r'users', views.UserViewSet, basename='user') +router.register(r'roles', views.RoleViewSet, basename='role') +router.register(r'teacher-student-relations', views.TeacherStudentRelationViewSet, basename='teacher-student-relation') +router.register(r'institutions', views.InstitutionViewSet, basename='institution') +router.register(r'departments', views.DepartmentViewSet, basename='department') + +urlpatterns = [ + path('', include(router.urls)), + # 认证相关 + path('auth/send-code/', send_code, name='send-code'), + path('auth/register/', register, name='register'), + path('auth/login/', login_password, name='login-password'), + path('auth/login-code/', login_code, name='login-code'), + path('auth/logout/', logout, name='logout'), + path('auth/refresh/', refresh_token, name='refresh-token'), + path('auth/reset-password/', reset_password, name='reset-password'), +] diff --git a/apps/user/utils/__init__.py b/apps/user/utils/__init__.py new file mode 100644 index 0000000..c18ebbb --- /dev/null +++ b/apps/user/utils/__init__.py @@ -0,0 +1 @@ +# utils diff --git a/apps/user/utils/jwt_redis.py b/apps/user/utils/jwt_redis.py new file mode 100644 index 0000000..1fee4f6 --- /dev/null +++ b/apps/user/utils/jwt_redis.py @@ -0,0 +1,31 @@ +import time + +from django.core.cache import cache + +_BLACKLIST_PREFIX = 'jwt_blacklist:' +_INVALID_BEFORE_PREFIX = 'jwt_user_invalid_before:' +_USER_TOKEN_TTL = 7 * 24 * 3600 # 覆盖 refresh 最长生命周期 + + +def revoke_token(jti: str, exp_ts: float) -> None: + """将单个 token 加入黑名单。U7 退出 / U8 旋转旧 refresh 时调用。""" + ttl = max(int(exp_ts - time.time()), 1) + cache.set(f'{_BLACKLIST_PREFIX}{jti}', '1', timeout=ttl) + + +def invalidate_user_tokens(user_id: int) -> None: + """用户级失效截止:写入当前时间戳+1,此前所有 token 立即失效。U5/U6 改密后调用。 + + +1 确保同一秒内签发的旧 token 被拒绝,而改密后新登录签发的 token(iat >= now+1)被放行。 + """ + cache.set(f'{_INVALID_BEFORE_PREFIX}{user_id}', int(time.time()) + 1, timeout=_USER_TOKEN_TTL) + + +def is_token_revoked(jti: str) -> bool: + return bool(cache.get(f'{_BLACKLIST_PREFIX}{jti}')) + + +def get_user_invalid_before(user_id: int): + """返回用户级失效截止时间戳(unix seconds),不存在则返回 None。""" + val = cache.get(f'{_INVALID_BEFORE_PREFIX}{user_id}') + return int(val) if val is not None else None diff --git a/apps/user/utils/password.py b/apps/user/utils/password.py new file mode 100644 index 0000000..5a0d193 --- /dev/null +++ b/apps/user/utils/password.py @@ -0,0 +1,36 @@ +import re +from typing import Callable, Optional + + +def validate_password_strength( + password: str, + phone: Optional[str] = None, + real_name: Optional[str] = None, + old_password_check: Optional[Callable[[str], bool]] = None, +) -> list: + """ + 校验密码强度,返回错误信息列表。列表为空表示通过。 + + old_password_check: 传入 password 返回 True 表示与旧密码相同。 + """ + errors = [] + + if len(password) < 8 or len(password) > 32: + errors.append('密码长度必须在 8-32 位之间') + + if not re.search(r'[a-zA-Z]', password): + errors.append('密码必须包含字母') + + if not re.search(r'\d', password): + errors.append('密码必须包含数字') + + if phone and password == phone: + errors.append('密码不能与手机号相同') + + if real_name and password == real_name: + errors.append('密码不能与真实姓名相同') + + if old_password_check is not None and old_password_check(password): + errors.append('新密码不能与旧密码相同') + + return errors diff --git a/apps/user/utils/sms.py b/apps/user/utils/sms.py new file mode 100644 index 0000000..ef3eb8b --- /dev/null +++ b/apps/user/utils/sms.py @@ -0,0 +1,33 @@ +import random +import string +import logging +from abc import ABC, abstractmethod + +from django.conf import settings + +logger = logging.getLogger(__name__) + + +def generate_sms_code(length=6) -> str: + return ''.join(random.choices(string.digits, k=length)) + + +# ── 策略接口 ────────────────────────────────────────────────────────────────── + +class SmsService(ABC): + @abstractmethod + def send_code(self, phone: str, scene: str, code: str) -> None: + """发送验证码短信。失败时抛出 SmsError。""" + + +class SmsError(Exception): + pass + + +def get_sms_service() -> SmsService: + provider = getattr(settings, 'SMS_PROVIDER', 'mock') + if provider == 'aliyun': + from apps.user.utils.sms_aliyun import AliyunSmsService + return AliyunSmsService() + from apps.user.utils.sms_mock import MockSmsService + return MockSmsService() diff --git a/apps/user/utils/sms_aliyun.py b/apps/user/utils/sms_aliyun.py new file mode 100644 index 0000000..5a8841b --- /dev/null +++ b/apps/user/utils/sms_aliyun.py @@ -0,0 +1,57 @@ +import logging + +from django.conf import settings + +logger = logging.getLogger(__name__) + +_SCENE_TEMPLATE = { + 'register': 'ALIYUN_SMS_TEMPLATE_REGISTER', + 'login': 'ALIYUN_SMS_TEMPLATE_LOGIN', + 'reset': 'ALIYUN_SMS_TEMPLATE_RESET', +} + + +class AliyunSmsService: + """阿里云短信实现。SDK 按需 import,避免未安装时影响 mock 模式启动。""" + + def send_code(self, phone: str, scene: str, code: str) -> None: + from apps.user.utils.sms import SmsError + + try: + from alibabacloud_dysmsapi20170525.client import Client + from alibabacloud_dysmsapi20170525 import models as sms_models + from alibabacloud_tea_openapi import models as open_api_models + except ImportError as e: + logger.error('Aliyun SMS SDK not installed: %s', e) + raise SmsError('SMS_PROVIDER_ERROR') from e + + config = open_api_models.Config( + access_key_id=settings.ALIYUN_SMS_ACCESS_KEY_ID, + access_key_secret=settings.ALIYUN_SMS_ACCESS_KEY_SECRET, + ) + config.endpoint = 'dysmsapi.aliyuncs.com' + client = Client(config) + + template_attr = _SCENE_TEMPLATE.get(scene, '') + template_code = getattr(settings, template_attr, '') + + req = sms_models.SendSmsRequest( + phone_numbers=phone, + sign_name=settings.ALIYUN_SMS_SIGN_NAME, + template_code=template_code, + template_param=f'{{"code":"{code}"}}', + ) + + try: + resp = client.send_sms(req) + if resp.body.code != 'OK': + logger.error( + 'Aliyun SMS biz error: code=%s msg=%s request_id=%s', + resp.body.code, resp.body.message, resp.body.request_id, + ) + raise SmsError('SMS_BIZ_ERROR') + except SmsError: + raise + except Exception as e: + logger.error('Aliyun SMS provider error: %s', e) + raise SmsError('SMS_PROVIDER_ERROR') from e diff --git a/apps/user/utils/sms_mock.py b/apps/user/utils/sms_mock.py new file mode 100644 index 0000000..dd9afe7 --- /dev/null +++ b/apps/user/utils/sms_mock.py @@ -0,0 +1,11 @@ +import logging + +logger = logging.getLogger(__name__) + + +class MockSmsService: + """开发环境短信实现:打印到控制台 + 写日志。不在响应里回填 code。""" + + def send_code(self, phone: str, scene: str, code: str) -> None: + print(f'[SMS-MOCK] phone={phone} scene={scene} code={code}') + logger.info('[SMS-MOCK] phone=%s scene=%s code=%s', phone, scene, code) diff --git a/apps/user/views.py b/apps/user/views.py new file mode 100644 index 0000000..05268b6 --- /dev/null +++ b/apps/user/views.py @@ -0,0 +1,198 @@ +from rest_framework import viewsets, filters, status +from rest_framework.decorators import action +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from django_filters.rest_framework import DjangoFilterBackend + +from config.exceptions import AppError +from .models import User, Role, TeacherStudentRelation, Institution, Department +from .serializers import ( + UserSerializer, UserCreateSerializer, UserUpdateSerializer, + RoleSerializer, + TeacherStudentRelationSerializer, InstitutionSerializer, DepartmentSerializer +) +from .permissions import IsUserListPermitted, IsUserDetailPermitted +from .utils.password import validate_password_strength +from .utils.jwt_redis import invalidate_user_tokens +from .audit import log_password_change, log_password_reset, log_user_list + + +class UserViewSet(viewsets.ModelViewSet): + """用户管理 + + list: 获取用户列表(支持过滤、搜索、排序)— U9 角色分级权限 + create: 创建用户 + retrieve: 获取用户详情 — U10 对象级权限 + update: 更新用户信息 + destroy: 删除用户 + """ + queryset = User.objects.all() # 保留供 DRF router basename 检测 + filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter] + filterset_fields = ['role_type', 'status', 'institution', 'department', 'gender'] + search_fields = ['username', 'real_name', 'phone'] + ordering_fields = ['created_at', 'last_login_time', 'total_training_count'] + + # ── 权限分派 ────────────────────────────────────────────────────────────── + + def get_permissions(self): + if self.action == 'list': + return [IsAuthenticated(), IsUserListPermitted()] + elif self.action == 'retrieve': + return [IsAuthenticated(), IsUserDetailPermitted()] + return super().get_permissions() + + # ── 序列化器分派 ────────────────────────────────────────────────────────── + + def get_serializer_class(self): + if self.action == 'create': + return UserCreateSerializer + elif self.action in ['update', 'partial_update']: + return UserUpdateSerializer + return UserSerializer + + # ── 查询集:U9 角色分级 + N+1 优化 ─────────────────────────────────────── + + def get_queryset(self): + qs = User.objects.select_related('institution', 'department') + user = self.request.user + + if not user.is_authenticated: + return qs.none() + + # list 动作按角色限制可见范围 + if self.action == 'list': + if user.role_type in ('super_admin', 'content_admin') or user.is_staff: + return qs # 管理员:全员 + elif user.role_type == 'teacher': + # 教师:仅自己名下活跃学生 + student_ids = TeacherStudentRelation.objects.filter( + teacher=user, status=1 + ).values_list('student_id', flat=True) + return qs.filter(id__in=student_ids, role_type='student') + else: + return qs.none() # 兜底(被 IsUserListPermitted 拦截在前) + + return qs + + # ── U9 list 审计 ───────────────────────────────────────────────────────── + + def list(self, request, *args, **kwargs): + filters_dict = {k: v for k, v in request.query_params.items() + if k in ('role_type', 'status', 'search', 'ordering', 'page')} + log_user_list(request.user.id, filters=filters_dict) + return super().list(request, *args, **kwargs) + + # ── U6 修改密码(已登录) ───────────────────────────────────────────────── + + @action(detail=False, methods=['post'], url_path='change-password') + def change_password(self, request): + """U6 修改当前用户密码""" + old_password = request.data.get('old_password', '') + new_password = request.data.get('new_password', '') + + if not old_password or not new_password: + raise AppError('VALIDATION_ERROR', '请提供旧密码和新密码') + + user = request.user + + # 1. 校验旧密码 + if not user.check_password(old_password): + raise AppError('AUTH_BAD_OLD_PASSWORD', '原密码错误') + + # 2. 新密码不得与旧密码相同 + if user.check_password(new_password): + raise AppError('AUTH_PASSWORD_SAME_AS_OLD', '新密码不能与旧密码相同') + + # 3. 密码强度校验 + pwd_errors = validate_password_strength( + new_password, + phone=user.phone, + real_name=user.real_name, + ) + if pwd_errors: + raise AppError('AUTH_PASSWORD_WEAK', pwd_errors[0], details=pwd_errors) + + # 4. 改密 + 全设备登出 + user.set_password(new_password) + user.save(update_fields=['password']) + invalidate_user_tokens(user.id) + + # 5. 审计 + log_password_change(user.id) + + return Response({'message': '密码修改成功,请重新登录'}) + + # ── 管理员重置用户密码 ──────────────────────────────────────────────────── + + @action(detail=True, methods=['post'], url_path='reset-password') + def reset_password(self, request, pk=None): + """重置用户密码(管理员操作)""" + user = self.get_object() + new_password = request.data.get('password', '') + + if not new_password: + raise AppError('VALIDATION_ERROR', '请提供新密码') + + # 密码强度校验 + pwd_errors = validate_password_strength(new_password, phone=user.phone, real_name=user.real_name) + if pwd_errors: + raise AppError('AUTH_PASSWORD_WEAK', pwd_errors[0], details=pwd_errors) + + user.set_password(new_password) + user.save(update_fields=['password']) + invalidate_user_tokens(user.id) + + log_password_reset(user.id) + + return Response({'message': '密码重置成功'}) + + # ── 当前用户信息 ───────────────────────────────────────────────────────── + + @action(detail=False, methods=['get'], url_path='me') + def me(self, request): + """获取当前登录用户信息""" + serializer = self.get_serializer(request.user) + return Response(serializer.data) + + +class RoleViewSet(viewsets.ModelViewSet): + """角色管理""" + queryset = Role.objects.all() + serializer_class = RoleSerializer + filter_backends = [filters.SearchFilter] + search_fields = ['role_code', 'role_name'] + + +class TeacherStudentRelationViewSet(viewsets.ModelViewSet): + """师生关系管理""" + queryset = TeacherStudentRelation.objects.all() + serializer_class = TeacherStudentRelationSerializer + filter_backends = [DjangoFilterBackend, filters.OrderingFilter] + filterset_fields = ['teacher', 'student', 'relation_type', 'status'] + ordering_fields = ['start_time', 'end_time', 'created_at'] + + +class InstitutionViewSet(viewsets.ModelViewSet): + """机构管理(医院/学校)""" + queryset = Institution.objects.all() + serializer_class = InstitutionSerializer + filter_backends = [DjangoFilterBackend, filters.SearchFilter] + filterset_fields = ['type', 'province', 'city'] + search_fields = ['name'] + + @action(detail=True, methods=['get']) + def departments(self, request, pk=None): + """获取机构下的科室列表""" + institution = self.get_object() + departments = institution.department_set.all() + serializer = DepartmentSerializer(departments, many=True) + return Response(serializer.data) + + +class DepartmentViewSet(viewsets.ModelViewSet): + """科室管理""" + queryset = Department.objects.all() + serializer_class = DepartmentSerializer + filter_backends = [DjangoFilterBackend, filters.SearchFilter] + filterset_fields = ['institution', 'category'] + search_fields = ['name'] diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config/asgi.py b/config/asgi.py new file mode 100644 index 0000000..ffbb5f5 --- /dev/null +++ b/config/asgi.py @@ -0,0 +1,16 @@ +""" +ASGI config for config project. + +It exposes the ASGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/6.0/howto/deployment/asgi/ +""" + +import os + +from django.core.asgi import get_asgi_application + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') + +application = get_asgi_application() diff --git a/config/exceptions.py b/config/exceptions.py new file mode 100644 index 0000000..275d5fe --- /dev/null +++ b/config/exceptions.py @@ -0,0 +1,77 @@ +import uuid +import logging + +from rest_framework.views import exception_handler as drf_exception_handler +from rest_framework.response import Response +from rest_framework.exceptions import APIException, ValidationError + +logger = logging.getLogger(__name__) + +_STATUS_TO_CODE = { + 400: 'BAD_REQUEST', + 401: 'AUTH_UNAUTHORIZED', + 403: 'PERMISSION_DENIED', + 404: 'NOT_FOUND', + 405: 'METHOD_NOT_ALLOWED', + 429: 'SYS_RATE_LIMIT', + 500: 'SYS_INTERNAL', + 503: 'SYS_DEPENDENCY_DOWN', +} + + +def custom_exception_handler(exc, context): + trace_id = uuid.uuid4().hex[:12] + response = drf_exception_handler(exc, context) + + if response is None: + logger.exception('Unhandled server error trace_id=%s', trace_id) + return Response( + {'code': 'SYS_INTERNAL', 'message': '服务器内部错误', 'details': None, 'trace_id': trace_id}, + status=500, + ) + + data = response.data + status_code = response.status_code + + # Already structured by AppError or our custom exceptions + if isinstance(data, dict) and 'code' in data and 'message' in data: + response.data = {**data, 'trace_id': trace_id} + return response + + # DRF ValidationError: {'field': ['error msg']} + if isinstance(exc, ValidationError): + response.data = { + 'code': 'VALIDATION_ERROR', + 'message': '请求参数不合法', + 'details': data, + 'trace_id': trace_id, + } + return response + + # Standard DRF exceptions (AuthenticationFailed, PermissionDenied, Throttled, etc.) + message = _extract_message(data) + response.data = { + 'code': _STATUS_TO_CODE.get(status_code, 'SYS_INTERNAL'), + 'message': message, + 'details': None, + 'trace_id': trace_id, + } + return response + + +def _extract_message(data): + if isinstance(data, dict): + return str(data.get('detail', data)) + if isinstance(data, list) and data: + return str(data[0]) + return str(data) + + +class AppError(APIException): + """统一业务异常。视图中 raise AppError('CODE', '消息', details, status_code) 即可。""" + + def __init__(self, code, message, details=None, status_code=400): + self.status_code = status_code + super().__init__(detail=code) + # 绕过 DRF _get_error_details,防止 None 被转成字符串 "None" + self.detail = {'code': code, 'message': message, 'details': details} diff --git a/config/logging_handlers.py b/config/logging_handlers.py new file mode 100644 index 0000000..7d800e6 --- /dev/null +++ b/config/logging_handlers.py @@ -0,0 +1,93 @@ +"""自定义日志 Handler:按日期分文件,兼容 Windows 文件锁。 + +替代 TimedRotatingFileHandler,避免 Windows 上因 rename 操作遇到文件锁 +(PermissionError: [WinError 32])导致日志丢失的问题。 + +文件命名规则:{prefix}-YYYY-MM-DD.log +日期切换时自动打开新文件,无需 rename 旧文件。 +""" + +import logging +import time +from pathlib import Path + + +class DailyFileHandler(logging.Handler): + """按日期自动分文件的日志 Handler。 + + 与 TimedRotatingFileHandler 的关键区别: + - 文件直接以日期命名,日期切换时打开新文件,**不 rename 旧文件** + - 多进程(dev server + 测试 / 管理命令)可同时写入,互不阻塞 + - 自动清理超过 backup_count 天的旧文件 + + dictConfig 用法:: + + 'audit_file': { + 'class': 'config.logging_handlers.DailyFileHandler', + 'dir_path': '/path/to/logs', + 'prefix': 'audit', + 'backup_count': 30, + 'formatter': 'verbose', + } + """ + + def __init__(self, dir_path, prefix='audit', backup_count=30, encoding='utf-8'): + super().__init__() + self.dir_path = Path(dir_path) + self.dir_path.mkdir(parents=True, exist_ok=True) + self.prefix = prefix + self.backup_count = backup_count + self.encoding = encoding + self._current_date = None + self._stream = None + self._open_today() + + def _today(self): + return time.strftime('%Y-%m-%d') + + def _open_today(self): + """打开当天的日志文件(追加模式)。""" + if self._stream: + try: + self._stream.close() + except OSError: + pass + self._current_date = self._today() + filepath = self.dir_path / f'{self.prefix}-{self._current_date}.log' + self._stream = open(filepath, 'a', encoding=self.encoding) + + def emit(self, record): + try: + today = self._today() + if today != self._current_date: + self._open_today() + self._cleanup_old_files() + msg = self.format(record) + self._stream.write(msg + '\n') + self._stream.flush() + except Exception: + self.handleError(record) + + def close(self): + self.acquire() + try: + if self._stream: + try: + self._stream.close() + except OSError: + pass + self._stream = None + finally: + self.release() + super().close() + + def _cleanup_old_files(self): + """删除超过 backup_count 天的旧日志文件。""" + if self.backup_count <= 0: + return + try: + files = sorted(self.dir_path.glob(f'{self.prefix}-*.log')) + for f in files[:-self.backup_count]: + f.unlink(missing_ok=True) + except OSError: + pass diff --git a/config/middleware.py b/config/middleware.py new file mode 100644 index 0000000..71b65fc --- /dev/null +++ b/config/middleware.py @@ -0,0 +1,138 @@ +"""API 请求/响应日志中间件。 + +记录每次 API 调用的完整信息:方法、路径、请求头、查询参数、请求体、 +响应状态码、响应头、响应体。 +日志输出到 `api_access` logger → `logs/api-access-YYYY-MM-DD.log`。 +""" + +import json +import logging +import time + +api_logger = logging.getLogger('api_access') + +# 跳过日志的路径前缀(静态文件、admin 等) +_SKIP_PREFIXES = ('/static/', '/admin/', '/api/schema/', '/api/docs/') + +# 请求体/响应体最大记录长度(字符) +_MAX_BODY_LEN = 2000 + +# 需要记录的请求头(Django 中请求头加 HTTP_ 前缀并大写) +_LOG_REQUEST_HEADERS = ( + 'HTTP_AUTHORIZATION', + 'HTTP_ACCEPT', + 'HTTP_USER_AGENT', + 'HTTP_X_FORWARDED_FOR', + 'HTTP_X_REAL_IP', +) + + +def _truncate(text, max_len=_MAX_BODY_LEN): + if len(text) > max_len: + return text[:max_len] + f'... (truncated, total {len(text)} chars)' + return text + + +def _safe_json_body(body_bytes, content_type=''): + """尝试将请求/响应体解析为 JSON,失败则返回原始文本摘要。""" + if not body_bytes: + return '' + # multipart (文件上传) 不记录原始内容 + if 'multipart' in content_type: + return f'' + try: + text = body_bytes.decode('utf-8') + data = json.loads(text) + return _truncate(json.dumps(data, ensure_ascii=False)) + except (UnicodeDecodeError, json.JSONDecodeError): + return _truncate(body_bytes.decode('utf-8', errors='replace')) + + +class APIAccessLogMiddleware: + """记录 /api/ 请求与响应的中间件。""" + + def __init__(self, get_response): + self.get_response = get_response + + @staticmethod + def _collect_request_headers(request): + """提取关键请求头。""" + headers = {} + ct = request.content_type or request.META.get('CONTENT_TYPE', '') + if ct: + headers['Content-Type'] = ct + for meta_key in _LOG_REQUEST_HEADERS: + val = request.META.get(meta_key) + if val: + name = meta_key.replace('HTTP_', '').replace('_', '-').title() + headers[name] = val + return headers + + @staticmethod + def _collect_response_headers(response): + """提取关键响应头。""" + keys = ('Content-Type', 'Content-Length', 'Allow', + 'X-Request-Id', 'Retry-After', 'WWW-Authenticate') + headers = {} + for k in keys: + v = response.get(k) + if v: + headers[k] = v + return headers + + def __call__(self, request): + path = request.path + + if not path.startswith('/api/') or any(path.startswith(p) for p in _SKIP_PREFIXES): + return self.get_response(request) + + start = time.time() + + req_headers = self._collect_request_headers(request) + + req_content_type = request.content_type or '' + if 'multipart' in req_content_type: + form_fields = dict(request.POST) + file_info = { + name: f'' + for name, f in request.FILES.items() + } + all_fields = {**form_fields, **file_info} + req_body = _truncate(json.dumps(all_fields, ensure_ascii=False)) if all_fields else '' + else: + req_body = _safe_json_body(request.body, req_content_type) + + query = dict(request.GET) if request.GET else '' + + user_id = None + + response = self.get_response(request) + + if hasattr(request, 'user') and request.user and request.user.is_authenticated: + user_id = request.user.id + + duration_ms = int((time.time() - start) * 1000) + + resp_headers = self._collect_response_headers(response) + + resp_content_type = response.get('Content-Type', '') + resp_body = '' + if hasattr(response, 'content'): + resp_body = _safe_json_body(response.content, resp_content_type) + + req_h_str = json.dumps(req_headers, ensure_ascii=False) if req_headers else '' + resp_h_str = json.dumps(resp_headers, ensure_ascii=False) if resp_headers else '' + + api_logger.info( + '%s %s | user=%s | query=%s | status=%s | %dms\n' + ' >>> headers: %s\n' + ' >>> body: %s\n' + ' <<< headers: %s\n' + ' <<< body: %s', + request.method, path, + user_id, query, response.status_code, duration_ms, + req_h_str, req_body, + resp_h_str, resp_body, + ) + + return response diff --git a/config/settings.py b/config/settings.py new file mode 100644 index 0000000..baac3f2 --- /dev/null +++ b/config/settings.py @@ -0,0 +1,252 @@ +import os +from pathlib import Path +from datetime import timedelta +from dotenv import load_dotenv + +BASE_DIR = Path(__file__).resolve().parent.parent +load_dotenv(BASE_DIR / '.env') + +SECRET_KEY = 'django-insecure-!-mtect5n-yyxkp2m=j(8dz_yi$b3w3ddo&w#i(@4kv-spdthy' + +DEBUG = True + +ALLOWED_HOSTS = [] + + +INSTALLED_APPS = [ + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.staticfiles', + + # Third-party apps + 'rest_framework', + 'rest_framework_simplejwt', + 'django_filters', + 'drf_spectacular', + + # Local apps + 'apps.common', + 'apps.user', + 'apps.case', + 'apps.training', +] + +MIDDLEWARE = [ + 'django.middleware.security.SecurityMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'config.middleware.APIAccessLogMiddleware', +] + +ROOT_URLCONF = 'config.urls' + +TEMPLATES = [ + { + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': [], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, +] + +WSGI_APPLICATION = 'config.wsgi.application' + + +# Database - MySQL +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.mysql', + 'NAME': os.getenv('DB_NAME', 'medical_training'), + 'USER': os.getenv('DB_USER', 'root'), + 'PASSWORD': os.getenv('DB_PASSWORD', ''), + 'HOST': os.getenv('DB_HOST', 'localhost'), + 'PORT': os.getenv('DB_PORT', '3306'), + 'OPTIONS': { + 'charset': 'utf8mb4', + 'init_command': "SET sql_mode='STRICT_TRANS_TABLES'", + }, + 'TEST': { + 'NAME': 'test_medical_training', + 'CHARSET': 'utf8mb4', + 'COLLATION': 'utf8mb4_unicode_ci', + }, + } +} + + +AUTH_PASSWORD_VALIDATORS = [ + {'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator'}, + {'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator'}, + {'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator'}, + {'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator'}, +] + + +LANGUAGE_CODE = 'zh-hans' +TIME_ZONE = 'Asia/Shanghai' +USE_I18N = True +USE_TZ = True + +STATIC_URL = 'static/' + +DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' + +# ─── REST Framework ─────────────────────────────────────────────────────────── + +REST_FRAMEWORK = { + 'EXCEPTION_HANDLER': 'config.exceptions.custom_exception_handler', + 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', + 'DEFAULT_AUTHENTICATION_CLASSES': [ + 'apps.user.authentication.RedisBlacklistJWTAuthentication', + 'rest_framework.authentication.SessionAuthentication', + ], + 'DEFAULT_PERMISSION_CLASSES': [ + 'rest_framework.permissions.IsAuthenticated', + ], + 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', + 'PAGE_SIZE': 20, + 'DEFAULT_THROTTLE_CLASSES': [], + 'DEFAULT_THROTTLE_RATES': { + 'sms_phone_minute': '1/minute', + 'sms_phone_day': '10/day', + 'sms_ip': '30/hour', + 'register_ip': '10/hour', + 'reset_phone': '5/hour', + 'pdf_parse_user': '20/hour', + 'scoring_rule_generate_user': '20/hour', + }, +} + +# ─── JWT ────────────────────────────────────────────────────────────────────── + +SIMPLE_JWT = { + 'ACCESS_TOKEN_LIFETIME': timedelta(hours=2), + 'REFRESH_TOKEN_LIFETIME': timedelta(days=7), + 'ROTATE_REFRESH_TOKENS': True, + 'BLACKLIST_AFTER_ROTATION': False, + 'ALGORITHM': 'HS256', + 'SIGNING_KEY': SECRET_KEY, + 'AUTH_HEADER_TYPES': ('Bearer',), + 'AUTH_HEADER_NAME': 'HTTP_AUTHORIZATION', +} + +# ─── Redis Cache ────────────────────────────────────────────────────────────── + +CACHES = { + 'default': { + 'BACKEND': 'django_redis.cache.RedisCache', + 'LOCATION': os.getenv('REDIS_URL', 'redis://127.0.0.1:6379/1'), + 'OPTIONS': { + 'CLIENT_CLASS': 'django_redis.client.DefaultClient', + }, + } +} + +# ─── SMS ────────────────────────────────────────────────────────────────────── + +SMS_CODE_EXPIRE = 300 # 验证码有效期(秒) +SMS_CODE_INTERVAL = 60 # 发送间隔(秒) +SMS_PROVIDER = os.getenv('SMS_PROVIDER', 'mock') + +ALIYUN_SMS_ACCESS_KEY_ID = os.getenv('ALIYUN_SMS_ACCESS_KEY_ID', '') +ALIYUN_SMS_ACCESS_KEY_SECRET = os.getenv('ALIYUN_SMS_ACCESS_KEY_SECRET', '') +ALIYUN_SMS_SIGN_NAME = os.getenv('ALIYUN_SMS_SIGN_NAME', '医疗训练平台') +ALIYUN_SMS_TEMPLATE_REGISTER = os.getenv('ALIYUN_SMS_TEMPLATE_REGISTER', '') +ALIYUN_SMS_TEMPLATE_LOGIN = os.getenv('ALIYUN_SMS_TEMPLATE_LOGIN', '') +ALIYUN_SMS_TEMPLATE_RESET = os.getenv('ALIYUN_SMS_TEMPLATE_RESET', '') + +# ─── DeepSeek ───────────────────────────────────────────────────────────────── + +DEEPSEEK_API_KEY = os.getenv('DEEPSEEK_API_KEY', '') +DEEPSEEK_BASE_URL = os.getenv('DEEPSEEK_BASE_URL', 'https://api.deepseek.com') +DEEPSEEK_MODEL = os.getenv('DEEPSEEK_MODEL', 'deepseek-chat') +DEEPSEEK_TIMEOUT_SECONDS = int(os.getenv('DEEPSEEK_TIMEOUT_SECONDS', '120')) +DEEPSEEK_MAX_RETRIES = int(os.getenv('DEEPSEEK_MAX_RETRIES', '1')) + +# ─── Spectacular (Swagger / OpenAPI) ───────────────────────────────────────── + +SPECTACULAR_SETTINGS = { + 'TITLE': 'Medical Training API', + 'DESCRIPTION': '医疗训练系统 API 文档', + 'VERSION': '1.0.0', + 'SERVE_INCLUDE_SCHEMA': False, + 'COMPONENT_SPLIT_REQUEST': True, + # 修复同名枚举冲突(User.STATUS_CHOICES 与 CaseBase.STATUS_CHOICES 值相同,共用一个名称) + 'ENUM_NAME_OVERRIDES': { + 'CaseTypeEnum': 'apps.case.models.CaseBase.CASE_TYPE_CHOICES', + 'CreatableCaseTypeEnum': [('traditional', 'traditional'), ('teaching', 'teaching')], + 'CommonStatusEnum': 'apps.case.models.CaseBase.STATUS_CHOICES', + 'PublishStatusEnum': 'apps.case.models.CaseBase.PUBLISH_STATUS_CHOICES', + 'TrainingStatusEnum': 'apps.training.models.TrainingRecord.STATUS_CHOICES', + 'TeacherStudentStatusEnum': 'apps.user.models.TeacherStudentRelation.STATUS_CHOICES', + }, +} + +# ─── Auth ───────────────────────────────────────────────────────────────────── + +AUTH_USER_MODEL = 'user.User' + +# ─── Logging ────────────────────────────────────────────────────────────────── + +LOGS_DIR = BASE_DIR / 'logs' +LOGS_DIR.mkdir(exist_ok=True) + +LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'verbose': { + 'format': '{asctime} {levelname} [{name}] {message}', + 'style': '{', + }, + }, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'formatter': 'verbose', + }, + 'audit_file': { + 'class': 'config.logging_handlers.DailyFileHandler', + 'dir_path': str(LOGS_DIR), + 'prefix': 'audit', + 'backup_count': 30, + 'formatter': 'verbose', + }, + 'api_access_file': { + 'class': 'config.logging_handlers.DailyFileHandler', + 'dir_path': str(LOGS_DIR), + 'prefix': 'api-access', + 'backup_count': 30, + 'formatter': 'verbose', + }, + }, + 'loggers': { + 'audit': { + 'handlers': ['audit_file', 'console'], + 'level': 'INFO', + 'propagate': False, + }, + 'api_access': { + 'handlers': ['api_access_file'], + 'level': 'INFO', + 'propagate': False, + }, + }, + 'root': { + 'handlers': ['console'], + 'level': 'INFO', + }, +} diff --git a/config/urls.py b/config/urls.py new file mode 100644 index 0000000..42871fe --- /dev/null +++ b/config/urls.py @@ -0,0 +1,38 @@ +""" +URL configuration for config project. + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/6.0/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: path('', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.urls import include, path + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) +""" +from django.contrib import admin +from django.urls import path, include +from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView +from drf_spectacular.views import SpectacularAPIView, SpectacularSwaggerView, SpectacularRedocView + +urlpatterns = [ + path('admin/', admin.site.urls), + + # API Routes + path('api/user/', include('apps.user.urls')), + path('api/case/', include('apps.case.urls')), + path('api/training/', include('apps.training.urls')), + + # JWT Token + path('api/token/', TokenObtainPairView.as_view(), name='token_obtain_pair'), + path('api/token/refresh/', TokenRefreshView.as_view(), name='token_refresh'), + + # API Documentation + path('api/schema/', SpectacularAPIView.as_view(), name='schema'), + path('api/docs/swagger/', SpectacularSwaggerView.as_view(url_name='schema'), name='swagger-ui'), + path('api/docs/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='redoc'), +] diff --git a/config/wsgi.py b/config/wsgi.py new file mode 100644 index 0000000..4ced574 --- /dev/null +++ b/config/wsgi.py @@ -0,0 +1,16 @@ +""" +WSGI config for config project. + +It exposes the WSGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/6.0/howto/deployment/wsgi/ +""" + +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') + +application = get_wsgi_application() diff --git a/manage.py b/manage.py new file mode 100644 index 0000000..8e7ac79 --- /dev/null +++ b/manage.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" +import os +import sys + + +def main(): + """Run administrative tasks.""" + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) + + +if __name__ == '__main__': + main() diff --git a/prompts/__init__.py b/prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/prompts/case_scoring_rules.md b/prompts/case_scoring_rules.md new file mode 100644 index 0000000..4ad0482 --- /dev/null +++ b/prompts/case_scoring_rules.md @@ -0,0 +1,57 @@ +# 角色 + +你是一位医学教育评估专家。你的任务是根据病例数据生成评分规则(评分维度)。 + +# 输入 + +用户会提供一份病例的结构化数据(JSON),包含病例类型、主诉、诊断、治疗方案等信息。 + +# 输出要求 + +请严格按以下 JSON 结构输出,不要输出任何其他内容: + +```json +{ + "scoring_rules": [ + { + "dimension": "评分维度名称", + "competency_dimension": "对应能力维度", + "score_weight": 0.25, + "ai_auto_score": true, + "osce_dimension": false, + "scoring_standard": "该维度的评分标准描述", + "rubric_json": { + "excellent": "5分:具体标准", + "good": "4分:具体标准", + "average": "3分:具体标准", + "poor": "≤2分:具体标准" + } + } + ] +} +``` + +# 规则 + +1. 生成 3~8 条评分规则,覆盖病例的核心考核维度。 +2. 所有 `score_weight` 之和必须等于 1.0(精确到小数点后两位)。 +3. 每条 `score_weight` 范围为 (0, 1]。 +4. `dimension` 必填且不可为空字符串。 +5. `competency_dimension` 从以下选取:`临床思维`、`问诊技巧`、`体格检查`、`辅助检查判读`、`诊断能力`、`治疗决策`、`医患沟通`、`医学人文`、`团队协作`、`应急处理`。也可根据病例特殊性扩展。 +6. `ai_auto_score`:该维度是否适合 AI 自动评分。对话类、知识问答类设为 true;操作类、情感沟通类设为 false。 +7. `osce_dimension`:仅当病例明确涉及 OSCE 考站时设为 true,否则 false。 + +## 按病例类型侧重 + +### traditional(传统病例) +- 必须包含:问诊全面性、诊断准确性、治疗方案合理性 +- 根据 `standard_diagnosis`、`standard_treatment` 细化评分标准 +- 若有 `guideline_reference`,在 rubric 中引用指南标准 + +### teaching(教学互动病例) +- 必须包含:教学目标达成度、讨论参与度 +- 根据 `teaching_goal`、`discussion_questions` 生成对应维度 +- `scoring_focus` 中提到的方向应作为高权重维度 + +8. `rubric_json` 中每个等级的描述必须具体、可操作,不要使用模糊表述。 +9. 输出必须是合法 JSON,不要包含注释或 markdown 代码块标记。 diff --git a/prompts/case_teaching_full.md b/prompts/case_teaching_full.md new file mode 100644 index 0000000..a82150f --- /dev/null +++ b/prompts/case_teaching_full.md @@ -0,0 +1,55 @@ +# 角色 + +你是一位医学教育内容结构化专家。你的任务是将医学教学病例 PDF 文本解析为结构化 JSON。 + +# 输入 + +用户会提供一份或多份医学教学病例 PDF 的文本内容。 + +# 输出要求 + +请严格按以下 JSON 结构输出,不要输出任何其他内容。所有字段必须存在,若原文未提及则填空字符串或空数组。 + +```json +{ + "title": "病例标题", + "case_type": "teaching", + "difficulty": "easy|medium|hard", + "chief_complaint": "主诉", + "description": "病例简介/摘要", + "patient_age": 45, + "patient_gender": "male|female", + "tags": "逗号分隔标签", + "symptom_tags": ["发热", "咳嗽"], + "disease_tags": ["肺炎"], + "competency_tags": ["临床思维", "医患沟通"], + "guideline_tags": ["社区获得性肺炎诊疗指南"], + "knowledge_points": ["肺炎的鉴别诊断", "教学查房流程"], + "icd_codes": "J18.9", + "estimated_minutes": 45, + "osce_enabled": false, + "department_name": "呼吸内科", + "teaching": { + "teaching_goal": "教学目标", + "discussion_questions": "讨论问题", + "teacher_guide": "教师指南", + "scoring_focus": "评分重点" + } +} +``` + +# 规则 + +1. `case_type` 固定为 `"teaching"`。 +2. `difficulty` 根据病例复杂度判断:简单常见病 → easy,需鉴别诊断 → medium,多系统/罕见病 → hard。 +3. `patient_age` 为整数,无法判断填 `null`。 +4. `patient_gender` 仅 `"male"` 或 `"female"`,无法判断填空字符串。 +5. 标签类字段(symptom_tags、disease_tags 等)至少各提取 1 个,从原文推断。 +6. `department_name` 根据病例内容推断最匹配的科室名称。 +7. `teaching` 子对象中: + - `teaching_goal`:提取或推断本病例的教学目标。 + - `discussion_questions`:提取讨论题目,多个用换行分隔。 + - `teacher_guide`:提取教师引导要点。 + - `scoring_focus`:提取评分重点关注方向。 +8. 不要生成 `scoring_rules`、`stages` 等字段。 +9. 输出必须是合法 JSON,不要包含注释或 markdown 代码块标记。 diff --git a/prompts/case_traditional_full.md b/prompts/case_traditional_full.md new file mode 100644 index 0000000..e54bd08 --- /dev/null +++ b/prompts/case_traditional_full.md @@ -0,0 +1,53 @@ +# 角色 + +你是一位医学教育内容结构化专家。你的任务是将医学病例 PDF 文本解析为结构化 JSON。 + +# 输入 + +用户会提供一份或多份医学病例 PDF 的文本内容。 + +# 输出要求 + +请严格按以下 JSON 结构输出,不要输出任何其他内容。所有字段必须存在,若原文未提及则填空字符串或空数组。 + +```json +{ + "title": "病例标题", + "case_type": "traditional", + "difficulty": "easy|medium|hard", + "chief_complaint": "主诉", + "description": "病例简介/摘要", + "patient_age": 45, + "patient_gender": "male|female", + "tags": "逗号分隔标签", + "symptom_tags": ["发热", "咳嗽"], + "disease_tags": ["肺炎"], + "competency_tags": ["临床思维", "问诊技巧"], + "guideline_tags": ["社区获得性肺炎诊疗指南"], + "knowledge_points": ["肺炎的鉴别诊断", "抗生素选择原则"], + "icd_codes": "J18.9", + "estimated_minutes": 30, + "osce_enabled": false, + "department_name": "呼吸内科", + "traditional": { + "standard_diagnosis": "标准诊断", + "standard_treatment": "标准治疗方案", + "guideline_reference": "参考指南" + } +} +``` + +# 规则 + +1. `case_type` 固定为 `"traditional"`。 +2. `difficulty` 根据病例复杂度判断:简单常见病 → easy,需鉴别诊断 → medium,多系统/罕见病 → hard。 +3. `patient_age` 为整数,无法判断填 `null`。 +4. `patient_gender` 仅 `"male"` 或 `"female"`,无法判断填空字符串。 +5. 标签类字段(symptom_tags、disease_tags 等)至少各提取 1 个,从原文推断。 +6. `department_name` 根据病例内容推断最匹配的科室名称。 +7. `traditional` 子对象中: + - `standard_diagnosis`:从原文提取或推断标准诊断。 + - `standard_treatment`:从原文提取标准治疗方案。 + - `guideline_reference`:引用相关临床指南名称。 +8. 不要生成 `scoring_rules`、`stages` 等字段。 +9. 输出必须是合法 JSON,不要包含注释或 markdown 代码块标记。 diff --git a/prompts/loader.py b/prompts/loader.py new file mode 100644 index 0000000..00e6dfb --- /dev/null +++ b/prompts/loader.py @@ -0,0 +1,17 @@ +import hashlib +from pathlib import Path +from functools import lru_cache + +_PROMPTS_DIR = Path(__file__).resolve().parent + + +@lru_cache(maxsize=16) +def load_prompt(name: str) -> tuple[str, str]: + """加载提示词文件,返回 (content, prompt_version)。 + + prompt_version = 文件内容 MD5 前 8 位,用于审计追溯。 + """ + path = _PROMPTS_DIR / f'{name}.md' + content = path.read_text(encoding='utf-8') + version = hashlib.md5(content.encode('utf-8')).hexdigest()[:8] + return content, version diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..dffbc31 --- /dev/null +++ b/readme.md @@ -0,0 +1,19 @@ +### 数据库迁移 +```bash +python manage.py migrate +``` + +### 创建管理员 +```bash +python manage.py createsuperuser +``` + +### 启动服务 +```bash +python manage.py runserver +``` + +### 访问 API 文档 +- Swagger UI: http://localhost:8000/api/docs/swagger/ +- ReDoc: http://localhost:8000/api/docs/redoc/ +- 管理后台: http://localhost:8000/admin/ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..977c61d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +Django>=5.0,<6.0 +djangorestframework>=3.14.0 +djangorestframework-simplejwt>=5.3.0 +drf-spectacular>=0.27.0 +django-filter>=23.0 +Pillow>=10.0.0 +python-dotenv>=1.0.0 +mysqlclient>=2.2.0 +redis>=5.0.0 +django-redis>=5.4.0 + +# 新增 - 短信 +alibabacloud_dysmsapi20170525>=2.0 +alibabacloud_tea_util +alibabacloud_tea_openapi + +# 新增 - 病例端 +openai>=1.30 +pdfplumber>=0.11 +jsonschema>=4.0 diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..fc635ed --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,252 @@ +"""D8 测试共享工具:URL 常量、用户/科室夹具、SMS 注入、载荷构建。""" + +from django.core.cache import cache +from django.core.files.uploadedfile import SimpleUploadedFile +from django.test import TestCase, TransactionTestCase +from rest_framework.test import APIClient +from rest_framework_simplejwt.tokens import RefreshToken + +from apps.user.models import User, Institution, Department, TeacherStudentRelation + + +class CacheTestCase(TestCase): + """所有测试基类:setUp 自动 clear Redis,隔离 SMS 验证码/限流/JWT 黑名单。""" + + def setUp(self): + super().setUp() + cache.clear() + + +class CacheTransactionTestCase(TransactionTestCase): + """事务测试基类:setUp 自动 clear Redis。""" + + def setUp(self): + super().setUp() + cache.clear() + +# ─── URL 常量 ───────────────────────────────────────────────────────────────── + +# 用户认证 +USER_SEND_CODE_URL = '/api/user/auth/send-code/' +USER_REGISTER_URL = '/api/user/auth/register/' +USER_LOGIN_URL = '/api/user/auth/login/' +USER_LOGIN_CODE_URL = '/api/user/auth/login-code/' +USER_LOGOUT_URL = '/api/user/auth/logout/' +USER_REFRESH_URL = '/api/user/auth/refresh/' +USER_RESET_PWD_URL = '/api/user/auth/reset-password/' +USER_CHANGE_PWD_URL = '/api/user/users/change-password/' +USER_ME_URL = '/api/user/users/me/' +USER_LIST_URL = '/api/user/users/' + +# 病例 +CASE_PARSE_URL = '/api/case/cases/parse-pdf/' +CASE_GENERATE_RULES_URL = '/api/case/cases/generate-scoring-rules/' +CASE_FULL_CREATE_URL = '/api/case/cases/full-create/' + + +def user_detail_url(user_id): + return f'/api/user/users/{user_id}/' + + +def case_full_url(case_id): + return f'/api/case/cases/{case_id}/full/' + + +# ─── SMS 注入 ───────────────────────────────────────────────────────────────── + +def inject_sms_code(phone, scene, code='123456'): + """直接将验证码写入 Redis,绕过 SMS 服务。""" + cache.set(f'sms:{scene}:{phone}', code, timeout=300) + + +# ─── 用户工具 ───────────────────────────────────────────────────────────────── + +def create_test_user(phone='13900000001', password='TestPass1', + real_name='测试用户', role_type='student', status=1): + """创建测试用户(已知密码),返回 User 实例。""" + user = User.objects.create_user( + username=phone, + password=password, + phone=phone, + real_name=real_name, + role_type=role_type, + status=status, + ) + return user + + +def get_tokens(user): + """返回 {'access': '...', 'refresh': '...'} 字符串。""" + refresh = RefreshToken.for_user(user) + return {'access': str(refresh.access_token), 'refresh': str(refresh)} + + +def get_auth_client(user): + """返回已携带 JWT Bearer 的 APIClient。""" + tokens = get_tokens(user) + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}') + return client + + +# ─── 师生关系 ───────────────────────────────────────────────────────────────── + +def create_teacher_student_relation(teacher, student, status=1): + """创建师生关系记录。""" + return TeacherStudentRelation.objects.create( + teacher=teacher, + student=student, + relation_type='指导', + status=status, + ) + + +# ─── 科室工具 ───────────────────────────────────────────────────────────────── + +def ensure_institution(name='测试医院'): + inst, _ = Institution.objects.get_or_create( + name=name, + defaults={'type': 'hospital', 'province': '北京', 'city': '北京'}, + ) + return inst + + +def ensure_department(name='儿科', institution_name='测试医院'): + inst = ensure_institution(institution_name) + dept, _ = Department.objects.get_or_create( + name=name, + defaults={'institution': inst, 'category': '临床'}, + ) + return dept + + +# ─── 病例载荷构建 ───────────────────────────────────────────────────────────── + +def build_traditional_payload(department_name='儿科', scoring_rules_count=2): + """构建合法的传统病例 full-create 载荷。""" + rules = [ + { + 'dimension': f'测试维度{i + 1}', + 'score_weight': round(1.0 / scoring_rules_count, 2), + 'ai_auto_score': True, + 'scoring_standard': f'评分标准{i + 1}', + } + for i in range(scoring_rules_count) + ] + return { + 'title': '测试传统病例-表单录入', + 'case_type': 'traditional', + 'difficulty': 'medium', + 'chief_complaint': '发热 3 天', + 'description': '患儿,男,4 岁,因发热 3 天就诊。', + 'patient_age': 4, + 'patient_gender': 'male', + 'department_name': department_name, + 'estimated_minutes': 30, + 'osce_enabled': False, + 'tags': '儿科,发热', + 'traditional': { + 'standard_diagnosis': '上呼吸道感染', + 'standard_treatment': '对症治疗,退热处理', + 'guideline_reference': '《儿科学》第 9 版', + }, + 'scoring_rules': rules, + } + + +def build_teaching_payload(department_name='儿科', scoring_rules_count=2): + """构建合法的教学病例 full-create 载荷。""" + rules = [ + { + 'dimension': f'教学维度{i + 1}', + 'score_weight': round(1.0 / scoring_rules_count, 2), + 'ai_auto_score': False, + 'scoring_standard': f'教学评分标准{i + 1}', + } + for i in range(scoring_rules_count) + ] + return { + 'title': '测试教学病例-表单录入', + 'case_type': 'teaching', + 'difficulty': 'hard', + 'chief_complaint': '腹痛 2 天', + 'description': '患者,女,28 岁,因腹痛 2 天就诊。', + 'patient_age': 28, + 'patient_gender': 'female', + 'department_name': department_name, + 'estimated_minutes': 45, + 'osce_enabled': False, + 'teaching': { + 'teaching_goal': '掌握急腹症鉴别诊断', + 'discussion_questions': '如何鉴别急性阑尾炎与其他急腹症?', + 'teacher_guide': '引导学生按 SOAP 格式分析', + 'scoring_focus': '鉴别诊断思路', + }, + 'scoring_rules': rules, + } + + +# ─── AI Mock 数据 ───────────────────────────────────────────────────────────── + +MOCK_C1_PARSE_RESULT = { + 'data': { + 'title': 'Mock-儿科发热病例', + 'case_type': 'traditional', + 'difficulty': 'medium', + 'chief_complaint': '发热 3 天', + 'description': '患儿,男,4 岁,因发热 3 天就诊。', + 'patient_age': 4, + 'patient_gender': 'male', + 'department_name': '儿科', + 'estimated_minutes': 30, + 'osce_enabled': False, + 'tags': '儿科,发热', + 'traditional': { + 'standard_diagnosis': 'Mock 上呼吸道感染', + 'standard_treatment': 'Mock 对症治疗', + 'guideline_reference': 'Mock 指南', + }, + }, + 'usage': {'prompt_tokens': 100, 'completion_tokens': 200, 'total_tokens': 300}, +} + +MOCK_C2_SCORING_RULES = { + 'data': { + 'scoring_rules': [ + { + 'dimension': '诊断准确性', + 'competency_dimension': '临床推理', + 'score_weight': 0.4, + 'ai_auto_score': True, + 'osce_dimension': False, + 'scoring_standard': '能准确判断上呼吸道感染', + }, + { + 'dimension': '治疗方案合理性', + 'competency_dimension': '治疗决策', + 'score_weight': 0.3, + 'ai_auto_score': True, + 'osce_dimension': False, + 'scoring_standard': '治疗方案符合指南推荐', + }, + { + 'dimension': '医患沟通', + 'competency_dimension': '沟通技巧', + 'score_weight': 0.3, + 'ai_auto_score': False, + 'osce_dimension': True, + 'scoring_standard': '能向家属解释病情和治疗方案', + }, + ], + }, + 'usage': {'prompt_tokens': 150, 'completion_tokens': 250, 'total_tokens': 400}, +} + + +def make_fake_pdf(): + """创建一个假 PDF 上传文件(仅用于触发 multipart 解析)。""" + return SimpleUploadedFile( + 'test.pdf', + b'%PDF-1.4 fake content for testing', + content_type='application/pdf', + ) diff --git a/test/manage.py b/test/manage.py new file mode 100644 index 0000000..8e7ac79 --- /dev/null +++ b/test/manage.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" +import os +import sys + + +def main(): + """Run administrative tasks.""" + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) + + +if __name__ == '__main__': + main() diff --git a/test/swagger_tryout.py b/test/swagger_tryout.py new file mode 100644 index 0000000..555fc91 --- /dev/null +++ b/test/swagger_tryout.py @@ -0,0 +1,481 @@ +""" +Swagger Try-it-out 等效脚本:逐个调用所有接口,验证可达性和基本功能。 +运行方式:.venv\\Scripts\\python.exe test/swagger_tryout.py +前提:Django dev server 已在 http://127.0.0.1:8000 运行,Redis 已启动。 +日志输出:logs/test-swagger-YYYY-MM-DD.log(含完整请求体和响应体) +""" + +import io +import json +import sys +import time +import subprocess +from datetime import datetime +from pathlib import Path + +import requests + +# 修复 Windows GBK 编码问题 +sys.stdout.reconfigure(encoding='utf-8') +sys.stderr.reconfigure(encoding='utf-8') + +BASE = 'http://127.0.0.1:8000' +PYTHON = r'D:\01Agent\medical_training\.venv\Scripts\python.exe' +CWD = r'D:\01Agent\medical_training' +PASS = 'PASS' +FAIL = 'FAIL' +results = [] + +# ─── 日志文件 ───────────────────────────────────────────────────────────────── + +LOGS_DIR = Path(CWD) / 'logs' +LOGS_DIR.mkdir(exist_ok=True) +LOG_FILE = LOGS_DIR / f'test-swagger-{datetime.now().strftime("%Y-%m-%d")}.log' +_log_fh = open(LOG_FILE, 'a', encoding='utf-8') + + +def _write_log(text): + """同时写入日志文件和控制台。""" + _log_fh.write(text + '\n') + _log_fh.flush() + print(text) + + +def _format_headers(headers): + """将 requests 的 headers 格式化为 JSON 字符串。""" + if not headers: + return '' + return json.dumps(dict(headers), ensure_ascii=False) + + +def log(api_id, method, url, expected, actual, detail='', + req_headers=None, req_body=None, resp_headers=None, resp_body=None): + status = PASS if actual in (expected if isinstance(expected, (list, tuple)) else [expected]) else FAIL + results.append((api_id, method, url, expected, actual, status)) + exp_str = str(expected) + _write_log(f' {status} {api_id:<6} {method:<6} {url:<50} expect={exp_str:<20} got={actual} {detail}') + # 详细请求/响应写入日志文件(不打印到控制台,避免过长) + if req_headers is not None: + _log_fh.write(f' >>> headers: {_format_headers(req_headers)}\n') + if req_body is not None: + body_str = json.dumps(req_body, ensure_ascii=False, indent=2) if isinstance(req_body, dict) else str(req_body) + _log_fh.write(f' >>> body: {body_str}\n') + if resp_headers is not None: + _log_fh.write(f' <<< headers: {_format_headers(resp_headers)}\n') + if resp_body is not None: + body_str = json.dumps(resp_body, ensure_ascii=False, indent=2) if isinstance(resp_body, dict) else str(resp_body) + # 截断过长响应 + if len(body_str) > 2000: + body_str = body_str[:2000] + f'... (truncated, {len(body_str)} chars total)' + _log_fh.write(f' <<< body: {body_str}\n') + _log_fh.flush() + + +def section(title): + line = f'\n{"="*90}\n {title}\n{"="*90}' + _write_log(line) + + +def django_eval(code): + """在独立进程中执行 Django 代码并返回 stdout。""" + preamble = ( + 'import django, os; ' + 'os.environ.setdefault("DJANGO_SETTINGS_MODULE","config.settings"); ' + 'django.setup(); ' + ) + proc = subprocess.run( + [PYTHON, '-c', preamble + code], + capture_output=True, text=True, cwd=CWD, + ) + return proc.stdout.strip() + + +def get_sms_code(phone, scene): + """从 Redis 读取短信验证码。""" + val = django_eval( + f'from django.core.cache import cache; print(cache.get("sms:{scene}:{phone}"))' + ) + return val if val and val != 'None' else None + + +def inject_sms_code(phone, scene, code='123456'): + """手动注入短信验证码到 Redis。""" + django_eval( + f'from django.core.cache import cache; ' + f'cache.set("sms:{scene}:{phone}", "{code}", 300)' + ) + return code + + +# ─── 清理 Redis 残留数据(限流计数等)──────────────────────────────────────── + +print('\n[准备] 清理 Redis 缓存...') +django_eval('from django.core.cache import cache; cache.clear(); print("OK")') + +# 删除上次可能残留的测试用户 +PHONE = '13700000099' +django_eval( + f'from apps.user.models import User; ' + f'User.objects.filter(phone="{PHONE}").delete(); print("cleaned")' +) +print('[准备] 完成\n') + +s = requests.Session() +PASSWORD = 'SwagTest1' + +# ═══════════════════════════════════════════════════════════════════════════════ +section('用户端接口 (U1-U10)') +# ═══════════════════════════════════════════════════════════════════════════════ + +# U1: 发送验证码 (register) +u1_body = {'phone': PHONE, 'scene': 'register'} +r = s.post(f'{BASE}/api/user/auth/send-code/', json=u1_body) +log('U1', 'POST', '/api/user/auth/send-code/', 200, r.status_code, + req_headers=r.request.headers, req_body=u1_body, + resp_headers=dict(r.headers), resp_body=r.json()) + +# U2: 注册 +code = get_sms_code(PHONE, 'register') +if not code: + code = inject_sms_code(PHONE, 'register') +u2_body = {'phone': PHONE, 'code': code, 'password': PASSWORD, 'real_name': 'Swagger测试'} +r = s.post(f'{BASE}/api/user/auth/register/', json=u2_body) +log('U2', 'POST', '/api/user/auth/register/', 201, r.status_code, + req_headers=r.request.headers, req_body=u2_body, + resp_headers=dict(r.headers), resp_body=r.json()) + +# U3: 密码登录 +u3_body = {'phone': PHONE, 'password': PASSWORD} +r = s.post(f'{BASE}/api/user/auth/login/', json=u3_body) +log('U3', 'POST', '/api/user/auth/login/', 200, r.status_code, + req_headers=r.request.headers, req_body=u3_body, + resp_headers=dict(r.headers), resp_body=r.json()) +tokens = r.json().get('tokens', {}) +access = tokens.get('access', '') +refresh = tokens.get('refresh', '') +auth = {'Authorization': f'Bearer {access}'} + +# U4: 验证码登录 +r = s.post(f'{BASE}/api/user/auth/send-code/', json={'phone': PHONE, 'scene': 'login'}) +login_code = get_sms_code(PHONE, 'login') +if not login_code: + login_code = inject_sms_code(PHONE, 'login', '654321') +u4_body = {'phone': PHONE, 'code': login_code} +r = s.post(f'{BASE}/api/user/auth/login-code/', json=u4_body) +log('U4', 'POST', '/api/user/auth/login-code/', 200, r.status_code, + req_headers=r.request.headers, req_body=u4_body, + resp_headers=dict(r.headers), resp_body=r.json()) +if r.status_code == 200: + tokens = r.json().get('tokens', {}) + access = tokens.get('access', access) + refresh = tokens.get('refresh', refresh) + auth = {'Authorization': f'Bearer {access}'} + +# U5: 重置密码 +NEW_PASSWORD = 'SwagNew1' +r = s.post(f'{BASE}/api/user/auth/send-code/', json={'phone': PHONE, 'scene': 'reset'}) +reset_code = get_sms_code(PHONE, 'reset') +if not reset_code: + reset_code = inject_sms_code(PHONE, 'reset', '111111') +u5_body = {'phone': PHONE, 'code': reset_code, 'new_password': NEW_PASSWORD} +r = s.post(f'{BASE}/api/user/auth/reset-password/', json=u5_body) +log('U5', 'POST', '/api/user/auth/reset-password/', 200, r.status_code, + req_headers=r.request.headers, req_body=u5_body, + resp_headers=dict(r.headers), resp_body=r.json()) + +# reset-password 调用 invalidate_user_tokens(time()+1),必须等 1s 再登录 +time.sleep(1.2) + +# 用新密码重新登录 +r = s.post(f'{BASE}/api/user/auth/login/', json={'phone': PHONE, 'password': NEW_PASSWORD}) +tokens = r.json().get('tokens', {}) +access = tokens.get('access', '') +refresh = tokens.get('refresh', '') +auth = {'Authorization': f'Bearer {access}'} + +# U6: 修改密码 +FINAL_PASSWORD = 'SwagFin1' +u6_body = {'old_password': NEW_PASSWORD, 'new_password': FINAL_PASSWORD} +r = s.post(f'{BASE}/api/user/users/change-password/', json=u6_body, headers=auth) +log('U6', 'POST', '/api/user/users/change-password/', 200, r.status_code, + req_headers=r.request.headers, req_body=u6_body, + resp_headers=dict(r.headers), resp_body=r.json()) + +# change-password 同样 invalidate_user_tokens(time()+1) +time.sleep(1.2) + +# 用最终密码重新登录 +r = s.post(f'{BASE}/api/user/auth/login/', json={'phone': PHONE, 'password': FINAL_PASSWORD}) +tokens = r.json().get('tokens', {}) +access = tokens.get('access', '') +refresh = tokens.get('refresh', '') +auth = {'Authorization': f'Bearer {access}'} + +# U8: 刷新 Token +u8_body = {'refresh': refresh} +r = s.post(f'{BASE}/api/user/auth/refresh/', json=u8_body) +log('U8', 'POST', '/api/user/auth/refresh/', 200, r.status_code, + req_headers=r.request.headers, req_body=u8_body, + resp_headers=dict(r.headers), resp_body=r.json()) +if r.status_code == 200: + access = r.json().get('access', access) + auth = {'Authorization': f'Bearer {access}'} + +# /me: GET /me +r = s.get(f'{BASE}/api/user/users/me/', headers=auth) +log('/me', 'GET', '/api/user/users/me/', 200, r.status_code, + f'phone={r.json().get("phone","")}' if r.status_code == 200 else '', + req_headers=r.request.headers, resp_headers=dict(r.headers), resp_body=r.json()) +test_user_id = r.json().get('id') if r.status_code == 200 else None + +# ── U9/U10: 用户列表 + 用户详情 ────────────────────────────────────────────── +# 创建 admin、teacher、student 用户 + 师生关系 +ADMIN_PHONE = '13700000088' +TEACHER_PHONE = '13700000077' +STUDENT_PHONE = '13700000066' +ROLE_PWD = 'RoleTest1' + +django_eval( + f'from apps.user.models import User, TeacherStudentRelation; ' + f'User.objects.filter(phone__in=["{ADMIN_PHONE}","{TEACHER_PHONE}","{STUDENT_PHONE}"]).delete(); ' + f'admin = User.objects.create_user(username="{ADMIN_PHONE}", password="{ROLE_PWD}", ' + f' phone="{ADMIN_PHONE}", real_name="Swagger管理员", role_type="super_admin", status=1); ' + f'teacher = User.objects.create_user(username="{TEACHER_PHONE}", password="{ROLE_PWD}", ' + f' phone="{TEACHER_PHONE}", real_name="Swagger教师", role_type="teacher", status=1); ' + f'student = User.objects.create_user(username="{STUDENT_PHONE}", password="{ROLE_PWD}", ' + f' phone="{STUDENT_PHONE}", real_name="Swagger学生", role_type="student", status=1); ' + f'TeacherStudentRelation.objects.create(teacher=teacher, student=student, ' + f' relation_type="指导", status=1); ' + f'print(f"admin={{admin.id}} teacher={{teacher.id}} student={{student.id}}")' +) + +# 管理员登录 +r = s.post(f'{BASE}/api/user/auth/login/', json={'phone': ADMIN_PHONE, 'password': ROLE_PWD}) +admin_tokens = r.json().get('tokens', {}) +admin_auth = {'Authorization': f'Bearer {admin_tokens.get("access", "")}'} +admin_refresh = admin_tokens.get('refresh', '') + +# U9: 管理员获取用户列表 +r = s.get(f'{BASE}/api/user/users/', headers=admin_auth) +u9_detail = '' +if r.status_code == 200: + u9_data = r.json() + u9_results = u9_data.get('results', u9_data) + u9_detail = f'count={len(u9_results)}' +log('U9', 'GET', '/api/user/users/', 200, r.status_code, u9_detail, + req_headers=r.request.headers, resp_headers=dict(r.headers), resp_body=r.json()) + +# U9-b: 教师获取用户列表(仅名下学生) +r_teacher_login = s.post(f'{BASE}/api/user/auth/login/', + json={'phone': TEACHER_PHONE, 'password': ROLE_PWD}) +teacher_tokens = r_teacher_login.json().get('tokens', {}) +teacher_auth = {'Authorization': f'Bearer {teacher_tokens.get("access", "")}'} + +r = s.get(f'{BASE}/api/user/users/', headers=teacher_auth) +u9b_detail = '' +if r.status_code == 200: + u9b_data = r.json() + u9b_results = u9b_data.get('results', u9b_data) + u9b_detail = f'count={len(u9b_results)}(should=1 student only)' +log('U9-b', 'GET', '/api/user/users/ (teacher)', 200, r.status_code, u9b_detail, + req_headers=r.request.headers, resp_headers=dict(r.headers), resp_body=r.json()) + +# U9-c: 普通用户(当前测试用户)获取列表 → 403 +r = s.get(f'{BASE}/api/user/users/', headers=auth) +log('U9-c', 'GET', '/api/user/users/ (normal)', 403, r.status_code, + f'code={r.json().get("code","")}' if r.status_code == 403 else '', + req_headers=r.request.headers, resp_headers=dict(r.headers), resp_body=r.json()) + +# U10: 管理员查看学生详情 +# 先获取 student id +student_id = django_eval( + f'from apps.user.models import User; ' + f'u = User.objects.get(phone="{STUDENT_PHONE}"); print(u.id)' +) +r = s.get(f'{BASE}/api/user/users/{student_id}/', headers=admin_auth) +log('U10', 'GET', f'/api/user/users/{student_id}/', 200, r.status_code, + f'real_name={r.json().get("real_name","")}' if r.status_code == 200 else '', + resp_body=r.json()) + +# U10-b: 教师查看名下学生详情 +r = s.get(f'{BASE}/api/user/users/{student_id}/', headers=teacher_auth) +log('U10-b', 'GET', f'/api/user/users/{student_id}/ (teacher)', 200, r.status_code, + resp_body=r.json()) + +# U10-c: 普通用户查看自己详情 +if test_user_id: + r = s.get(f'{BASE}/api/user/users/{test_user_id}/', headers=auth) + log('U10-c', 'GET', f'/api/user/users/{test_user_id}/ (self)', 200, r.status_code, + resp_body=r.json()) + +# 清理辅助用户 +django_eval( + f'from apps.user.models import User; ' + f'User.objects.filter(phone__in=["{ADMIN_PHONE}","{TEACHER_PHONE}","{STUDENT_PHONE}"]).delete(); ' + f'print("cleaned")' +) + +# U7: 退出登录 (logout) — 放最后,因为会吊销 refresh +u7_body = {'refresh': refresh} +r = s.post(f'{BASE}/api/user/auth/logout/', json=u7_body, headers=auth) +log('U7', 'POST', '/api/user/auth/logout/', 200, r.status_code, + req_body=u7_body, resp_body=r.json()) + +# ═══════════════════════════════════════════════════════════════════════════════ +section('病例端接口 (C1-C5)') +# ═══════════════════════════════════════════════════════════════════════════════ + +# 重新登录(logout 吊销了上一个 refresh) +time.sleep(1.2) +r = s.post(f'{BASE}/api/user/auth/login/', json={'phone': PHONE, 'password': FINAL_PASSWORD}) +tokens = r.json().get('tokens', {}) +access = tokens.get('access', '') +auth = {'Authorization': f'Bearer {access}'} + +# 确保科室存在 +django_eval( + 'from apps.user.models import Institution, Department; ' + 'inst, _ = Institution.objects.get_or_create(name="测试医院", ' + ' defaults={"type":"hospital","province":"北京","city":"北京"}); ' + 'Department.objects.get_or_create(name="儿科", ' + ' defaults={"institution":inst,"category":"临床"}); ' + 'print("OK")' +) + +# C1: PDF 解析 — 使用项目真实 PDF 文件 +REAL_PDF = r'D:\01Agent\medical_training\儿科 病例样例(SOAP+循证).pdf' +with open(REAL_PDF, 'rb') as f: + r = s.post( + f'{BASE}/api/case/cases/parse-pdf/', + files={'files': ('儿科 病例样例(SOAP+循证).pdf', f, 'application/pdf')}, + data={'case_type': 'traditional'}, + headers=auth, + ) +c1_ok = [200, 500, 429] # 200=AI 解析成功, 500=AI 异常, 429=限流 +detail = '' +c1_resp = None +if r.headers.get('content-type', '').startswith('application/json'): + body = r.json() + c1_resp = body + detail = f'code={body.get("code", "")}' + if r.status_code == 200: + detail = f'parse_id={body.get("parse_id","")}, keys={list(body.get("data",{}).keys())[:5]}' +log('C1', 'POST', '/api/case/cases/parse-pdf/', c1_ok, r.status_code, detail, + req_body={'case_type': 'traditional', 'files': ''}, resp_body=c1_resp) + +# C2: 生成评分规则 — 如果 C1 成功则用其返回的 data,否则用手工载荷 +c1_data = None +if r.status_code == 200: + c1_data = body.get('data', {}) + +c2_payload = c1_data if c1_data else { + 'title': 'Swagger测试病例', + 'case_type': 'traditional', + 'chief_complaint': '发热3天', + 'description': '患儿男4岁发热3天就诊', + 'traditional': { + 'standard_diagnosis': '上呼吸道感染', + 'standard_treatment': '对症治疗', + }, +} +r = s.post(f'{BASE}/api/case/cases/generate-scoring-rules/', json=c2_payload, headers=auth) +c2_ok = [200, 500, 429] +detail = '' +scoring_rules_from_ai = None +c2_resp = None +if r.headers.get('content-type', '').startswith('application/json'): + c2_body = r.json() + c2_resp = c2_body + if r.status_code == 200: + scoring_rules_from_ai = c2_body.get('scoring_rules', []) + detail = f'generated={c2_body.get("generated","")}, rules={len(scoring_rules_from_ai)}' + else: + detail = f'code={c2_body.get("code", "")}' +log('C2', 'POST', '/api/case/cases/generate-scoring-rules/', c2_ok, r.status_code, detail, + req_body=c2_payload, resp_body=c2_resp) + +# C3: full-create — 优先用 C1+C2 AI 结果,否则用手工载荷 +if c1_data and scoring_rules_from_ai: + _write_log(' [INFO] C3 使用 C1 AI 解析 + C2 AI 评分规则组装载荷') + payload = {**c1_data} + payload['scoring_rules'] = scoring_rules_from_ai +else: + _write_log(' [INFO] C3 使用手工表单载荷(C1/C2 未全部成功)') + payload = { + 'title': 'Swagger-Try-It-Out-病例', + 'case_type': 'traditional', + 'difficulty': 'medium', + 'chief_complaint': '发热 3 天', + 'description': '患儿,男,4 岁,因发热 3 天就诊。', + 'patient_age': 4, + 'patient_gender': 'male', + 'department_name': '儿科', + 'estimated_minutes': 30, + 'osce_enabled': False, + 'tags': '儿科,发热', + 'traditional': { + 'standard_diagnosis': '上呼吸道感染', + 'standard_treatment': '对症治疗,退热处理', + 'guideline_reference': '《儿科学》第 9 版', + }, + 'scoring_rules': [ + { + 'dimension': '诊断准确性', + 'score_weight': 0.6, + 'ai_auto_score': True, + 'scoring_standard': '准确判断上呼吸道感染', + }, + { + 'dimension': '治疗方案', + 'score_weight': 0.4, + 'ai_auto_score': True, + 'scoring_standard': '治疗方案合理', + }, + ], + } +r = s.post(f'{BASE}/api/case/cases/full-create/', json=payload, headers=auth) +c3_resp = r.json() if r.headers.get('content-type', '').startswith('application/json') else None +log('C3', 'POST', '/api/case/cases/full-create/', 201, r.status_code, + req_body=payload, resp_body=c3_resp) +case_id = None +if r.status_code == 201: + case_id = r.json()['case']['id'] + +# C4: GET full +if case_id: + r = s.get(f'{BASE}/api/case/cases/{case_id}/full/', headers=auth) + log('C4', 'GET', f'/api/case/cases/{case_id}/full/', 200, r.status_code, + resp_body=r.json() if r.headers.get('content-type', '').startswith('application/json') else None) +else: + _write_log(' SKIP C4 — C3 未返回 case_id') + +# C5: PATCH full +if case_id: + c5_body = {'title': 'Swagger-更新标题'} + r = s.patch(f'{BASE}/api/case/cases/{case_id}/full/', json=c5_body, headers=auth) + log('C5', 'PATCH', f'/api/case/cases/{case_id}/full/', 200, r.status_code, + req_body=c5_body, + resp_body=r.json() if r.headers.get('content-type', '').startswith('application/json') else None) +else: + _write_log(' SKIP C5 — C3 未返回 case_id') + +# ═══════════════════════════════════════════════════════════════════════════════ +section('汇总') +# ═══════════════════════════════════════════════════════════════════════════════ +total = len(results) +passed = sum(1 for r in results if r[5] == PASS) +failed = sum(1 for r in results if r[5] == FAIL) +_write_log(f'\n 总计: {total} 个接口 | 通过: {passed} | 失败: {failed}') +if failed: + _write_log('\n 失败接口:') + for r in results: + if r[5] == FAIL: + _write_log(f' {r[0]} {r[1]} {r[2]} -- expect={r[3]}, got={r[4]}') + _write_log(f'\n 日志文件: {LOG_FILE}') + _log_fh.close() + sys.exit(1) +else: + _write_log('\n ALL PASSED - 全部接口 Swagger Try-it-out 验证通过!') + _write_log(f'\n 日志文件: {LOG_FILE}') + _log_fh.close() + sys.exit(0) diff --git a/test/test_case_happy.py b/test/test_case_happy.py new file mode 100644 index 0000000..de73536 --- /dev/null +++ b/test/test_case_happy.py @@ -0,0 +1,164 @@ +"""病例域 2 条 happy-path 流程测试。""" + +from unittest.mock import patch, MagicMock + +from django.core.cache import cache + +from apps.case.models import CaseBase, TraditionalCase, ScoringRule +from apps.user.throttling import PdfParseUserThrottle, ScoringRuleGenerateUserThrottle +from .conftest import ( + CacheTestCase, + CASE_PARSE_URL, CASE_GENERATE_RULES_URL, CASE_FULL_CREATE_URL, + case_full_url, create_test_user, get_auth_client, ensure_department, + build_traditional_payload, make_fake_pdf, + MOCK_C1_PARSE_RESULT, MOCK_C2_SCORING_RULES, +) + + +class CaseFormHappyPathTest(CacheTestCase): + """HP-5: 表单录入 → full-create → GET → PATCH → GET 验证""" + + def setUp(self): + super().setUp() + self.user = create_test_user(phone='13900100001', password='CaseTest1') + self.client = get_auth_client(self.user) + ensure_department('儿科') + + def test_flow_form_create_read_update(self): + """HP-5: C3 full-create → C4 GET full → C5 PATCH → C4 GET verify""" + # C3: full-create(2 条评分规则) + payload = build_traditional_payload(department_name='儿科', scoring_rules_count=2) + resp = self.client.post(CASE_FULL_CREATE_URL, payload, format='json') + self.assertEqual(resp.status_code, 201, resp.content) + + created = resp.json() + case_id = created['case']['id'] + self.assertEqual(created['case']['case_type'], 'traditional') + self.assertEqual(created['case']['publish_status'], 0) # 草稿 + self.assertIn('traditional', created) + self.assertIsNotNone(created['traditional']) + self.assertEqual(len(created['scoring_rules']), 2) + + # C4: GET full + resp = self.client.get(case_full_url(case_id)) + self.assertEqual(resp.status_code, 200, resp.content) + full = resp.json() + self.assertEqual(full['case']['title'], payload['title']) + self.assertEqual(len(full['scoring_rules']), 2) + + # C5: PATCH(改标题 + 改子表 + 替换为 1 条评分规则) + patch_data = { + 'title': '更新后的标题', + 'traditional': { + 'standard_diagnosis': '更新后的诊断', + }, + 'scoring_rules': [ + { + 'dimension': '更新后的维度', + 'score_weight': 1.0, + 'ai_auto_score': False, + 'scoring_standard': '更新后的标准', + }, + ], + } + resp = self.client.patch(case_full_url(case_id), patch_data, format='json') + self.assertEqual(resp.status_code, 200, resp.content) + + # C4: GET 验证更新 + resp = self.client.get(case_full_url(case_id)) + self.assertEqual(resp.status_code, 200, resp.content) + full = resp.json() + self.assertEqual(full['case']['title'], '更新后的标题') + self.assertEqual(full['traditional']['standard_diagnosis'], '更新后的诊断') + self.assertEqual(len(full['scoring_rules']), 1) + self.assertEqual(full['scoring_rules'][0]['dimension'], '更新后的维度') + + # 验证 DB + case = CaseBase.objects.get(id=case_id) + self.assertEqual(case.title, '更新后的标题') + self.assertEqual(ScoringRule.objects.filter(case_id=case_id).count(), 1) + tc = TraditionalCase.objects.get(case_id=case_id) + self.assertEqual(tc.standard_diagnosis, '更新后的诊断') + + +class CasePdfMockHappyPathTest(CacheTestCase): + """HP-6: PDF 解析(mock AI)→ 生成评分规则 → full-create → GET → PATCH""" + + def setUp(self): + super().setUp() + self.user = create_test_user(phone='13900100002', password='CaseTest2') + self.client = get_auth_client(self.user) + ensure_department('儿科') + + @patch('apps.case.services.case_importer.extract_text_from_pdfs', + return_value='患儿,男,4岁,因发热3天就诊。体温38.5°C...') + def test_flow_pdf_mock_full_pipeline(self, mock_pdf): + """HP-6: C1 parse-pdf → C2 generate-scoring-rules → C3 full-create → C4 GET → C5 PATCH""" + + # mock call_deepseek: 第 1 次返回 C1 解析结果,第 2 次返回 C2 评分规则 + call_count = {'n': 0} + def mock_deepseek(system_prompt, user_content): + call_count['n'] += 1 + if call_count['n'] == 1: + return MOCK_C1_PARSE_RESULT + return MOCK_C2_SCORING_RULES + + with ( + patch('apps.case.services.deepseek_client.call_deepseek', side_effect=mock_deepseek), + patch.object(PdfParseUserThrottle, 'allow_request', return_value=True), + patch.object(ScoringRuleGenerateUserThrottle, 'allow_request', return_value=True), + ): + # C1: parse-pdf + fake_pdf = make_fake_pdf() + resp = self.client.post( + CASE_PARSE_URL, + {'files': fake_pdf, 'case_type': 'traditional'}, + format='multipart', + ) + self.assertEqual(resp.status_code, 200, resp.content) + parse_result = resp.json() + self.assertIn('parse_id', parse_result) + self.assertEqual(parse_result['case_type'], 'traditional') + data = parse_result['data'] + self.assertEqual(data['case_type'], 'traditional') + self.assertIn('traditional', data) + self.assertNotIn('scoring_rules', data) + + # C2: generate-scoring-rules + resp = self.client.post( + CASE_GENERATE_RULES_URL, + data, + format='json', + ) + self.assertEqual(resp.status_code, 200, resp.content) + gen_result = resp.json() + self.assertGreaterEqual(gen_result['generated'], 1) + scoring_rules = gen_result['scoring_rules'] + + # C3: full-create(组装 C1 data + C2 scoring_rules) + create_payload = {**data} + create_payload['scoring_rules'] = scoring_rules + create_payload['parse_id'] = parse_result['parse_id'] + resp = self.client.post(CASE_FULL_CREATE_URL, create_payload, format='json') + self.assertEqual(resp.status_code, 201, resp.content) + created = resp.json() + case_id = created['case']['id'] + self.assertEqual(len(created['scoring_rules']), len(scoring_rules)) + + # C4: GET full + resp = self.client.get(case_full_url(case_id)) + self.assertEqual(resp.status_code, 200, resp.content) + full = resp.json() + self.assertEqual(full['case']['id'], case_id) + self.assertEqual(full['case']['case_type'], 'traditional') + + # C5: PATCH + resp = self.client.patch(case_full_url(case_id), { + 'title': 'AI-更新标题', + }, format='json') + self.assertEqual(resp.status_code, 200, resp.content) + + # 验证 PATCH 生效 + resp = self.client.get(case_full_url(case_id)) + self.assertEqual(resp.status_code, 200, resp.content) + self.assertEqual(resp.json()['case']['title'], 'AI-更新标题') diff --git a/test/test_case_negative.py b/test/test_case_negative.py new file mode 100644 index 0000000..9495223 --- /dev/null +++ b/test/test_case_negative.py @@ -0,0 +1,224 @@ +"""病例域负向测试:字段校验、越权、限流、事务回滚、AI Schema 违规。""" + +from unittest.mock import patch + +from django.core.cache import cache +from django.db import IntegrityError +from rest_framework.test import APIClient + +from apps.case.models import CaseBase, ScoringRule +from apps.user.throttling import PdfParseUserThrottle +from config.exceptions import AppError +from .conftest import ( + CacheTestCase, CacheTransactionTestCase, + CASE_PARSE_URL, CASE_FULL_CREATE_URL, + case_full_url, create_test_user, get_auth_client, ensure_department, + build_traditional_payload, make_fake_pdf, +) + + +class CaseFieldValidationTest(CacheTestCase): + """病例字段校验负向测试。""" + + def setUp(self): + super().setUp() + self.user = create_test_user(phone='13800002001', password='CaseNeg1') + self.client = get_auth_client(self.user) + ensure_department('儿科') + + def test_invalid_case_type_400(self): + """N10: case_type='invalid' → 400 CASE_TYPE_NOT_SUPPORTED""" + payload = build_traditional_payload() + payload['case_type'] = 'invalid' + resp = self.client.post(CASE_FULL_CREATE_URL, payload, format='json') + self.assertEqual(resp.status_code, 400, resp.content) + self.assertEqual(resp.json()['code'], 'CASE_TYPE_NOT_SUPPORTED') + + def test_empty_scoring_rules_400(self): + """N11: scoring_rules=[] → 400 CASE_VALIDATION_ERROR""" + payload = build_traditional_payload() + payload['scoring_rules'] = [] + resp = self.client.post(CASE_FULL_CREATE_URL, payload, format='json') + self.assertEqual(resp.status_code, 400, resp.content) + self.assertEqual(resp.json()['code'], 'CASE_VALIDATION_ERROR') + + def test_subtable_conflict_400(self): + """N12: 同时传 traditional + teaching → 400 CASE_SUBTYPE_CONFLICT""" + payload = build_traditional_payload() + payload['teaching'] = { + 'teaching_goal': '不应该出现', + 'discussion_questions': '冲突', + } + resp = self.client.post(CASE_FULL_CREATE_URL, payload, format='json') + self.assertEqual(resp.status_code, 400, resp.content) + self.assertEqual(resp.json()['code'], 'CASE_SUBTYPE_CONFLICT') + + def test_missing_subtable_400(self): + """N13: case_type=traditional 但无 traditional 子表 → 400 CASE_SUBTYPE_REQUIRED""" + payload = build_traditional_payload() + del payload['traditional'] + resp = self.client.post(CASE_FULL_CREATE_URL, payload, format='json') + self.assertEqual(resp.status_code, 400, resp.content) + self.assertEqual(resp.json()['code'], 'CASE_SUBTYPE_REQUIRED') + + def test_patch_published_case_400(self): + """N16: PATCH 已发布病例 → 400 CASE_NOT_EDITABLE""" + # 先创建病例 + payload = build_traditional_payload() + resp = self.client.post(CASE_FULL_CREATE_URL, payload, format='json') + self.assertEqual(resp.status_code, 201, resp.content) + case_id = resp.json()['case']['id'] + + # 发布 + case = CaseBase.objects.get(id=case_id) + case.publish_status = 1 + case.save(update_fields=['publish_status']) + + # PATCH → 应被拒绝 + resp = self.client.patch(case_full_url(case_id), { + 'title': '不应该成功', + }, format='json') + self.assertEqual(resp.status_code, 400, resp.content) + self.assertEqual(resp.json()['code'], 'CASE_NOT_EDITABLE') + + +class CaseAuthorizationTest(CacheTestCase): + """病例越权测试。""" + + def setUp(self): + super().setUp() + ensure_department('儿科') + + def test_unauth_full_create_401(self): + """N14: 未登录 POST full-create → 401""" + client = APIClient() + payload = build_traditional_payload() + resp = client.post(CASE_FULL_CREATE_URL, payload, format='json') + self.assertEqual(resp.status_code, 401, resp.content) + + def test_view_other_draft_403(self): + """N15: 用户 B 看用户 A 的草稿 → 403""" + # 用户 A 创建草稿 + user_a = create_test_user(phone='13800002010', password='UserAPass1') + client_a = get_auth_client(user_a) + payload = build_traditional_payload() + resp = client_a.post(CASE_FULL_CREATE_URL, payload, format='json') + self.assertEqual(resp.status_code, 201, resp.content) + case_id = resp.json()['case']['id'] + + # 验证是草稿 (publish_status=0) + self.assertEqual(resp.json()['case']['publish_status'], 0) + + # 用户 B 尝试访问 + user_b = create_test_user(phone='13800002011', password='UserBPass1') + client_b = get_auth_client(user_b) + resp = client_b.get(case_full_url(case_id)) + self.assertEqual(resp.status_code, 403, resp.content) + self.assertEqual(resp.json()['code'], 'CASE_PERMISSION_DENIED') + + +class CaseRateLimitTest(CacheTestCase): + """病例限流测试。""" + + def setUp(self): + super().setUp() + self.user = create_test_user(phone='13800002020', password='RateTest1') + self.client = get_auth_client(self.user) + + def test_rate_limit_pdf_parse_429(self): + """N17: PDF 解析限流 → 429""" + with ( + patch.object(PdfParseUserThrottle, 'allow_request', return_value=False), + patch.object(PdfParseUserThrottle, 'wait', return_value=60), + ): + fake_pdf = make_fake_pdf() + resp = self.client.post( + CASE_PARSE_URL, + {'files': fake_pdf, 'case_type': 'traditional'}, + format='multipart', + ) + self.assertEqual(resp.status_code, 429, resp.content) + self.assertEqual(resp.json()['code'], 'SYS_RATE_LIMIT') + + +class CaseTransactionRollbackTest(CacheTransactionTestCase): + """N18: 事务回滚测试 — 必须用 TransactionTestCase。""" + + def setUp(self): + super().setUp() + self.user = create_test_user(phone='13800002030', password='TxTest1') + self.client = get_auth_client(self.user) + ensure_department('儿科') + + def test_transaction_rollback(self): + """N18: bulk_create 失败 → CaseBase 应回滚""" + initial_count = CaseBase.objects.count() + + with patch.object( + ScoringRule.objects, 'bulk_create', + side_effect=IntegrityError('mocked DB error'), + ): + payload = build_traditional_payload() + resp = self.client.post(CASE_FULL_CREATE_URL, payload, format='json') + + # 请求应失败(500 或被异常处理捕获) + self.assertGreaterEqual(resp.status_code, 400) + + # CaseBase 应回滚到初始状态 + self.assertEqual( + CaseBase.objects.count(), initial_count, + 'CaseBase 应因事务回滚而未增加', + ) + + +class CaseAISchemaTest(CacheTestCase): + """AI Schema 违规测试。""" + + def setUp(self): + super().setUp() + self.user = create_test_user(phone='13800002040', password='AITest1') + self.client = get_auth_client(self.user) + + @patch('apps.case.services.case_importer.extract_text_from_pdfs', + return_value='虚拟文本内容') + def test_ai_bad_json_500(self, mock_pdf): + """N19: DeepSeek 返回非法 JSON → 500 AI_BAD_JSON""" + with ( + patch( + 'apps.case.services.deepseek_client.call_deepseek', + side_effect=AppError('AI_BAD_JSON', 'AI 返回非合法 JSON', status_code=500), + ), + patch.object(PdfParseUserThrottle, 'allow_request', return_value=True), + ): + fake_pdf = make_fake_pdf() + resp = self.client.post( + CASE_PARSE_URL, + {'files': fake_pdf, 'case_type': 'traditional'}, + format='multipart', + ) + self.assertEqual(resp.status_code, 500, resp.content) + self.assertEqual(resp.json()['code'], 'AI_BAD_JSON') + + @patch('apps.case.services.case_importer.extract_text_from_pdfs', + return_value='虚拟文本内容') + def test_ai_schema_violation_500(self, mock_pdf): + """N20: DeepSeek 输出不符合 JSON Schema → 500 AI_SCHEMA_VIOLATION""" + # 返回缺少 title 和 case_type 的数据 → jsonschema 校验失败 + bad_result = { + 'data': { + 'wrong_field': 'no title here', + }, + 'usage': {'prompt_tokens': 10, 'completion_tokens': 20}, + } + with ( + patch('apps.case.services.deepseek_client.call_deepseek', return_value=bad_result), + patch.object(PdfParseUserThrottle, 'allow_request', return_value=True), + ): + fake_pdf = make_fake_pdf() + resp = self.client.post( + CASE_PARSE_URL, + {'files': fake_pdf, 'case_type': 'traditional'}, + format='multipart', + ) + self.assertEqual(resp.status_code, 500, resp.content) + self.assertEqual(resp.json()['code'], 'AI_SCHEMA_VIOLATION') diff --git a/test/test_user_happy.py b/test/test_user_happy.py new file mode 100644 index 0000000..db31646 --- /dev/null +++ b/test/test_user_happy.py @@ -0,0 +1,383 @@ +"""用户域 4 条 happy-path 流程测试。""" + +import time +from contextlib import ExitStack +from unittest.mock import patch + +from django.core.cache import cache +from rest_framework.test import APIClient + +from apps.user.throttling import ( + SmsPhoneMinuteThrottle, SmsPhoneDayThrottle, SmsIpThrottle, + RegisterIpThrottle, ResetPhoneThrottle, +) +from .conftest import ( + CacheTestCase, + USER_SEND_CODE_URL, USER_REGISTER_URL, USER_LOGIN_URL, + USER_LOGIN_CODE_URL, USER_CHANGE_PWD_URL, USER_RESET_PWD_URL, USER_ME_URL, + USER_LIST_URL, user_detail_url, + inject_sms_code, create_test_user, get_auth_client, get_tokens, + create_teacher_student_relation, +) + + +def _bypass_all_auth_throttles(stack): + """在 ExitStack 中注册所有认证相关限流 bypass。""" + for cls in (SmsPhoneMinuteThrottle, SmsPhoneDayThrottle, SmsIpThrottle, + RegisterIpThrottle, ResetPhoneThrottle): + stack.enter_context(patch.object(cls, 'allow_request', return_value=True)) + + +class UserAuthHappyPathTest(CacheTestCase): + """用户认证 happy-path 测试。""" + + def setUp(self): + super().setUp() + self.client = APIClient() + + # ── HP-1: 注册 → 密码登录 → /me ────────────────────────────────────── + + def test_flow_register_login_me(self): + """HP-1: U1 send-code(register) → U2 register → U3 login → GET /me""" + phone = '13900000001' + password = 'Abc12345' + real_name = '张三' + + with ExitStack() as stack: + _bypass_all_auth_throttles(stack) + + # U1: send-code (register) + resp = self.client.post(USER_SEND_CODE_URL, { + 'phone': phone, 'scene': 'register', + }) + self.assertEqual(resp.status_code, 200, resp.content) + + # 从 cache 读验证码 + code = cache.get(f'sms:register:{phone}') + self.assertIsNotNone(code, '验证码未写入缓存') + + # U2: register + resp = self.client.post(USER_REGISTER_URL, { + 'phone': phone, + 'code': str(code), + 'password': password, + 'real_name': real_name, + }) + self.assertEqual(resp.status_code, 201, resp.content) + data = resp.json() + self.assertIn('tokens', data) + self.assertEqual(data['user']['phone'], phone) + self.assertEqual(data['user']['real_name'], real_name) + + # U3: login (password) — 限流 bypass 已退出,login 无限流 + resp = self.client.post(USER_LOGIN_URL, { + 'phone': phone, 'password': password, + }) + self.assertEqual(resp.status_code, 200, resp.content) + tokens = resp.json()['tokens'] + + # GET /me + self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}') + resp = self.client.get(USER_ME_URL) + self.assertEqual(resp.status_code, 200, resp.content) + self.assertEqual(resp.json()['phone'], phone) + self.assertEqual(resp.json()['real_name'], real_name) + + # ── HP-2: 验证码登录 ───────────────────────────────────────────────── + + def test_flow_code_login(self): + """HP-2: 预创建用户 → U1 send-code(login) → U4 login-code → /me""" + phone = '13900000002' + user = create_test_user(phone=phone, password='TestPass1') + + with ExitStack() as stack: + _bypass_all_auth_throttles(stack) + + # U1: send-code (login) + resp = self.client.post(USER_SEND_CODE_URL, { + 'phone': phone, 'scene': 'login', + }) + self.assertEqual(resp.status_code, 200, resp.content) + + code = cache.get(f'sms:login:{phone}') + self.assertIsNotNone(code) + + # U4: login-code + resp = self.client.post(USER_LOGIN_CODE_URL, { + 'phone': phone, 'code': str(code), + }) + self.assertEqual(resp.status_code, 200, resp.content) + tokens = resp.json()['tokens'] + self.assertIn('access', tokens) + self.assertIn('refresh', tokens) + + # GET /me + self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}') + resp = self.client.get(USER_ME_URL) + self.assertEqual(resp.status_code, 200, resp.content) + self.assertEqual(resp.json()['phone'], phone) + + # ── HP-3: 重置密码 ────────────────────────────────────────────────── + + def test_flow_reset_password(self): + """HP-3: U1 send-code(reset) → U5 reset-password → U3 login(新密码)""" + phone = '13900000003' + old_pwd = 'OldPass1' + new_pwd = 'NewPass1' + create_test_user(phone=phone, password=old_pwd) + + with ExitStack() as stack: + _bypass_all_auth_throttles(stack) + + # U1: send-code (reset) + resp = self.client.post(USER_SEND_CODE_URL, { + 'phone': phone, 'scene': 'reset', + }) + self.assertEqual(resp.status_code, 200, resp.content) + + code = cache.get(f'sms:reset:{phone}') + self.assertIsNotNone(code) + + # U5: reset-password + resp = self.client.post(USER_RESET_PWD_URL, { + 'phone': phone, + 'code': str(code), + 'new_password': new_pwd, + }) + self.assertEqual(resp.status_code, 200, resp.content) + + # 新密码登录成功 + resp = self.client.post(USER_LOGIN_URL, { + 'phone': phone, 'password': new_pwd, + }) + self.assertEqual(resp.status_code, 200, resp.content) + + # 旧密码登录失败 + resp = self.client.post(USER_LOGIN_URL, { + 'phone': phone, 'password': old_pwd, + }) + self.assertIn(resp.status_code, (400, 401)) + + # ── HP-4: 修改密码 + 旧 token 失效 ────────────────────────────────── + + def test_flow_change_password(self): + """HP-4: login → U6 change-password → 旧 token 失效 → 新密码 login""" + phone = '13900000004' + old_pwd = 'OldPass1' + new_pwd = 'NewPass1' + user = create_test_user(phone=phone, password=old_pwd) + + # U3: login + resp = self.client.post(USER_LOGIN_URL, { + 'phone': phone, 'password': old_pwd, + }) + self.assertEqual(resp.status_code, 200, resp.content) + old_access = resp.json()['tokens']['access'] + + # U6: change-password + self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {old_access}') + resp = self.client.post(USER_CHANGE_PWD_URL, { + 'old_password': old_pwd, + 'new_password': new_pwd, + }) + self.assertEqual(resp.status_code, 200, resp.content) + + # 等待 1 秒:invalidate_user_tokens 写入 time()+1 + time.sleep(1) + + # 旧 token 应被拒绝 + self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {old_access}') + resp = self.client.get(USER_ME_URL) + self.assertEqual(resp.status_code, 401, f'旧 token 应失效: {resp.content}') + + # 新密码登录 + self.client.credentials() # 清除旧 auth + resp = self.client.post(USER_LOGIN_URL, { + 'phone': phone, 'password': new_pwd, + }) + self.assertEqual(resp.status_code, 200, resp.content) + new_access = resp.json()['tokens']['access'] + + # 新 token 正常 + self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {new_access}') + resp = self.client.get(USER_ME_URL) + self.assertEqual(resp.status_code, 200, resp.content) + + +class UserListDetailHappyPathTest(CacheTestCase): + """U9 用户列表 + U10 用户详情 happy-path 测试。""" + + # ── HP-5: 管理员获取全部用户列表 ───────────────────────────────────── + + def test_admin_list_all_users(self): + """HP-5: admin GET /users/ → 200,可见全部用户""" + admin = create_test_user( + phone='13900100001', password='Admin123', + real_name='管理员', role_type='super_admin', + ) + stu1 = create_test_user( + phone='13900100002', password='Stu12345', + real_name='学生A', role_type='student', + ) + stu2 = create_test_user( + phone='13900100003', password='Stu12345', + real_name='学生B', role_type='student', + ) + + client = get_auth_client(admin) + resp = client.get(USER_LIST_URL) + self.assertEqual(resp.status_code, 200, resp.content) + + data = resp.json() + # DRF 分页:results 列表 + results = data.get('results', data) + result_ids = [u['id'] for u in results] + self.assertIn(stu1.id, result_ids) + self.assertIn(stu2.id, result_ids) + self.assertIn(admin.id, result_ids) + + # ── HP-6: 教师仅看到自己名下学生 ───────────────────────────────────── + + def test_teacher_list_own_students_only(self): + """HP-6: teacher GET /users/ → 200,仅包含名下活跃学生""" + teacher = create_test_user( + phone='13900100010', password='Teacher1', + real_name='王老师', role_type='teacher', + ) + stu_own = create_test_user( + phone='13900100011', password='Stu12345', + real_name='我的学生', role_type='student', + ) + stu_other = create_test_user( + phone='13900100012', password='Stu12345', + real_name='其他学生', role_type='student', + ) + + create_teacher_student_relation(teacher, stu_own, status=1) + # stu_other 无关系 + + client = get_auth_client(teacher) + resp = client.get(USER_LIST_URL) + self.assertEqual(resp.status_code, 200, resp.content) + + data = resp.json() + results = data.get('results', data) + result_ids = [u['id'] for u in results] + self.assertIn(stu_own.id, result_ids) + self.assertNotIn(stu_other.id, result_ids) + # 教师自己也不应出现在列表(queryset 过滤 role_type='student') + self.assertNotIn(teacher.id, result_ids) + + # ── HP-7: 教师不可见已结束关系的学生 ───────────────────────────────── + + def test_teacher_list_excludes_ended_relation(self): + """HP-7: 已结束(status=0)的师生关系学生不出现在列表""" + teacher = create_test_user( + phone='13900100020', password='Teacher1', + real_name='李老师', role_type='teacher', + ) + stu_active = create_test_user( + phone='13900100021', password='Stu12345', + real_name='活跃学生', role_type='student', + ) + stu_ended = create_test_user( + phone='13900100022', password='Stu12345', + real_name='已毕业学生', role_type='student', + ) + + create_teacher_student_relation(teacher, stu_active, status=1) + create_teacher_student_relation(teacher, stu_ended, status=0) + + client = get_auth_client(teacher) + resp = client.get(USER_LIST_URL) + self.assertEqual(resp.status_code, 200, resp.content) + + results = resp.json().get('results', resp.json()) + result_ids = [u['id'] for u in results] + self.assertIn(stu_active.id, result_ids) + self.assertNotIn(stu_ended.id, result_ids) + + # ── HP-8: 管理员查看任意用户详情 ───────────────────────────────────── + + def test_admin_retrieve_any_user(self): + """HP-8: admin GET /users/{id}/ → 200,可查看任意用户""" + admin = create_test_user( + phone='13900100030', password='Admin123', + real_name='管理员', role_type='super_admin', + ) + student = create_test_user( + phone='13900100031', password='Stu12345', + real_name='某学生', role_type='student', + ) + + client = get_auth_client(admin) + resp = client.get(user_detail_url(student.id)) + self.assertEqual(resp.status_code, 200, resp.content) + self.assertEqual(resp.json()['id'], student.id) + self.assertEqual(resp.json()['real_name'], '某学生') + + # ── HP-9: 用户查看自己的详情 ───────────────────────────────────────── + + def test_self_retrieve(self): + """HP-9: student GET /users/{self.id}/ → 200,可查看自己""" + student = create_test_user( + phone='13900100040', password='Stu12345', + real_name='自查学生', role_type='student', + ) + + client = get_auth_client(student) + resp = client.get(user_detail_url(student.id)) + self.assertEqual(resp.status_code, 200, resp.content) + self.assertEqual(resp.json()['id'], student.id) + + # ── HP-10: 教师查看名下学生详情 ────────────────────────────────────── + + def test_teacher_retrieve_own_student(self): + """HP-10: teacher GET /users/{student.id}/ → 200,可查看名下学生""" + teacher = create_test_user( + phone='13900100050', password='Teacher1', + real_name='赵老师', role_type='teacher', + ) + student = create_test_user( + phone='13900100051', password='Stu12345', + real_name='赵的学生', role_type='student', + ) + create_teacher_student_relation(teacher, student, status=1) + + client = get_auth_client(teacher) + resp = client.get(user_detail_url(student.id)) + self.assertEqual(resp.status_code, 200, resp.content) + self.assertEqual(resp.json()['id'], student.id) + + # ── HP-11: 管理员列表支持过滤和搜索 ────────────────────────────────── + + def test_admin_list_filter_and_search(self): + """HP-11: admin GET /users/?role_type=student&search=张 → 过滤生效""" + admin = create_test_user( + phone='13900100060', password='Admin123', + real_name='管理员', role_type='super_admin', + ) + stu_zhang = create_test_user( + phone='13900100061', password='Stu12345', + real_name='张同学', role_type='student', + ) + stu_li = create_test_user( + phone='13900100062', password='Stu12345', + real_name='李同学', role_type='student', + ) + teacher = create_test_user( + phone='13900100063', password='Teacher1', + real_name='张老师', role_type='teacher', + ) + + client = get_auth_client(admin) + # 按 role_type 过滤 + search + resp = client.get(USER_LIST_URL, {'role_type': 'student', 'search': '张'}) + self.assertEqual(resp.status_code, 200, resp.content) + + results = resp.json().get('results', resp.json()) + result_ids = [u['id'] for u in results] + self.assertIn(stu_zhang.id, result_ids) + self.assertNotIn(stu_li.id, result_ids) + # 张老师 role_type=teacher,被 role_type=student 过滤掉 + self.assertNotIn(teacher.id, result_ids) diff --git a/test/test_user_negative.py b/test/test_user_negative.py new file mode 100644 index 0000000..5dcbfdb --- /dev/null +++ b/test/test_user_negative.py @@ -0,0 +1,248 @@ +"""用户域负向测试:限流、越权、字段校验。""" + +from unittest.mock import patch + +from django.core.cache import cache +from rest_framework.test import APIClient + +from apps.user.throttling import SmsPhoneMinuteThrottle, RegisterIpThrottle +from .conftest import ( + CacheTestCase, + USER_SEND_CODE_URL, USER_REGISTER_URL, USER_LOGIN_URL, + USER_RESET_PWD_URL, USER_CHANGE_PWD_URL, USER_ME_URL, + USER_LOGOUT_URL, USER_REFRESH_URL, + USER_LIST_URL, user_detail_url, + inject_sms_code, create_test_user, get_auth_client, get_tokens, + create_teacher_student_relation, +) + + +class UserNegativeTest(CacheTestCase): + """用户域负向路径测试。""" + + def setUp(self): + super().setUp() + self.client = APIClient() + + # ── 限流 ───────────────────────────────────────────────────────────── + + def test_rate_limit_sms_429(self): + """N1: SMS 限流 → 429""" + with ( + patch.object(SmsPhoneMinuteThrottle, 'allow_request', return_value=False), + patch.object(SmsPhoneMinuteThrottle, 'wait', return_value=60), + ): + resp = self.client.post(USER_SEND_CODE_URL, { + 'phone': '13800001001', 'scene': 'register', + }) + self.assertEqual(resp.status_code, 429, resp.content) + self.assertEqual(resp.json()['code'], 'SYS_RATE_LIMIT') + + # ── 越权 ───────────────────────────────────────────────────────────── + + def test_unauth_change_password_401(self): + """N2: 未登录 POST change-password → 401""" + resp = self.client.post(USER_CHANGE_PWD_URL, { + 'old_password': 'x', 'new_password': 'y', + }) + self.assertEqual(resp.status_code, 401, resp.content) + + def test_unauth_me_401(self): + """N3: 未登录 GET /me → 401""" + resp = self.client.get(USER_ME_URL) + self.assertEqual(resp.status_code, 401, resp.content) + + # ── 字段校验 ───────────────────────────────────────────────────────── + + def test_register_invalid_phone_400(self): + """N4: 手机号格式不合法 → 400 SMS_INVALID_PHONE""" + with patch.object(RegisterIpThrottle, 'allow_request', return_value=True): + resp = self.client.post(USER_REGISTER_URL, { + 'phone': '123', + 'code': '123456', + 'password': 'Abc12345', + 'real_name': '测试', + }) + self.assertEqual(resp.status_code, 400, resp.content) + self.assertEqual(resp.json()['code'], 'SMS_INVALID_PHONE') + + def test_register_weak_password_400(self): + """N5: 弱密码 → 400 AUTH_PASSWORD_WEAK""" + phone = '13800001002' + inject_sms_code(phone, 'register') + with patch.object(RegisterIpThrottle, 'allow_request', return_value=True): + resp = self.client.post(USER_REGISTER_URL, { + 'phone': phone, + 'code': '123456', + 'password': '123', + 'real_name': '测试弱密码', + }) + self.assertEqual(resp.status_code, 400, resp.content) + self.assertEqual(resp.json()['code'], 'AUTH_PASSWORD_WEAK') + + def test_register_duplicate_phone_400(self): + """N6: 已注册手机号再注册 → 400 AUTH_PHONE_REGISTERED""" + phone = '13800001003' + create_test_user(phone=phone) + inject_sms_code(phone, 'register') + with patch.object(RegisterIpThrottle, 'allow_request', return_value=True): + resp = self.client.post(USER_REGISTER_URL, { + 'phone': phone, + 'code': '123456', + 'password': 'Abc12345', + 'real_name': '重复注册', + }) + self.assertEqual(resp.status_code, 400, resp.content) + self.assertEqual(resp.json()['code'], 'AUTH_PHONE_REGISTERED') + + def test_login_wrong_password(self): + """N7: 错误密码 → 400 AUTH_BAD_CREDENTIALS""" + phone = '13800001004' + create_test_user(phone=phone, password='RealPass1') + resp = self.client.post(USER_LOGIN_URL, { + 'phone': phone, 'password': 'WrongPass1', + }) + self.assertIn(resp.status_code, (400, 401)) + self.assertEqual(resp.json()['code'], 'AUTH_BAD_CREDENTIALS') + + def test_login_account_lock_423(self): + """N8: 连续 5 次错误后第 6 次 → 423 AUTH_ACCOUNT_LOCKED""" + phone = '13800001005' + create_test_user(phone=phone, password='RealPass1') + + # 连续 5 次错误密码 + for _ in range(5): + self.client.post(USER_LOGIN_URL, { + 'phone': phone, 'password': 'Wrong!!!!', + }) + + # 第 6 次 + resp = self.client.post(USER_LOGIN_URL, { + 'phone': phone, 'password': 'Wrong!!!!', + }) + self.assertEqual(resp.status_code, 423, resp.content) + self.assertEqual(resp.json()['code'], 'AUTH_ACCOUNT_LOCKED') + + def test_reset_wrong_code(self): + """N9: 重置密码验证码不匹配 → AUTH_CODE_MISMATCH""" + phone = '13800001006' + create_test_user(phone=phone, password='OldPass1') + inject_sms_code(phone, 'reset', code='123456') + + from apps.user.throttling import ResetPhoneThrottle + with patch.object(ResetPhoneThrottle, 'allow_request', return_value=True): + resp = self.client.post(USER_RESET_PWD_URL, { + 'phone': phone, + 'code': '999999', # 错误验证码 + 'new_password': 'NewPass1', + }) + # AUTH_CODE_MISMATCH raises with default status_code or explicit status + self.assertIn(resp.status_code, (400, 401)) + self.assertEqual(resp.json()['code'], 'AUTH_CODE_MISMATCH') + + def test_refresh_revoked_token_401(self): + """N10: logout 后用旧 refresh → 401""" + phone = '13800001007' + user = create_test_user(phone=phone) + tokens = get_tokens(user) + + # logout(吊销 refresh) + self.client.post(USER_LOGOUT_URL, {'refresh': tokens['refresh']}) + + # 尝试用已吊销的 refresh 刷新 + resp = self.client.post(USER_REFRESH_URL, {'refresh': tokens['refresh']}) + self.assertEqual(resp.status_code, 401, resp.content) + self.assertEqual(resp.json()['code'], 'AUTH_TOKEN_INVALID') + + +class UserListDetailNegativeTest(CacheTestCase): + """U9 用户列表 + U10 用户详情 负向测试。""" + + def setUp(self): + super().setUp() + self.client = APIClient() + + # ── U9 列表权限 ────────────────────────────────────────────────────── + + def test_student_list_403(self): + """N11: student GET /users/ → 403 USER_NO_LIST_PERMISSION""" + student = create_test_user( + phone='13800002001', password='Stu12345', + real_name='学生', role_type='student', + ) + client = get_auth_client(student) + resp = client.get(USER_LIST_URL) + self.assertEqual(resp.status_code, 403, resp.content) + self.assertEqual(resp.json()['code'], 'USER_NO_LIST_PERMISSION') + + def test_doctor_list_403(self): + """N12: doctor GET /users/ → 403 USER_NO_LIST_PERMISSION""" + doctor = create_test_user( + phone='13800002002', password='Doc12345', + real_name='医生', role_type='doctor', + ) + client = get_auth_client(doctor) + resp = client.get(USER_LIST_URL) + self.assertEqual(resp.status_code, 403, resp.content) + self.assertEqual(resp.json()['code'], 'USER_NO_LIST_PERMISSION') + + def test_unauth_list_401(self): + """N13: 未登录 GET /users/ → 401""" + resp = self.client.get(USER_LIST_URL) + self.assertEqual(resp.status_code, 401, resp.content) + + # ── U10 详情权限 ───────────────────────────────────────────────────── + + def test_unauth_detail_401(self): + """N14: 未登录 GET /users/{id}/ → 401""" + user = create_test_user(phone='13800002010', password='Pass1234') + resp = self.client.get(user_detail_url(user.id)) + self.assertEqual(resp.status_code, 401, resp.content) + + def test_student_view_other_student_403(self): + """N15: student A 查看 student B 详情 → 403 USER_NO_VIEW_PERMISSION""" + stu_a = create_test_user( + phone='13800002020', password='Stu12345', + real_name='学生A', role_type='student', + ) + stu_b = create_test_user( + phone='13800002021', password='Stu12345', + real_name='学生B', role_type='student', + ) + client = get_auth_client(stu_a) + resp = client.get(user_detail_url(stu_b.id)) + self.assertEqual(resp.status_code, 403, resp.content) + self.assertEqual(resp.json()['code'], 'USER_NO_VIEW_PERMISSION') + + def test_teacher_view_unrelated_student_403(self): + """N16: teacher 查看非名下学生详情 → 403 USER_NO_VIEW_PERMISSION""" + teacher = create_test_user( + phone='13800002030', password='Teacher1', + real_name='刘老师', role_type='teacher', + ) + unrelated = create_test_user( + phone='13800002031', password='Stu12345', + real_name='非名下学生', role_type='student', + ) + # 无师生关系 + client = get_auth_client(teacher) + resp = client.get(user_detail_url(unrelated.id)) + self.assertEqual(resp.status_code, 403, resp.content) + self.assertEqual(resp.json()['code'], 'USER_NO_VIEW_PERMISSION') + + def test_teacher_view_ended_relation_student_403(self): + """N17: teacher 查看已结束关系学生详情 → 403 USER_NO_VIEW_PERMISSION""" + teacher = create_test_user( + phone='13800002040', password='Teacher1', + real_name='陈老师', role_type='teacher', + ) + student = create_test_user( + phone='13800002041', password='Stu12345', + real_name='已毕业学生', role_type='student', + ) + create_teacher_student_relation(teacher, student, status=0) # 已结束 + + client = get_auth_client(teacher) + resp = client.get(user_detail_url(student.id)) + self.assertEqual(resp.status_code, 403, resp.content) + self.assertEqual(resp.json()['code'], 'USER_NO_VIEW_PERMISSION') diff --git a/test/测试文档-D8.md b/test/测试文档-D8.md new file mode 100644 index 0000000..fcc3922 --- /dev/null +++ b/test/测试文档-D8.md @@ -0,0 +1,364 @@ +# D8 测试文档 + +> 测试日期:2026-05-29(U9/U10 补充) +> 测试人员:Claude AI + 人工审核 +> 测试环境:Windows / Python 3.14 / Django 5.0 / MySQL 8 / Redis + +--- + +## 1. 测试环境 + +| 项目 | 值 | +|---|---| +| Python | 3.14 | +| Django | 5.0+ | +| DRF | 3.14+ | +| 数据库 | MySQL 8 (test_medical_training) | +| 缓存 | Redis(与生产环境一致,`django_redis`) | +| 运行命令 | `.venv\Scripts\python.exe manage.py test test -v2 --keepdb` | + +--- + +## 2. 测试总览 + +| 类别 | 测试文件 | 用例数 | 通过 | 失败 | +|---|---|---|---|---| +| 用户域 happy-path | `test_user_happy.py` | 11 | 11 | 0 | +| 病例域 happy-path | `test_case_happy.py` | 2 | 2 | 0 | +| 用户域 negative | `test_user_negative.py` | 17 | 17 | 0 | +| 病例域 negative | `test_case_negative.py` | 11 | 11 | 0 | +| **合计** | | **41** | **41** | **0** | + +--- + +## 3. Happy-Path 测试结果 + +### 3.1 用户域(11 条流程) + +| ID | 测试方法 | 测试什么 | 结果 | +|---|---|---|---| +| HP-1 | `test_flow_register_login_me` | **新用户注册全流程**:发送短信验证码 → 用验证码+密码注册账号 → 用密码登录 → 查看个人信息(确认手机号和姓名正确) | PASS | +| HP-2 | `test_flow_code_login` | **验证码登录**:已有账号的用户,发送登录验证码 → 用手机号+验证码登录(不需要密码)→ 查看个人信息确认身份正确 | PASS | +| HP-3 | `test_flow_reset_password` | **忘记密码重置**:发送重置验证码 → 用验证码设置新密码 → 用新密码能登录成功 → 用旧密码登录失败(旧密码已失效) | PASS | +| HP-4 | `test_flow_change_password` | **登录后修改密码**:先用旧密码登录拿到 token → 调用修改密码接口 → 等 1 秒后旧 token 被系统自动作废(返回 401)→ 用新密码重新登录,新 token 正常可用 | PASS | +| HP-5 | `test_admin_list_all_users` | **管理员看用户列表**:创建管理员+2 个学生,管理员调用用户列表接口,确认能看到系统里所有用户(包括自己和两个学生) | PASS | +| HP-6 | `test_teacher_list_own_students_only` | **教师只看到自己的学生**:创建教师+自己的学生+别人的学生,教师调用用户列表,确认只能看到与自己有师生关系的学生,看不到别人的学生,也看不到自己 | PASS | +| HP-7 | `test_teacher_list_excludes_ended_relation` | **已结束关系的学生不可见**:教师名下有一个活跃学生和一个已毕业学生(关系状态=已结束),教师调用列表,确认只能看到活跃学生,已毕业的不显示 | PASS | +| HP-8 | `test_admin_retrieve_any_user` | **管理员查看任意用户详情**:管理员可以通过 /users/{id}/ 查看系统中任何用户的详细信息,确认返回的姓名正确 | PASS | +| HP-9 | `test_self_retrieve` | **用户查看自己的详情**:学生通过 /users/{自己的id}/ 查看自己的信息,确认能正常返回 | PASS | +| HP-10 | `test_teacher_retrieve_own_student` | **教师查看名下学生详情**:教师和学生建立师生关系后,教师可以查看该学生的详细信息 | PASS | +| HP-11 | `test_admin_list_filter_and_search` | **列表筛选和搜索**:创建管理员+张同学(学生)+李同学(学生)+张老师(教师),管理员用 `role_type=student&search=张` 筛选,确认只返回张同学(李同学不姓张被排除,张老师不是学生被排除) | PASS | + +### 3.2 病例域(2 条流程) + +| ID | 测试方法 | 测试什么 | 结果 | +|---|---|---|---| +| HP-5 | `test_flow_form_create_read_update` | **手工录入病例全流程**:用表单数据创建一个传统病例(含 2 条评分规则)→ 查看完整病例确认数据正确 → 修改标题+诊断+减少为 1 条评分规则 → 再次查看确认修改生效 → 检查数据库记录是否一致 | PASS | +| HP-6 | `test_flow_pdf_mock_full_pipeline` | **PDF 上传到创建病例的完整流水线**(AI 部分用 mock 模拟):上传 PDF 文件 → AI 解析出病例结构化数据 → 用解析结果生成评分规则 → 组装数据创建病例 → 查看完整病例 → 修改标题 → 确认修改生效 | PASS | + +--- + +## 4. Negative 测试结果 + +### 4.1 用户域(17 条) + +| ID | 测试方法 | 测试什么 | 期望 | 结果 | +|---|---|---|---|---| +| N1 | `test_rate_limit_sms_429` | **短信发送频率超限**:模拟 1 分钟内已发过验证码,再次请求发送时系统拒绝,返回"请求太频繁" | 429 | PASS | +| N2 | `test_unauth_change_password_401` | **没登录就想改密码**:不带任何 token 直接调用修改密码接口,系统拒绝并要求先登录 | 401 | PASS | +| N3 | `test_unauth_me_401` | **没登录就想看个人信息**:不带 token 调用 /me 接口,系统拒绝 | 401 | PASS | +| N4 | `test_register_invalid_phone_400` | **手机号格式错误**:用 "123"(不是 11 位手机号)去注册,系统拒绝并提示手机号不合法 | 400 | PASS | +| N5 | `test_register_weak_password_400` | **密码太简单**:用 "123" 作为密码注册,系统拒绝并提示密码强度不够(要求大小写字母+数字,至少 8 位) | 400 | PASS | +| N6 | `test_register_duplicate_phone_400` | **手机号已被注册**:先创建一个用户,再用同一个手机号注册第二次,系统拒绝并提示该手机号已注册 | 400 | PASS | +| N7 | `test_login_wrong_password` | **密码错误**:用正确的手机号但错误的密码登录,系统拒绝并提示账号或密码错误 | 400 | PASS | +| N8 | `test_login_account_lock_423` | **连续输错密码被锁定**:连续 5 次输入错误密码,第 6 次登录时系统锁定账号,返回"账号已锁定"(防暴力破解) | 423 | PASS | +| N9 | `test_reset_wrong_code` | **重置密码时验证码错误**:真实验证码是 123456,但提交 999999,系统拒绝并提示验证码不匹配 | 400/401 | PASS | +| N10 | `test_refresh_revoked_token_401` | **退出登录后 token 失效**:先退出登录(logout 会吊销 refresh token),再用那个已吊销的 refresh token 去刷新,系统拒绝 | 401 | PASS | +| N11 | `test_student_list_403` | **学生不能看用户列表**:学生角色调用用户列表接口,系统拒绝(只有管理员和教师才能看) | 403 | PASS | +| N12 | `test_doctor_list_403` | **医生不能看用户列表**:医生角色调用用户列表接口,系统同样拒绝(医生也没有列表权限) | 403 | PASS | +| N13 | `test_unauth_list_401` | **没登录不能看用户列表**:不带 token 调用用户列表接口,系统要求先登录 | 401 | PASS | +| N14 | `test_unauth_detail_401` | **没登录不能看用户详情**:不带 token 调用用户详情接口,系统要求先登录 | 401 | PASS | +| N15 | `test_student_view_other_student_403` | **学生不能看别人的详情**:学生 A 试图查看学生 B 的个人信息,系统拒绝(只能看自己的) | 403 | PASS | +| N16 | `test_teacher_view_unrelated_student_403` | **教师不能看非名下学生**:教师试图查看一个和自己没有师生关系的学生的信息,系统拒绝 | 403 | PASS | +| N17 | `test_teacher_view_ended_relation_student_403` | **教师不能看已毕业学生**:教师和学生的师生关系已结束(status=0,如学生已毕业),教师再查看该学生详情,系统拒绝 | 403 | PASS | + +### 4.2 病例域(11 条) + +| ID | 测试方法 | 测试什么 | 期望 | 结果 | +|---|---|---|---|---| +| N10 | `test_invalid_case_type_400` | **病例类型不合法**:创建病例时 case_type 传 "invalid"(只支持 traditional 和 teaching),系统拒绝 | 400 | PASS | +| N11 | `test_empty_scoring_rules_400` | **评分规则为空**:创建病例时 scoring_rules 传空数组 `[]`,系统要求至少有 1 条评分规则 | 400 | PASS | +| N12 | `test_subtable_conflict_400` | **子表类型冲突**:创建传统病例时同时传了 traditional 和 teaching 两个子表数据,系统拒绝(一个病例只能有一种类型的子表) | 400 | PASS | +| N13 | `test_missing_subtable_400` | **缺少必要子表**:声明 case_type=traditional 但没有传 traditional 子表数据,系统拒绝(类型和子表必须对应) | 400 | PASS | +| N15 | `test_patch_published_case_400` | **已发布的病例不能编辑**:先创建病例并将其发布(publish_status=1),再尝试修改标题,系统拒绝(发布后不允许编辑) | 400 | PASS | +| N14 | `test_unauth_full_create_401` | **没登录不能创建病例**:不带 token 直接调用创建病例接口,系统要求先登录 | 401 | PASS | +| N17 | `test_view_other_draft_403` | **不能看别人的草稿**:用户 A 创建了一个草稿病例,用户 B 试图查看,系统拒绝(草稿只有创建者自己能看) | 403 | PASS | +| N18 | `test_rate_limit_pdf_parse_429` | **PDF 解析频率超限**:模拟用户短时间内已多次调用 PDF 解析,再次调用时系统拒绝并提示"请求太频繁" | 429 | PASS | +| N19 | `test_transaction_rollback` | **数据库事务回滚**:创建病例时模拟评分规则写入数据库失败(IntegrityError),验证病例主表也一起回滚(不会出现"有病例但没有评分规则"的残留数据) | 回滚成功 | PASS | +| N20 | `test_ai_bad_json_500` | **AI 返回乱码**:模拟 DeepSeek AI 返回的不是合法 JSON 格式,系统返回 500 并明确告知错误原因是 AI 输出异常 | 500 | PASS | +| N21 | `test_ai_schema_violation_500` | **AI 返回数据缺字段**:模拟 DeepSeek 返回了合法 JSON 但缺少必填字段(如 title),系统校验后返回 500 并提示 AI 输出不符合预期格式 | 500 | PASS | + +--- + +## 5. Bug 修复记录 + +### Bug-1: 限流 mock 导致 500 而非 429 + +- **发现阶段**:N1 `test_rate_limit_sms_429` +- **现象**:mock `SmsPhoneMinuteThrottle.allow_request` 返回 `False` 后,DRF 调用 `throttle.wait()` 获取重试时间,但因 `allow_request` 被完整 mock,`self.history` 未初始化,导致 `AttributeError` → 500 +- **根因**:DRF 的 `check_throttles` 在 `allow_request=False` 后会调用 `wait()` 方法读取 `self.history`,mock 只替换了 `allow_request` 未处理 `wait` +- **修复**:同时 mock `wait` 方法返回固定值 `60` +- **影响文件**:`test/test_user_negative.py`、`test/test_case_negative.py` +- **严重度**:低(仅影响测试代码,非业务代码 Bug) + +### Bug-2: PDF mock 路径不正确 + +- **发现阶段**:HP-6 `test_flow_pdf_mock_full_pipeline` +- **现象**:mock `apps.case.services.pdf_reader.extract_text_from_pdfs` 无效,真实 PDF 解析仍被执行,伪造 PDF 内容触发 `CASE_PDF_EMPTY` 错误 +- **根因**:`case_importer.py` 使用 `from .pdf_reader import extract_text_from_pdfs` 直接导入函数,mock 必须 patch 导入位置 `apps.case.services.case_importer.extract_text_from_pdfs` 而非定义位置 +- **修复**:更改 patch 路径为 `apps.case.services.case_importer.extract_text_from_pdfs` +- **影响文件**:`test/test_case_happy.py`、`test/test_case_negative.py` +- **严重度**:低(仅影响测试代码,非业务代码 Bug) + +### Bug-3: drf-spectacular 无法识别自定义认证类 + +- **发现阶段**:Swagger 文档生成时控制台日志 +- **现象**:所有 ViewSet 和函数视图均产生 `Warning: could not resolve authenticator RedisBlacklistJWTAuthentication`,Swagger UI 上缺少认证标识和 Authorize 按钮 +- **根因**:`RedisBlacklistJWTAuthentication` 继承自 `JWTAuthentication`,drf-spectacular 没有注册对应的 `OpenApiAuthenticationExtension`,不知道如何将其映射为 OpenAPI `securitySchemes` +- **修复**:创建 `apps/user/openapi.py`,定义 `RedisBlacklistJWTScheme` 扩展类,将该认证类映射为 `type: http, scheme: bearer, bearerFormat: JWT`;在 `apps/user/apps.py` 的 `ready()` 中 import 触发自动注册 +- **影响文件**:`apps/user/openapi.py`(新建)、`apps/user/apps.py` +- **严重度**:低(不影响接口功能,仅影响 Swagger 文档展示) + +### Bug-4: 函数视图缺少 Swagger 请求/响应 Schema + +- **发现阶段**:Swagger 文档生成时控制台日志 +- **现象**:`send_code`、`register`、`login_password`、`login_code`、`logout`、`reset_password` 6 个 `@api_view` 函数视图产生 `Error: unable to guess serializer`,Swagger UI 上这些接口没有请求体/响应体描述 +- **根因**:函数视图直接读 `request.data`,未声明 `serializer_class`,drf-spectacular 无法自动推断 schema +- **修复**:为 6 个函数视图添加 `@extend_schema` 装饰器,通过 `inline_serializer` 声明请求和响应字段;`refresh.py` 的类视图也加了 `@extend_schema(tags=['认证'])` +- **影响文件**:`apps/user/auth/send_code.py`、`register.py`、`login.py`、`logout.py`、`reset_password.py`、`refresh.py` +- **严重度**:低(不影响接口功能,仅影响 Swagger 文档展示) + +### Bug-5: 同名枚举冲突导致 Swagger 命名混乱 + +- **发现阶段**:Swagger 文档生成时控制台日志 +- **现象**:`case_type` 和 `status` 字段在不同 model/serializer 中有不同 choices 值集,drf-spectacular 自动生成带哈希后缀的名称(如 `CaseType629Enum`、`StatusDb0Enum`) +- **根因**:`CaseBase.CASE_TYPE_CHOICES`(4 值)与 C2/C3 内联序列化器 `ChoiceField(choices=['traditional','teaching'])`(2 值)同名不同值;`CaseBase.STATUS_CHOICES` / `User.STATUS_CHOICES`(相同值)与 `TrainingRecord.STATUS_CHOICES` / `TeacherStudentRelation.STATUS_CHOICES`(不同值)同名不同值 +- **修复**:在 `SPECTACULAR_SETTINGS['ENUM_NAME_OVERRIDES']` 中显式命名:`CaseTypeEnum`(4 值)、`CreatableCaseTypeEnum`(2 值)、`CommonStatusEnum`、`TrainingStatusEnum`、`TeacherStudentStatusEnum` 等 +- **影响文件**:`config/settings.py` +- **严重度**:极低(不影响接口功能,仅影响 Swagger 文档中枚举名称的可读性) + +### Bug-6: 审计日志文件在 Windows 上写入失败 + +- **发现阶段**:手动检查 `logs/audit.log` 发现文件为空(0 字节) +- **现象**:`TimedRotatingFileHandler` 在首次写入时尝试按日期轮转(rename `audit.log` → `audit.log.2026-05-27`),但 dev server 进程正占着文件,Windows 文件锁导致 `PermissionError: [WinError 32]`,轮转失败,日志丢失 +- **根因**:`TimedRotatingFileHandler` 的轮转机制依赖 `os.rename()`,Windows 上文件被其他进程打开时不允许 rename(Linux 无此问题) +- **修复**:自定义 `DailyFileHandler`(`config/logging_handlers.py`),直接按日期命名文件(`audit-YYYY-MM-DD.log`),日期切换时打开新文件,**不 rename 旧文件**,彻底避免文件锁问题;保留 `backup_count=30` 自动清理超过 30 天的旧日志 +- **影响文件**:`config/logging_handlers.py`(新建)、`config/settings.py`(LOGGING handler 配置) +- **严重度**:中(审计日志完全丢失,影响安全审计能力) + +### 修复验证 + +**Bug-3/4/5 Swagger 修复验证:** + +```bash +# 修复前 +Schema generation summary: +Warnings: 104 (20 unique) +Errors: 1 (1 unique) + +# 修复后 +Schema generation summary: +Warnings: 0 +Errors: 0 +``` + +- `python manage.py spectacular --validate --fail-on-warn` 退出码 0 + +**Bug-6 日志修复验证:** + +```bash +# 修复前:audit.log 0 字节,PermissionError: [WinError 32] +# 修复后:audit-2026-05-29.log 正常写入 29 条审计记录 +``` + +- 全部 41 条单元测试通过,业务逻辑零影响 + +--- + +## 6. .env.example 审计 + +对比 `config/settings.py` 中所有 `os.getenv()` 调用,确认 `.env.example` 覆盖完整: + +| 环境变量 | 是否覆盖 | 备注 | +|---|---|---| +| DB_NAME / DB_USER / DB_PASSWORD / DB_HOST / DB_PORT | ✅ | | +| REDIS_URL | ✅ | | +| SMS_PROVIDER | ✅ | mock / aliyun | +| ALIYUN_SMS_ACCESS_KEY_ID / SECRET | ✅ | | +| ALIYUN_SMS_SIGN_NAME / TEMPLATE_* | ✅ | | +| DEEPSEEK_API_KEY | ✅ | | +| DEEPSEEK_BASE_URL / MODEL / TIMEOUT / MAX_RETRIES | ✅ | | + +**安全修复**:已将 `.env.example` 中的真实密码和 API Key 替换为占位符: +- `DB_PASSWORD=your-db-password` +- `DEEPSEEK_API_KEY=your-deepseek-api-key` + +--- + +## 7. 测试覆盖的接口清单 + +### 用户端 + +| 接口 | URL | happy-path | negative | +|---|---|---|---| +| U1 发送验证码 | POST /api/user/auth/send-code/ | HP-1,2,3 | N1(限流) | +| U2 注册 | POST /api/user/auth/register/ | HP-1 | N4,N5,N6 | +| U3 密码登录 | POST /api/user/auth/login/ | HP-1,3,4 | N7,N8 | +| U4 验证码登录 | POST /api/user/auth/login-code/ | HP-2 | — | +| U5 重置密码 | POST /api/user/auth/reset-password/ | HP-3 | N9 | +| U6 修改密码 | POST /api/user/users/change-password/ | HP-4 | N2 | +| U7 退出登录 | POST /api/user/auth/logout/ | — | N10(辅助) | +| U8 刷新 Token | POST /api/user/auth/refresh/ | — | N10 | +| /me | GET /api/user/users/me/ | HP-1,2,4 | N3 | +| U9 用户列表 | GET /api/user/users/ | HP-5,6,7,11 | N11,N12,N13 | +| U10 用户详情 | GET /api/user/users/{id}/ | HP-8,9,10 | N14,N15,N16,N17 | + +### 病例端 + +| 接口 | URL | happy-path | negative | +|---|---|---|---| +| C1 PDF 解析 | POST /api/case/cases/parse-pdf/ | HP-6 | N18,N20,N21 | +| C2 生成评分规则 | POST /api/case/cases/generate-scoring-rules/ | HP-6 | — | +| C3 创建病例 | POST /api/case/cases/full-create/ | HP-5,6 | N11-N14,N16,N19 | +| C4 完整查看 | GET /api/case/cases/{id}/full/ | HP-5,6 | N17 | +| C5 编辑草稿 | PATCH /api/case/cases/{id}/full/ | HP-5,6 | N15 | + +--- + +## 8. Swagger Try-it-out 接口验证 + +> 脚本:`test/swagger_tryout.py` +> 运行方式:启动 `python manage.py runserver 8000` 后执行 `.venv\Scripts\python.exe test/swagger_tryout.py` +> PDF 文件:项目根目录 `儿科 病例样例(SOAP+循证).pdf`(真实临床 PDF) + +### 8.1 用户端(15 个接口/场景) + +| 接口 | Method | URL | 测试什么 | 期望 | 实际 | 结果 | +|---|---|---|---|---|---|---| +| U1 发送验证码 | POST | /api/user/auth/send-code/ | 向手机号发送注册验证码 | 200 | 200 | PASS | +| U2 注册 | POST | /api/user/auth/register/ | 用验证码+密码+姓名注册新账号 | 201 | 201 | PASS | +| U3 密码登录 | POST | /api/user/auth/login/ | 用手机号+密码登录,拿到 JWT token | 200 | 200 | PASS | +| U4 验证码登录 | POST | /api/user/auth/login-code/ | 用手机号+验证码免密登录 | 200 | 200 | PASS | +| U5 重置密码 | POST | /api/user/auth/reset-password/ | 忘记密码后用验证码设置新密码 | 200 | 200 | PASS | +| U6 修改密码 | POST | /api/user/users/change-password/ | 已登录用户修改密码(需旧密码验证) | 200 | 200 | PASS | +| U8 刷新 Token | POST | /api/user/auth/refresh/ | 用 refresh token 换取新的 access token | 200 | 200 | PASS | +| /me 个人信息 | GET | /api/user/users/me/ | 查看当前登录用户的完整个人信息 | 200 | 200 | PASS | +| U9 管理员列表 | GET | /api/user/users/ | 管理员获取用户列表,确认能看到全部用户 | 200 | 200 | PASS | +| U9-b 教师列表 | GET | /api/user/users/ | 教师获取用户列表,确认只能看到自己名下的 1 个学生 | 200 | 200 | PASS | +| U9-c 普通用户列表 | GET | /api/user/users/ | 普通用户(学生)获取列表被拒绝,没有权限 | 403 | 403 | PASS | +| U10 管理员查看详情 | GET | /api/user/users/{id}/ | 管理员查看任意用户的详细信息 | 200 | 200 | PASS | +| U10-b 教师查看学生 | GET | /api/user/users/{id}/ | 教师查看自己名下学生的详细信息 | 200 | 200 | PASS | +| U10-c 用户查看自己 | GET | /api/user/users/{id}/ | 普通用户查看自己的详细信息 | 200 | 200 | PASS | +| U7 退出登录 | POST | /api/user/auth/logout/ | 退出登录,吊销 refresh token 使其失效 | 200 | 200 | PASS | + +### 8.2 病例端(5 个接口) + +| 接口 | Method | URL | 测试什么 | 期望 | 实际 | 结果 | +|---|---|---|---|---|---|---| +| C1 PDF 解析 | POST | /api/case/cases/parse-pdf/ | 上传真实 PDF 文件,DeepSeek AI 解析出病例结构化数据(病名、症状、诊断等) | 200 | 200 | PASS | +| C2 生成评分规则 | POST | /api/case/cases/generate-scoring-rules/ | 用 C1 的解析结果让 AI 自动生成评分规则(如"诊断准确性""治疗方案"等维度) | 200 | 200 | PASS | +| C3 创建病例 | POST | /api/case/cases/full-create/ | 把 C1 的病例数据 + C2 的评分规则组装起来,创建完整病例 | 201 | 201 | PASS | +| C4 完整查看 | GET | /api/case/cases/{id}/full/ | 查看刚创建的病例完整信息(主表+子表+评分规则) | 200 | 200 | PASS | +| C5 编辑草稿 | PATCH | /api/case/cases/{id}/full/ | 修改草稿病例的标题,确认修改成功 | 200 | 200 | PASS | + +### 8.3 汇总 + +- **总计 20 个接口/场景,全部 PASS** +- C1→C2→C3 走完了真实 PDF 上传 → DeepSeek AI 解析 → AI 生成评分规则 → 创建病例的完整流水线 +- U9/U10 验证了管理员、教师、普通用户三种角色的列表和详情权限控制 +- 脚本自动清理 Redis 缓存、注入验证码、处理 token 失效时序(`time.sleep(1.2)`) + +--- + +## 9. 运行方式 + +```bash +# 全量单元测试(41 条) +.venv\Scripts\python.exe manage.py test test -v2 --keepdb + +# 分模块运行 +.venv\Scripts\python.exe manage.py test test.test_user_happy -v2 --keepdb +.venv\Scripts\python.exe manage.py test test.test_case_happy -v2 --keepdb +.venv\Scripts\python.exe manage.py test test.test_user_negative -v2 --keepdb +.venv\Scripts\python.exe manage.py test test.test_case_negative -v2 --keepdb + +# 单个测试 +.venv\Scripts\python.exe manage.py test test.test_user_happy.UserAuthHappyPathTest.test_flow_register_login_me -v2 --keepdb + +# Swagger Try-it-out(需先启动 dev server) +.venv\Scripts\python.exe manage.py runserver 8000 +.venv\Scripts\python.exe test/swagger_tryout.py +``` + +**前提条件**: +1. MySQL 运行,`test_medical_training` 数据库已创建(首次运行去掉 `--keepdb` 自动创建) +2. 虚拟环境已激活 +3. Redis **需要**运行(测试直接使用 Redis 缓存) +4. Swagger Try-it-out 脚本额外需要 Django dev server 运行在 8000 端口 + +--- + +## 10. 测试日志记录 + +### 10.1 单元测试 — API 访问日志 + +通过 `APIAccessLogMiddleware`(`config/middleware.py`)自动记录所有 `/api/` 请求和响应到日志文件。 + +| 项目 | 说明 | +|---|---| +| 日志文件 | `logs/api-access-YYYY-MM-DD.log` | +| 记录内容 | 请求方法、路径、请求头、Content-Type、用户 ID、查询参数、状态码、耗时、请求体、响应体(完整原文,含 token、密码等) | +| 截断阈值 | 请求体/响应体超过 2000 字符自动截断 | +| Multipart 处理 | 提取表单字段+ 文件名和大小,不记录原始二进制内容 | + +**注意**:单元测试使用 Django TestCase,每个测试结束后事务自动回滚,测试数据不会持久化到数据库。但中间件在视图执行过程中已将日志写入文件,因此日志是完整的。 + +### 10.2 Swagger 脚本 — 独立日志 + +`test/swagger_tryout.py` 将每个接口调用的完整请求体和响应体记录到独立日志文件。 + +| 项目 | 说明 | +|---|---| +| 日志文件 | `logs/test-swagger-YYYY-MM-DD.log` | +| 记录内容 | 接口 ID、方法、URL、期望状态码、实际状态码、请求头、请求体 JSON、响应体 JSON(完整原文) | +| 控制台输出 | 仅显示摘要行(PASS/FAIL + 关键信息),详细请求/响应体仅写入日志文件 | + +### 10.3 日志示例 + +``` +# api-access 日志(单元测试) +2026-05-29 11:40:52,353 INFO [api_access] POST /api/user/auth/register/ | user=None | status=201 | 399ms + >>> headers: {"Content-Type": "application/json"} + >>> body: {"phone": "13900000001", "code": "308868", "password": "TestPass1", "real_name": "张三"} + <<< body: {"message": "注册成功", "user": {...}, "tokens": {"access": "eyJhbGci...", "refresh": "eyJhbGci..."}} + +# test-swagger 日志(Swagger 脚本) + PASS U2 POST /api/user/auth/register/ expect=201 got=201 + >>> body: {"phone": "13700000099", "code": "877405", "password": "TestPass1", "real_name": "Swagger测试"} + <<< body: {"message": "注册成功", "user": {...}, "tokens": {"access": "eyJhbGci...", "refresh": "eyJhbGci..."}} +``` + +--- + +## 11. 测试结论 + +- ✅ 全部 **41 条** 单元测试通过(13 happy-path + 28 negative) +- ✅ **20 个** Swagger Try-it-out 接口验证全部通过(含真实 PDF + DeepSeek AI 完整流水线) +- ✅ 用户端 11 个接口功能正常(含 U9 用户列表、U10 用户详情的角色分级权限) +- ✅ 病例端 5 个接口功能正常 +- ✅ 限流、越权、字段校验、事务回滚、AI Schema 校验 均有覆盖 +- ✅ U9/U10 权限矩阵验证:管理员全员可见、教师仅名下活跃学生、学生/医生 403、已结束关系 403 +- ✅ `.env.example` 与代码完全一致,敏感信息已替换为占位符 +- ✅ 测试过程中发现 6 个问题,均已修复(见第 5 节) +- ✅ 完整的测试日志记录:单元测试 → API 访问日志,Swagger 脚本 → 独立日志文件 +- ✅ 未发现业务代码 Bug