Files
fastapi/backend/app/services/runtime_memory.py
T

134 lines
5.5 KiB
Python
Raw Normal View History

import json
from datetime import datetime, timedelta
from threading import Lock
from app.core.config import settings
class BaseRuntimeMemoryService:
"""短期 memory 基类:定义会话消息缓存的统一接口。"""
def create(self, memory_key: str, patient_opening: str | None = None) -> None:
raise NotImplementedError
def add_message(self, memory_key: str, role: str, content: str, structured: dict | None = None) -> None:
raise NotImplementedError
def get_messages(self, memory_key: str | None) -> list[dict]:
raise NotImplementedError
def has_doctor_message(self, memory_key: str | None) -> bool:
"""问诊校验:判断当前会话是否存在医生问诊消息。"""
return any(item["role"] == "doctor" for item in self.get_messages(memory_key))
def release(self, memory_key: str | None) -> None:
raise NotImplementedError
class InMemoryRuntimeMemoryService(BaseRuntimeMemoryService):
"""进程内短期 memory:用于无 Redis 环境下的 Demo 兜底。"""
def __init__(self) -> None:
self._store: dict[str, dict] = {}
self._lock = Lock()
def create(self, memory_key: str, patient_opening: str | None = None) -> None:
"""memory 创建:为新会话初始化短期消息容器。"""
with self._lock:
self._store[memory_key] = {
"expires_at": datetime.utcnow() + timedelta(seconds=settings.runtime_memory_ttl_seconds),
"messages": [],
}
if patient_opening:
self._store[memory_key]["messages"].append(
{"role": "patient", "content": patient_opening, "structured": None, "created_at": datetime.utcnow().isoformat()}
)
def add_message(self, memory_key: str, role: str, content: str, structured: dict | None = None) -> None:
"""memory 写入:追加医生、病人或工具消息。"""
with self._lock:
self._ensure(memory_key)
self._store[memory_key]["messages"].append(
{"role": role, "content": content, "structured": structured, "created_at": datetime.utcnow().isoformat()}
)
def get_messages(self, memory_key: str | None) -> list[dict]:
"""memory 读取:返回当前会话的短期消息列表。"""
if not memory_key:
return []
with self._lock:
self._ensure(memory_key)
return list(self._store[memory_key]["messages"])
def release(self, memory_key: str | None) -> None:
"""memory 释放:评价完成后删除短期聊天记录。"""
if not memory_key:
return
with self._lock:
self._store.pop(memory_key, None)
def _ensure(self, memory_key: str) -> None:
"""memory 兜底:内存丢失或过期时重新创建空容器。"""
current = self._store.get(memory_key)
if not current or current["expires_at"] < datetime.utcnow():
self._store[memory_key] = {
"expires_at": datetime.utcnow() + timedelta(seconds=settings.runtime_memory_ttl_seconds),
"messages": [],
}
class RedisRuntimeMemoryService(BaseRuntimeMemoryService):
"""Redis 短期 memory:保存单次训练过程中的问诊消息并按 TTL 自动过期。"""
def __init__(self) -> None:
try:
import redis
except ImportError as exc:
raise RuntimeError("redis package is required for RedisRuntimeMemoryService") from exc
self.client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
def create(self, memory_key: str, patient_opening: str | None = None) -> None:
"""Redis memory 创建:初始化会话消息列表并设置过期时间。"""
self.client.delete(memory_key)
if patient_opening:
self.add_message(memory_key, "patient", patient_opening)
self.client.expire(memory_key, settings.runtime_memory_ttl_seconds)
def add_message(self, memory_key: str, role: str, content: str, structured: dict | None = None) -> None:
"""Redis memory 写入:追加一条问诊、病人或工具消息。"""
payload = {
"role": role,
"content": content,
"structured": structured,
"created_at": datetime.utcnow().isoformat(),
}
self.client.rpush(memory_key, json.dumps(payload, ensure_ascii=False))
self.client.expire(memory_key, settings.runtime_memory_ttl_seconds)
def get_messages(self, memory_key: str | None) -> list[dict]:
"""Redis memory 读取:读取当前会话短期消息列表。"""
if not memory_key:
return []
return [json.loads(item) for item in self.client.lrange(memory_key, 0, -1)]
def release(self, memory_key: str | None) -> None:
"""Redis memory 释放:评价完成后删除短期消息。"""
if memory_key:
self.client.delete(memory_key)
def create_runtime_memory_service() -> BaseRuntimeMemoryService:
"""memory 选择:根据配置启用 Redis,失败时回退进程内 memory。"""
if settings.runtime_memory_backend.lower() == "redis":
try:
service = RedisRuntimeMemoryService()
service.client.ping()
return service
except Exception:
return InMemoryRuntimeMemoryService()
return InMemoryRuntimeMemoryService()
RuntimeMemoryService = InMemoryRuntimeMemoryService
runtime_memory = create_runtime_memory_service()