import json import time import uuid from datetime import datetime from threading import Lock from typing import Any from app.core.config import settings from app.core.context import UserContext class LearningAssistantSessionStore: """AI 学习助手短期会话存储:使用 Redis 保存会话状态,测试或降级时使用进程内存。""" key_prefix = "learning_assistant:session:" def __init__(self) -> None: self._lock = Lock() self._memory_store: dict[str, dict[str, Any]] = {} self._redis_client = self._create_redis_client() def create(self, ctx: UserContext, title: str | None = None) -> dict[str, Any]: """学习助手会话创建:按当前用户和机构初始化一个短期问答会话。""" now = self._now() session_id = f"las_{uuid.uuid4().hex}" state: dict[str, Any] = { "assistant_session_id": session_id, "user_id": ctx.user_id, "institution_id": ctx.institution_id, "institution_name": self._profile_value(ctx, "institution_name"), "title": title or "AI 学习助手", "status": "active", "messages": [], "created_at": now, "updated_at": now, "expires_in_seconds": settings.learning_assistant_session_ttl_seconds, } self._save(state) return state def get(self, assistant_session_id: str, user_id: str) -> dict[str, Any] | None: """学习助手会话读取:只返回属于当前用户且未过期的会话。""" state = self._load(assistant_session_id) if not state or state.get("user_id") != user_id: return None return state def get_messages(self, assistant_session_id: str, user_id: str, limit: int | None = None) -> list[dict[str, Any]]: """学习助手上下文读取:返回当前会话最近若干轮问答,用于提示词拼接。""" state = self.get(assistant_session_id, user_id) if not state: return [] messages = list(state.get("messages") or []) if limit is None: limit = settings.learning_assistant_history_limit return messages[-limit:] def append_message( self, assistant_session_id: str, user_id: str, role: str, content: str, metadata: dict | None = None, ) -> dict[str, Any] | None: """学习助手上下文写入:记录用户问题和 AI 回答,按 TTL 自动过期。""" state = self.get(assistant_session_id, user_id) if not state: return None messages = list(state.get("messages") or []) messages.append( { "role": role, "content": content, "metadata": metadata or {}, "created_at": self._now(), } ) max_messages = max(settings.learning_assistant_history_limit * 2, 2) state["messages"] = messages[-max_messages:] state["updated_at"] = self._now() self._save(state) return state def _create_redis_client(self): """Redis 客户端创建:遵循 runtime memory 配置,失败时按配置降级。""" if settings.runtime_memory_backend.lower() != "redis": return None try: import redis client = redis.Redis.from_url(settings.redis_url, decode_responses=True) client.ping() return client except Exception: if settings.is_production and not settings.runtime_memory_fallback_enabled: raise RuntimeError("Redis is required for learning assistant sessions") return None def _save(self, state: dict[str, Any]) -> None: """会话保存:Redis 使用 setex,内存 fallback 使用过期时间戳。""" session_id = state["assistant_session_id"] ttl = settings.learning_assistant_session_ttl_seconds if self._redis_client is not None: self._redis_client.setex(self._key(session_id), ttl, json.dumps(state, ensure_ascii=False)) return with self._lock: self._memory_store[session_id] = {"expires_at": time.time() + ttl, "state": state} def _load(self, assistant_session_id: str) -> dict[str, Any] | None: """会话加载:读取并校验短期会话是否仍然有效。""" if self._redis_client is not None: raw = self._redis_client.get(self._key(assistant_session_id)) if not raw: return None return json.loads(raw) with self._lock: item = self._memory_store.get(assistant_session_id) if not item: return None if item["expires_at"] < time.time(): self._memory_store.pop(assistant_session_id, None) return None return dict(item["state"]) def _key(self, assistant_session_id: str) -> str: """Redis key 生成:与训练短期 memory 隔离。""" return f"{self.key_prefix}{assistant_session_id}" def _profile_value(self, ctx: UserContext, key: str) -> Any: """用户资料读取:从 Django `/me` 标准化 profile 中提取扩展字段。""" if not ctx.profile: return None return ctx.profile.get(key) def _now(self) -> str: """时间格式化:返回 ISO 字符串,便于前端展示和日志排查。""" return datetime.utcnow().isoformat(timespec="seconds") + "Z" learning_assistant_session_store = LearningAssistantSessionStore()