134 lines
5.5 KiB
Python
134 lines
5.5 KiB
Python
import json
|
|
from datetime import datetime, timedelta
|
|
from threading import Lock
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
class BaseRuntimeMemoryService:
|
|
"""短期 memory 基类:定义会话消息缓存的统一接口。"""
|
|
|
|
def create(self, memory_key: str, patient_opening: str | None = None) -> None:
|
|
raise NotImplementedError
|
|
|
|
def add_message(self, memory_key: str, role: str, content: str, structured: dict | None = None) -> None:
|
|
raise NotImplementedError
|
|
|
|
def get_messages(self, memory_key: str | None) -> list[dict]:
|
|
raise NotImplementedError
|
|
|
|
def has_doctor_message(self, memory_key: str | None) -> bool:
|
|
"""问诊校验:判断当前会话是否存在医生问诊消息。"""
|
|
return any(item["role"] == "doctor" for item in self.get_messages(memory_key))
|
|
|
|
def release(self, memory_key: str | None) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class InMemoryRuntimeMemoryService(BaseRuntimeMemoryService):
|
|
"""进程内短期 memory:用于无 Redis 环境下的 Demo 兜底。"""
|
|
|
|
def __init__(self) -> None:
|
|
self._store: dict[str, dict] = {}
|
|
self._lock = Lock()
|
|
|
|
def create(self, memory_key: str, patient_opening: str | None = None) -> None:
|
|
"""memory 创建:为新会话初始化短期消息容器。"""
|
|
with self._lock:
|
|
self._store[memory_key] = {
|
|
"expires_at": datetime.utcnow() + timedelta(seconds=settings.runtime_memory_ttl_seconds),
|
|
"messages": [],
|
|
}
|
|
if patient_opening:
|
|
self._store[memory_key]["messages"].append(
|
|
{"role": "patient", "content": patient_opening, "structured": None, "created_at": datetime.utcnow().isoformat()}
|
|
)
|
|
|
|
def add_message(self, memory_key: str, role: str, content: str, structured: dict | None = None) -> None:
|
|
"""memory 写入:追加医生、病人或工具消息。"""
|
|
with self._lock:
|
|
self._ensure(memory_key)
|
|
self._store[memory_key]["messages"].append(
|
|
{"role": role, "content": content, "structured": structured, "created_at": datetime.utcnow().isoformat()}
|
|
)
|
|
|
|
def get_messages(self, memory_key: str | None) -> list[dict]:
|
|
"""memory 读取:返回当前会话的短期消息列表。"""
|
|
if not memory_key:
|
|
return []
|
|
with self._lock:
|
|
self._ensure(memory_key)
|
|
return list(self._store[memory_key]["messages"])
|
|
|
|
def release(self, memory_key: str | None) -> None:
|
|
"""memory 释放:评价完成后删除短期聊天记录。"""
|
|
if not memory_key:
|
|
return
|
|
with self._lock:
|
|
self._store.pop(memory_key, None)
|
|
|
|
def _ensure(self, memory_key: str) -> None:
|
|
"""memory 兜底:内存丢失或过期时重新创建空容器。"""
|
|
current = self._store.get(memory_key)
|
|
if not current or current["expires_at"] < datetime.utcnow():
|
|
self._store[memory_key] = {
|
|
"expires_at": datetime.utcnow() + timedelta(seconds=settings.runtime_memory_ttl_seconds),
|
|
"messages": [],
|
|
}
|
|
|
|
|
|
class RedisRuntimeMemoryService(BaseRuntimeMemoryService):
|
|
"""Redis 短期 memory:保存单次训练过程中的问诊消息并按 TTL 自动过期。"""
|
|
|
|
def __init__(self) -> None:
|
|
try:
|
|
import redis
|
|
except ImportError as exc:
|
|
raise RuntimeError("redis package is required for RedisRuntimeMemoryService") from exc
|
|
self.client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
|
|
|
|
def create(self, memory_key: str, patient_opening: str | None = None) -> None:
|
|
"""Redis memory 创建:初始化会话消息列表并设置过期时间。"""
|
|
self.client.delete(memory_key)
|
|
if patient_opening:
|
|
self.add_message(memory_key, "patient", patient_opening)
|
|
self.client.expire(memory_key, settings.runtime_memory_ttl_seconds)
|
|
|
|
def add_message(self, memory_key: str, role: str, content: str, structured: dict | None = None) -> None:
|
|
"""Redis memory 写入:追加一条问诊、病人或工具消息。"""
|
|
payload = {
|
|
"role": role,
|
|
"content": content,
|
|
"structured": structured,
|
|
"created_at": datetime.utcnow().isoformat(),
|
|
}
|
|
self.client.rpush(memory_key, json.dumps(payload, ensure_ascii=False))
|
|
self.client.expire(memory_key, settings.runtime_memory_ttl_seconds)
|
|
|
|
def get_messages(self, memory_key: str | None) -> list[dict]:
|
|
"""Redis memory 读取:读取当前会话短期消息列表。"""
|
|
if not memory_key:
|
|
return []
|
|
return [json.loads(item) for item in self.client.lrange(memory_key, 0, -1)]
|
|
|
|
def release(self, memory_key: str | None) -> None:
|
|
"""Redis memory 释放:评价完成后删除短期消息。"""
|
|
if memory_key:
|
|
self.client.delete(memory_key)
|
|
|
|
|
|
def create_runtime_memory_service() -> BaseRuntimeMemoryService:
|
|
"""memory 选择:根据配置启用 Redis,失败时回退进程内 memory。"""
|
|
if settings.runtime_memory_backend.lower() == "redis":
|
|
try:
|
|
service = RedisRuntimeMemoryService()
|
|
service.client.ping()
|
|
return service
|
|
except Exception:
|
|
return InMemoryRuntimeMemoryService()
|
|
return InMemoryRuntimeMemoryService()
|
|
|
|
|
|
RuntimeMemoryService = InMemoryRuntimeMemoryService
|
|
runtime_memory = create_runtime_memory_service()
|