chore: finalize backend feature scope
This commit is contained in:
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import UserContext
|
||||
|
||||
|
||||
class LearningAssistantSessionStore:
|
||||
"""AI 学习助手短期会话存储:使用 Redis 保存会话状态,测试或降级时使用进程内存。"""
|
||||
|
||||
key_prefix = "learning_assistant:session:"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = Lock()
|
||||
self._memory_store: dict[str, dict[str, Any]] = {}
|
||||
self._redis_client = self._create_redis_client()
|
||||
|
||||
def create(self, ctx: UserContext, title: str | None = None) -> dict[str, Any]:
|
||||
"""学习助手会话创建:按当前用户和机构初始化一个短期问答会话。"""
|
||||
now = self._now()
|
||||
session_id = f"las_{uuid.uuid4().hex}"
|
||||
state: dict[str, Any] = {
|
||||
"assistant_session_id": session_id,
|
||||
"user_id": ctx.user_id,
|
||||
"institution_id": ctx.institution_id,
|
||||
"institution_name": self._profile_value(ctx, "institution_name"),
|
||||
"title": title or "AI 学习助手",
|
||||
"status": "active",
|
||||
"messages": [],
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"expires_in_seconds": settings.learning_assistant_session_ttl_seconds,
|
||||
}
|
||||
self._save(state)
|
||||
return state
|
||||
|
||||
def get(self, assistant_session_id: str, user_id: str) -> dict[str, Any] | None:
|
||||
"""学习助手会话读取:只返回属于当前用户且未过期的会话。"""
|
||||
state = self._load(assistant_session_id)
|
||||
if not state or state.get("user_id") != user_id:
|
||||
return None
|
||||
return state
|
||||
|
||||
def get_messages(self, assistant_session_id: str, user_id: str, limit: int | None = None) -> list[dict[str, Any]]:
|
||||
"""学习助手上下文读取:返回当前会话最近若干轮问答,用于提示词拼接。"""
|
||||
state = self.get(assistant_session_id, user_id)
|
||||
if not state:
|
||||
return []
|
||||
messages = list(state.get("messages") or [])
|
||||
if limit is None:
|
||||
limit = settings.learning_assistant_history_limit
|
||||
return messages[-limit:]
|
||||
|
||||
def append_message(
|
||||
self,
|
||||
assistant_session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: dict | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""学习助手上下文写入:记录用户问题和 AI 回答,按 TTL 自动过期。"""
|
||||
state = self.get(assistant_session_id, user_id)
|
||||
if not state:
|
||||
return None
|
||||
messages = list(state.get("messages") or [])
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": self._now(),
|
||||
}
|
||||
)
|
||||
max_messages = max(settings.learning_assistant_history_limit * 2, 2)
|
||||
state["messages"] = messages[-max_messages:]
|
||||
state["updated_at"] = self._now()
|
||||
self._save(state)
|
||||
return state
|
||||
|
||||
def _create_redis_client(self):
|
||||
"""Redis 客户端创建:遵循 runtime memory 配置,失败时按配置降级。"""
|
||||
if settings.runtime_memory_backend.lower() != "redis":
|
||||
return None
|
||||
try:
|
||||
import redis
|
||||
|
||||
client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
if settings.is_production and not settings.runtime_memory_fallback_enabled:
|
||||
raise RuntimeError("Redis is required for learning assistant sessions")
|
||||
return None
|
||||
|
||||
def _save(self, state: dict[str, Any]) -> None:
|
||||
"""会话保存:Redis 使用 setex,内存 fallback 使用过期时间戳。"""
|
||||
session_id = state["assistant_session_id"]
|
||||
ttl = settings.learning_assistant_session_ttl_seconds
|
||||
if self._redis_client is not None:
|
||||
self._redis_client.setex(self._key(session_id), ttl, json.dumps(state, ensure_ascii=False))
|
||||
return
|
||||
with self._lock:
|
||||
self._memory_store[session_id] = {"expires_at": time.time() + ttl, "state": state}
|
||||
|
||||
def _load(self, assistant_session_id: str) -> dict[str, Any] | None:
|
||||
"""会话加载:读取并校验短期会话是否仍然有效。"""
|
||||
if self._redis_client is not None:
|
||||
raw = self._redis_client.get(self._key(assistant_session_id))
|
||||
if not raw:
|
||||
return None
|
||||
return json.loads(raw)
|
||||
with self._lock:
|
||||
item = self._memory_store.get(assistant_session_id)
|
||||
if not item:
|
||||
return None
|
||||
if item["expires_at"] < time.time():
|
||||
self._memory_store.pop(assistant_session_id, None)
|
||||
return None
|
||||
return dict(item["state"])
|
||||
|
||||
def _key(self, assistant_session_id: str) -> str:
|
||||
"""Redis key 生成:与训练短期 memory 隔离。"""
|
||||
return f"{self.key_prefix}{assistant_session_id}"
|
||||
|
||||
def _profile_value(self, ctx: UserContext, key: str) -> Any:
|
||||
"""用户资料读取:从 Django `/me` 标准化 profile 中提取扩展字段。"""
|
||||
if not ctx.profile:
|
||||
return None
|
||||
return ctx.profile.get(key)
|
||||
|
||||
def _now(self) -> str:
|
||||
"""时间格式化:返回 ISO 字符串,便于前端展示和日志排查。"""
|
||||
return datetime.utcnow().isoformat(timespec="seconds") + "Z"
|
||||
|
||||
|
||||
learning_assistant_session_store = LearningAssistantSessionStore()
|
||||
Reference in New Issue
Block a user