142 lines
5.5 KiB
Python
142 lines
5.5 KiB
Python
import json
|
|
import time
|
|
import uuid
|
|
from datetime import datetime
|
|
from threading import Lock
|
|
from typing import Any
|
|
|
|
from app.core.config import settings
|
|
from app.core.context import UserContext
|
|
|
|
|
|
class LearningAssistantSessionStore:
|
|
"""AI 学习助手短期会话存储:使用 Redis 保存会话状态,测试或降级时使用进程内存。"""
|
|
|
|
key_prefix = "learning_assistant:session:"
|
|
|
|
def __init__(self) -> None:
|
|
self._lock = Lock()
|
|
self._memory_store: dict[str, dict[str, Any]] = {}
|
|
self._redis_client = self._create_redis_client()
|
|
|
|
def create(self, ctx: UserContext, title: str | None = None) -> dict[str, Any]:
|
|
"""学习助手会话创建:按当前用户和机构初始化一个短期问答会话。"""
|
|
now = self._now()
|
|
session_id = f"las_{uuid.uuid4().hex}"
|
|
state: dict[str, Any] = {
|
|
"assistant_session_id": session_id,
|
|
"user_id": ctx.user_id,
|
|
"institution_id": ctx.institution_id,
|
|
"institution_name": self._profile_value(ctx, "institution_name"),
|
|
"title": title or "AI 学习助手",
|
|
"status": "active",
|
|
"messages": [],
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"expires_in_seconds": settings.learning_assistant_session_ttl_seconds,
|
|
}
|
|
self._save(state)
|
|
return state
|
|
|
|
def get(self, assistant_session_id: str, user_id: str) -> dict[str, Any] | None:
|
|
"""学习助手会话读取:只返回属于当前用户且未过期的会话。"""
|
|
state = self._load(assistant_session_id)
|
|
if not state or state.get("user_id") != user_id:
|
|
return None
|
|
return state
|
|
|
|
def get_messages(self, assistant_session_id: str, user_id: str, limit: int | None = None) -> list[dict[str, Any]]:
|
|
"""学习助手上下文读取:返回当前会话最近若干轮问答,用于提示词拼接。"""
|
|
state = self.get(assistant_session_id, user_id)
|
|
if not state:
|
|
return []
|
|
messages = list(state.get("messages") or [])
|
|
if limit is None:
|
|
limit = settings.learning_assistant_history_limit
|
|
return messages[-limit:]
|
|
|
|
def append_message(
|
|
self,
|
|
assistant_session_id: str,
|
|
user_id: str,
|
|
role: str,
|
|
content: str,
|
|
metadata: dict | None = None,
|
|
) -> dict[str, Any] | None:
|
|
"""学习助手上下文写入:记录用户问题和 AI 回答,按 TTL 自动过期。"""
|
|
state = self.get(assistant_session_id, user_id)
|
|
if not state:
|
|
return None
|
|
messages = list(state.get("messages") or [])
|
|
messages.append(
|
|
{
|
|
"role": role,
|
|
"content": content,
|
|
"metadata": metadata or {},
|
|
"created_at": self._now(),
|
|
}
|
|
)
|
|
max_messages = max(settings.learning_assistant_history_limit * 2, 2)
|
|
state["messages"] = messages[-max_messages:]
|
|
state["updated_at"] = self._now()
|
|
self._save(state)
|
|
return state
|
|
|
|
def _create_redis_client(self):
|
|
"""Redis 客户端创建:遵循 runtime memory 配置,失败时按配置降级。"""
|
|
if settings.runtime_memory_backend.lower() != "redis":
|
|
return None
|
|
try:
|
|
import redis
|
|
|
|
client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
|
|
client.ping()
|
|
return client
|
|
except Exception:
|
|
if settings.is_production and not settings.runtime_memory_fallback_enabled:
|
|
raise RuntimeError("Redis is required for learning assistant sessions")
|
|
return None
|
|
|
|
def _save(self, state: dict[str, Any]) -> None:
|
|
"""会话保存:Redis 使用 setex,内存 fallback 使用过期时间戳。"""
|
|
session_id = state["assistant_session_id"]
|
|
ttl = settings.learning_assistant_session_ttl_seconds
|
|
if self._redis_client is not None:
|
|
self._redis_client.setex(self._key(session_id), ttl, json.dumps(state, ensure_ascii=False))
|
|
return
|
|
with self._lock:
|
|
self._memory_store[session_id] = {"expires_at": time.time() + ttl, "state": state}
|
|
|
|
def _load(self, assistant_session_id: str) -> dict[str, Any] | None:
|
|
"""会话加载:读取并校验短期会话是否仍然有效。"""
|
|
if self._redis_client is not None:
|
|
raw = self._redis_client.get(self._key(assistant_session_id))
|
|
if not raw:
|
|
return None
|
|
return json.loads(raw)
|
|
with self._lock:
|
|
item = self._memory_store.get(assistant_session_id)
|
|
if not item:
|
|
return None
|
|
if item["expires_at"] < time.time():
|
|
self._memory_store.pop(assistant_session_id, None)
|
|
return None
|
|
return dict(item["state"])
|
|
|
|
def _key(self, assistant_session_id: str) -> str:
|
|
"""Redis key 生成:与训练短期 memory 隔离。"""
|
|
return f"{self.key_prefix}{assistant_session_id}"
|
|
|
|
def _profile_value(self, ctx: UserContext, key: str) -> Any:
|
|
"""用户资料读取:从 Django `/me` 标准化 profile 中提取扩展字段。"""
|
|
if not ctx.profile:
|
|
return None
|
|
return ctx.profile.get(key)
|
|
|
|
def _now(self) -> str:
|
|
"""时间格式化:返回 ISO 字符串,便于前端展示和日志排查。"""
|
|
return datetime.utcnow().isoformat(timespec="seconds") + "Z"
|
|
|
|
|
|
learning_assistant_session_store = LearningAssistantSessionStore()
|