Files
fastapi/app/services/learning_assistant_session_store.py
T
2026-06-11 16:19:50 +08:00

142 lines
5.5 KiB
Python

import json
import time
import uuid
from datetime import datetime
from threading import Lock
from typing import Any
from app.core.config import settings
from app.core.context import UserContext
class LearningAssistantSessionStore:
"""AI 学习助手短期会话存储:使用 Redis 保存会话状态,测试或降级时使用进程内存。"""
key_prefix = "learning_assistant:session:"
def __init__(self) -> None:
self._lock = Lock()
self._memory_store: dict[str, dict[str, Any]] = {}
self._redis_client = self._create_redis_client()
def create(self, ctx: UserContext, title: str | None = None) -> dict[str, Any]:
"""学习助手会话创建:按当前用户和机构初始化一个短期问答会话。"""
now = self._now()
session_id = f"las_{uuid.uuid4().hex}"
state: dict[str, Any] = {
"assistant_session_id": session_id,
"user_id": ctx.user_id,
"institution_id": ctx.institution_id,
"institution_name": self._profile_value(ctx, "institution_name"),
"title": title or "AI 学习助手",
"status": "active",
"messages": [],
"created_at": now,
"updated_at": now,
"expires_in_seconds": settings.learning_assistant_session_ttl_seconds,
}
self._save(state)
return state
def get(self, assistant_session_id: str, user_id: str) -> dict[str, Any] | None:
"""学习助手会话读取:只返回属于当前用户且未过期的会话。"""
state = self._load(assistant_session_id)
if not state or state.get("user_id") != user_id:
return None
return state
def get_messages(self, assistant_session_id: str, user_id: str, limit: int | None = None) -> list[dict[str, Any]]:
"""学习助手上下文读取:返回当前会话最近若干轮问答,用于提示词拼接。"""
state = self.get(assistant_session_id, user_id)
if not state:
return []
messages = list(state.get("messages") or [])
if limit is None:
limit = settings.learning_assistant_history_limit
return messages[-limit:]
def append_message(
self,
assistant_session_id: str,
user_id: str,
role: str,
content: str,
metadata: dict | None = None,
) -> dict[str, Any] | None:
"""学习助手上下文写入:记录用户问题和 AI 回答,按 TTL 自动过期。"""
state = self.get(assistant_session_id, user_id)
if not state:
return None
messages = list(state.get("messages") or [])
messages.append(
{
"role": role,
"content": content,
"metadata": metadata or {},
"created_at": self._now(),
}
)
max_messages = max(settings.learning_assistant_history_limit * 2, 2)
state["messages"] = messages[-max_messages:]
state["updated_at"] = self._now()
self._save(state)
return state
def _create_redis_client(self):
"""Redis 客户端创建:遵循 runtime memory 配置,失败时按配置降级。"""
if settings.runtime_memory_backend.lower() != "redis":
return None
try:
import redis
client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
client.ping()
return client
except Exception:
if settings.is_production and not settings.runtime_memory_fallback_enabled:
raise RuntimeError("Redis is required for learning assistant sessions")
return None
def _save(self, state: dict[str, Any]) -> None:
"""会话保存:Redis 使用 setex,内存 fallback 使用过期时间戳。"""
session_id = state["assistant_session_id"]
ttl = settings.learning_assistant_session_ttl_seconds
if self._redis_client is not None:
self._redis_client.setex(self._key(session_id), ttl, json.dumps(state, ensure_ascii=False))
return
with self._lock:
self._memory_store[session_id] = {"expires_at": time.time() + ttl, "state": state}
def _load(self, assistant_session_id: str) -> dict[str, Any] | None:
"""会话加载:读取并校验短期会话是否仍然有效。"""
if self._redis_client is not None:
raw = self._redis_client.get(self._key(assistant_session_id))
if not raw:
return None
return json.loads(raw)
with self._lock:
item = self._memory_store.get(assistant_session_id)
if not item:
return None
if item["expires_at"] < time.time():
self._memory_store.pop(assistant_session_id, None)
return None
return dict(item["state"])
def _key(self, assistant_session_id: str) -> str:
"""Redis key 生成:与训练短期 memory 隔离。"""
return f"{self.key_prefix}{assistant_session_id}"
def _profile_value(self, ctx: UserContext, key: str) -> Any:
"""用户资料读取:从 Django `/me` 标准化 profile 中提取扩展字段。"""
if not ctx.profile:
return None
return ctx.profile.get(key)
def _now(self) -> str:
"""时间格式化:返回 ISO 字符串,便于前端展示和日志排查。"""
return datetime.utcnow().isoformat(timespec="seconds") + "Z"
learning_assistant_session_store = LearningAssistantSessionStore()