import json from datetime import datetime, timedelta from threading import Lock from app.core.config import settings class BaseRuntimeMemoryService: """短期 memory 基类:定义会话消息缓存的统一接口。""" def create(self, memory_key: str, patient_opening: str | None = None) -> None: raise NotImplementedError def add_message(self, memory_key: str, role: str, content: str, structured: dict | None = None) -> None: raise NotImplementedError def get_messages(self, memory_key: str | None) -> list[dict]: raise NotImplementedError def has_doctor_message(self, memory_key: str | None) -> bool: """问诊校验:判断当前会话是否存在医生问诊消息。""" return any(item["role"] == "doctor" for item in self.get_messages(memory_key)) def release(self, memory_key: str | None) -> None: raise NotImplementedError class InMemoryRuntimeMemoryService(BaseRuntimeMemoryService): """进程内短期 memory:用于无 Redis 环境下的 Demo 兜底。""" def __init__(self) -> None: self._store: dict[str, dict] = {} self._lock = Lock() def create(self, memory_key: str, patient_opening: str | None = None) -> None: """memory 创建:为新会话初始化短期消息容器。""" with self._lock: self._store[memory_key] = { "expires_at": datetime.utcnow() + timedelta(seconds=settings.runtime_memory_ttl_seconds), "messages": [], } if patient_opening: self._store[memory_key]["messages"].append( {"role": "patient", "content": patient_opening, "structured": None, "created_at": datetime.utcnow().isoformat()} ) def add_message(self, memory_key: str, role: str, content: str, structured: dict | None = None) -> None: """memory 写入:追加医生、病人或工具消息。""" with self._lock: self._ensure(memory_key) self._store[memory_key]["messages"].append( {"role": role, "content": content, "structured": structured, "created_at": datetime.utcnow().isoformat()} ) def get_messages(self, memory_key: str | None) -> list[dict]: """memory 读取:返回当前会话的短期消息列表。""" if not memory_key: return [] with self._lock: self._ensure(memory_key) return list(self._store[memory_key]["messages"]) def release(self, memory_key: str | None) -> None: """memory 释放:评价完成后删除短期聊天记录。""" if not memory_key: return with self._lock: self._store.pop(memory_key, None) def _ensure(self, memory_key: str) -> None: """memory 兜底:内存丢失或过期时重新创建空容器。""" current = self._store.get(memory_key) if not current or current["expires_at"] < datetime.utcnow(): self._store[memory_key] = { "expires_at": datetime.utcnow() + timedelta(seconds=settings.runtime_memory_ttl_seconds), "messages": [], } class RedisRuntimeMemoryService(BaseRuntimeMemoryService): """Redis 短期 memory:保存单次训练过程中的问诊消息并按 TTL 自动过期。""" def __init__(self) -> None: try: import redis except ImportError as exc: raise RuntimeError("redis package is required for RedisRuntimeMemoryService") from exc self.client = redis.Redis.from_url(settings.redis_url, decode_responses=True) def create(self, memory_key: str, patient_opening: str | None = None) -> None: """Redis memory 创建:初始化会话消息列表并设置过期时间。""" self.client.delete(memory_key) if patient_opening: self.add_message(memory_key, "patient", patient_opening) self.client.expire(memory_key, settings.runtime_memory_ttl_seconds) def add_message(self, memory_key: str, role: str, content: str, structured: dict | None = None) -> None: """Redis memory 写入:追加一条问诊、病人或工具消息。""" payload = { "role": role, "content": content, "structured": structured, "created_at": datetime.utcnow().isoformat(), } self.client.rpush(memory_key, json.dumps(payload, ensure_ascii=False)) self.client.expire(memory_key, settings.runtime_memory_ttl_seconds) def get_messages(self, memory_key: str | None) -> list[dict]: """Redis memory 读取:读取当前会话短期消息列表。""" if not memory_key: return [] return [json.loads(item) for item in self.client.lrange(memory_key, 0, -1)] def release(self, memory_key: str | None) -> None: """Redis memory 释放:评价完成后删除短期消息。""" if memory_key: self.client.delete(memory_key) def create_runtime_memory_service() -> BaseRuntimeMemoryService: """memory 选择:根据配置启用 Redis,失败时回退进程内 memory。""" if settings.runtime_memory_backend.lower() == "redis": try: service = RedisRuntimeMemoryService() service.client.ping() return service except Exception: if settings.is_production or not settings.runtime_memory_fallback_enabled: raise RuntimeError("Redis runtime memory is required but unavailable") return InMemoryRuntimeMemoryService() return InMemoryRuntimeMemoryService() RuntimeMemoryService = InMemoryRuntimeMemoryService runtime_memory = create_runtime_memory_service()