prepare fastapi root layout for server deployment
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""核心配置、异常、响应和用户上下文模块。"""
|
||||
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _load_dotenv_file() -> None:
|
||||
"""环境加载:轻量读取项目根目录 `.env`,避免强依赖 python-dotenv。"""
|
||||
env_path = next(
|
||||
(parent / ".env" for parent in Path(__file__).resolve().parents if (parent / ".env").exists()),
|
||||
None,
|
||||
)
|
||||
if env_path is None:
|
||||
return
|
||||
for line in env_path.read_text(encoding="utf-8").splitlines():
|
||||
if not line or line.strip().startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
os.environ.setdefault(key.strip(), value.strip())
|
||||
|
||||
|
||||
_load_dotenv_file()
|
||||
|
||||
|
||||
def _env_first(*keys: str, default: str = "") -> str:
|
||||
"""环境读取:按优先级读取多个环境变量。"""
|
||||
for key in keys:
|
||||
value = os.getenv(key)
|
||||
if value:
|
||||
return value
|
||||
return default
|
||||
|
||||
|
||||
def _normalize_sync_database_url(url: str) -> str:
|
||||
"""数据库连接:将异步 MySQL URL 转换为当前同步 ORM 可用的 URL。"""
|
||||
if url.startswith("mysql+aiomysql://"):
|
||||
return url.replace("mysql+aiomysql://", "mysql+pymysql://", 1)
|
||||
if url.startswith("mysql://"):
|
||||
return url.replace("mysql://", "mysql+pymysql://", 1)
|
||||
return url
|
||||
|
||||
|
||||
def _env_bool(key: str, default: bool) -> bool:
|
||||
"""布尔配置:统一解析环境变量中的 true/false 开关。"""
|
||||
return os.getenv(key, str(default)).lower() == "true"
|
||||
|
||||
|
||||
def _env_csv(key: str, default: str = "") -> list[str]:
|
||||
"""列表配置:把逗号分隔的环境变量转换为去空白列表。"""
|
||||
return [item.strip() for item in os.getenv(key, default).split(",") if item.strip()]
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""系统配置:集中管理数据库、DeepSeek、报告和短期 memory 配置。"""
|
||||
|
||||
app_name: str = Field(default_factory=lambda: os.getenv("APP_NAME", "Medical Consultation Agent Demo"))
|
||||
app_env: str = Field(default_factory=lambda: os.getenv("APP_ENV", "local"))
|
||||
app_debug: bool = Field(default_factory=lambda: _env_bool("APP_DEBUG", True))
|
||||
app_root_path: str = Field(default_factory=lambda: os.getenv("APP_ROOT_PATH", ""))
|
||||
api_v1_prefix: str = Field(default_factory=lambda: os.getenv("API_V1_PREFIX", "/api/v1"))
|
||||
cors_allow_origins: list[str] = Field(
|
||||
default_factory=lambda: _env_csv(
|
||||
"CORS_ALLOW_ORIGINS",
|
||||
"http://127.0.0.1:5173,http://localhost:5173,http://127.0.0.1:5174,http://localhost:5174",
|
||||
)
|
||||
)
|
||||
cors_allow_origin_regex: str = Field(
|
||||
default_factory=lambda: os.getenv(
|
||||
"CORS_ALLOW_ORIGIN_REGEX",
|
||||
r"^http://(127\.0\.0\.1|localhost|192\.168\.\d+\.\d+):\d+$",
|
||||
)
|
||||
)
|
||||
|
||||
mysql_url: str = Field(default_factory=lambda: os.getenv("MYSQL_URL", ""))
|
||||
database_url: str = Field(
|
||||
default_factory=lambda: _normalize_sync_database_url(
|
||||
_env_first("DATABASE_URL", "MYSQL_URL", default="sqlite:///./storage/demo.db")
|
||||
)
|
||||
)
|
||||
|
||||
llm_api_key: str = Field(default_factory=lambda: _env_first("LLM_API_KEY", "DEEPSEEK_API_KEY", default=""))
|
||||
llm_base_url: str = Field(
|
||||
default_factory=lambda: _env_first("LLM_BASE_URL", "DEEPSEEK_BASE_URL", default="https://api.deepseek.com")
|
||||
)
|
||||
llm_model: str = Field(default_factory=lambda: _env_first("LLM_MODEL", "DEEPSEEK_FAST_MODEL", default="deepseek-chat"))
|
||||
llm_fast_model: str = Field(default_factory=lambda: _env_first("LLM_FAST_MODEL", "LLM_MODEL", "DEEPSEEK_FAST_MODEL", default="deepseek-chat"))
|
||||
llm_reason_model: str = Field(
|
||||
default_factory=lambda: _env_first("LLM_REASON_MODEL", "LLM_MODEL", "DEEPSEEK_REASON_MODEL", default="deepseek-reasoner")
|
||||
)
|
||||
llm_timeout_seconds: int = Field(default_factory=lambda: int(os.getenv("LLM_TIMEOUT_SECONDS", "45")))
|
||||
llm_chat_timeout_seconds: int = Field(default_factory=lambda: int(os.getenv("LLM_CHAT_TIMEOUT_SECONDS", "20")))
|
||||
llm_stream_first_token_timeout_seconds: int = Field(
|
||||
default_factory=lambda: int(os.getenv("LLM_STREAM_FIRST_TOKEN_TIMEOUT_SECONDS", "15"))
|
||||
)
|
||||
llm_stream_total_timeout_seconds: int = Field(default_factory=lambda: int(os.getenv("LLM_STREAM_TOTAL_TIMEOUT_SECONDS", "45")))
|
||||
llm_stream_enabled: bool = Field(default_factory=lambda: _env_bool("LLM_STREAM_ENABLED", True))
|
||||
llm_mock_enabled: bool = Field(default_factory=lambda: _env_bool("LLM_MOCK_ENABLED", True))
|
||||
llm_fallback_to_mock: bool = Field(default_factory=lambda: _env_bool("LLM_FALLBACK_TO_MOCK", True))
|
||||
llm_fast_thinking_enabled: bool = Field(default_factory=lambda: _env_bool("LLM_FAST_THINKING_ENABLED", False))
|
||||
llm_reason_thinking_enabled: bool = Field(default_factory=lambda: _env_bool("LLM_REASON_THINKING_ENABLED", False))
|
||||
llm_reasoning_effort: str = Field(default_factory=lambda: os.getenv("LLM_REASONING_EFFORT", "low"))
|
||||
llm_fast_max_tokens: int = Field(default_factory=lambda: int(os.getenv("LLM_FAST_MAX_TOKENS", "512")))
|
||||
llm_hint_max_tokens: int = Field(default_factory=lambda: int(os.getenv("LLM_HINT_MAX_TOKENS", "1200")))
|
||||
llm_scoring_json_response: bool = Field(default_factory=lambda: _env_bool("LLM_SCORING_JSON_RESPONSE", True))
|
||||
llm_scoring_max_tokens: int = Field(default_factory=lambda: int(os.getenv("LLM_SCORING_MAX_TOKENS", "4096")))
|
||||
|
||||
report_storage_dir: str = Field(default_factory=lambda: os.getenv("REPORT_STORAGE_DIR", "./storage/reports"))
|
||||
runtime_memory_ttl_seconds: int = Field(default_factory=lambda: int(os.getenv("RUNTIME_MEMORY_TTL_SECONDS", "7200")))
|
||||
runtime_memory_backend: str = Field(default_factory=lambda: os.getenv("RUNTIME_MEMORY_BACKEND", "memory"))
|
||||
runtime_memory_fallback_enabled: bool = Field(
|
||||
default_factory=lambda: _env_bool("RUNTIME_MEMORY_FALLBACK_ENABLED", True)
|
||||
)
|
||||
redis_url: str = Field(default_factory=lambda: os.getenv("REDIS_URL", "redis://redis:6379/0"))
|
||||
auth_validate_enabled: bool = Field(default_factory=lambda: _env_bool("AUTH_VALIDATE_ENABLED", True))
|
||||
auth_user_me_url: str = Field(default_factory=lambda: os.getenv("AUTH_USER_ME_URL", ""))
|
||||
auth_timeout_seconds: int = Field(default_factory=lambda: int(os.getenv("AUTH_TIMEOUT_SECONDS", "5")))
|
||||
auth_cache_ttl_seconds: int = Field(default_factory=lambda: int(os.getenv("AUTH_CACHE_TTL_SECONDS", "300")))
|
||||
|
||||
@property
|
||||
def is_production(self) -> bool:
|
||||
"""环境判断:标识当前是否运行在生产环境。"""
|
||||
return self.app_env.lower() == "production"
|
||||
|
||||
def deployment_config_errors(self) -> list[str]:
|
||||
"""部署检查:返回会导致生产核心链路不可用的配置问题。"""
|
||||
errors: list[str] = []
|
||||
if self.database_url.startswith("sqlite"):
|
||||
errors.append("DATABASE_URL must use MySQL")
|
||||
if "CHANGE_ME" in self.database_url:
|
||||
errors.append("DATABASE_URL still contains a placeholder")
|
||||
if self.runtime_memory_backend.lower() != "redis":
|
||||
errors.append("RUNTIME_MEMORY_BACKEND must be redis")
|
||||
if not self.redis_url:
|
||||
errors.append("REDIS_URL is required")
|
||||
if not self.auth_user_me_url:
|
||||
errors.append("AUTH_USER_ME_URL is required")
|
||||
if self.llm_api_key in {"", "CHANGE_ME"} and not self.llm_mock_enabled:
|
||||
errors.append("LLM_API_KEY is required when mock mode is disabled")
|
||||
return errors
|
||||
|
||||
def as_public_dict(self) -> dict[str, Any]:
|
||||
"""配置展示:返回允许暴露给 Demo 前端的功能开关。"""
|
||||
mock_enabled = self.llm_mock_enabled or not self.llm_api_key
|
||||
return {
|
||||
"stream_chat": self.llm_stream_enabled,
|
||||
"score_types": ["percentage", "five_point"],
|
||||
"pdf_export": True,
|
||||
"knowledge_search": True,
|
||||
"llm_mock_enabled": mock_enabled,
|
||||
"llm_mode": "mock" if mock_enabled else "real",
|
||||
"llm_fallback_to_mock": self.llm_fallback_to_mock,
|
||||
"llm_fast_model": self.llm_fast_model,
|
||||
"llm_reason_model": self.llm_reason_model,
|
||||
"llm_fast_thinking_enabled": self.llm_fast_thinking_enabled,
|
||||
"llm_reason_thinking_enabled": self.llm_reason_thinking_enabled,
|
||||
"llm_reasoning_effort": self.llm_reasoning_effort,
|
||||
"llm_fast_max_tokens": self.llm_fast_max_tokens,
|
||||
"runtime_memory_backend": self.runtime_memory_backend,
|
||||
"auth_validate_enabled": True,
|
||||
"auth_source": "django_user_center",
|
||||
}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,25 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UserContext:
|
||||
"""用户上下文:承载 Django 用户中心认证后的用户 ID 和入口元数据。"""
|
||||
|
||||
user_id: str
|
||||
tenant_id: str | None = None
|
||||
role: str | None = None
|
||||
class_id: str | None = None
|
||||
institution_id: int | None = None
|
||||
department_id: int | None = None
|
||||
entry_scene: str | None = None
|
||||
request_id: str | None = None
|
||||
ip_address: str | None = None
|
||||
user_agent: str | None = None
|
||||
username: str | None = None
|
||||
display_name: str | None = None
|
||||
phone: str | None = None
|
||||
major: str | None = None
|
||||
training_stage: str | None = None
|
||||
learning_target: str | None = None
|
||||
auth_source: str = "django_user_center"
|
||||
profile: dict | None = None
|
||||
@@ -0,0 +1,64 @@
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.core.exceptions import AppError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
"""异常注册:把业务异常转换为统一响应格式。"""
|
||||
|
||||
@app.exception_handler(AppError)
|
||||
async def handle_app_error(request: Request, exc: AppError) -> JSONResponse:
|
||||
logger.warning(
|
||||
"business_error code=%s path=%s request_id=%s",
|
||||
exc.code,
|
||||
request.url.path,
|
||||
request.headers.get("X-Request-Id"),
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"code": exc.code, "message": exc.message, "data": None},
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def handle_validation_error(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
logger.warning(
|
||||
"validation_error path=%s request_id=%s errors=%s",
|
||||
request.url.path,
|
||||
request.headers.get("X-Request-Id"),
|
||||
exc.errors(),
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={"code": "VALIDATION_ERROR", "message": "request validation failed", "data": {"errors": exc.errors()}},
|
||||
)
|
||||
|
||||
@app.exception_handler(SQLAlchemyError)
|
||||
async def handle_database_error(request: Request, exc: SQLAlchemyError) -> JSONResponse:
|
||||
logger.exception(
|
||||
"database_error path=%s request_id=%s",
|
||||
request.url.path,
|
||||
request.headers.get("X-Request-Id"),
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"code": "DATABASE_ERROR", "message": "database operation failed", "data": None},
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def handle_unexpected_error(request: Request, exc: Exception) -> JSONResponse:
|
||||
logger.exception(
|
||||
"unexpected_error path=%s request_id=%s",
|
||||
request.url.path,
|
||||
request.headers.get("X-Request-Id"),
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"code": "INTERNAL_ERROR", "message": "internal server error", "data": None},
|
||||
)
|
||||
@@ -0,0 +1,8 @@
|
||||
class AppError(Exception):
|
||||
"""业务异常:承载业务错误码、错误信息和 HTTP 状态码。"""
|
||||
|
||||
def __init__(self, code: str, message: str, status_code: int = 400) -> None:
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
@@ -0,0 +1,18 @@
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
"""统一响应:所有业务接口使用相同的 `code/message/data` 结构。"""
|
||||
|
||||
code: str = "OK"
|
||||
message: str = "success"
|
||||
data: T | None = None
|
||||
|
||||
|
||||
def ok(data: T | None = None) -> ApiResponse[T]:
|
||||
"""响应封装:生成成功响应对象。"""
|
||||
return ApiResponse(data=data)
|
||||
@@ -0,0 +1,40 @@
|
||||
from fastapi import Header, Request, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.core.context import UserContext
|
||||
from app.core.exceptions import AppError
|
||||
from app.services.external_auth_service import ExternalAuthService
|
||||
|
||||
bearer_scheme = HTTPBearer(auto_error=False, description="Django 用户中心 access token")
|
||||
|
||||
|
||||
async def get_user_context(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_scheme),
|
||||
x_entry_scene: str | None = Header(default=None, alias="X-Entry-Scene"),
|
||||
x_request_id: str | None = Header(default=None, alias="X-Request-Id"),
|
||||
) -> UserContext:
|
||||
"""用户校验:只接受宿主系统 access token,并转发 Django 用户中心 `/me` 获取真实用户。"""
|
||||
if not credentials or not credentials.credentials.strip():
|
||||
raise AppError("AUTH_CREDENTIAL_REQUIRED", "Authorization header is required", 401)
|
||||
|
||||
user = await ExternalAuthService().authenticate(request)
|
||||
return UserContext(
|
||||
user_id=user.user_id,
|
||||
tenant_id=user.tenant_id,
|
||||
role=user.role,
|
||||
institution_id=user.institution_id,
|
||||
department_id=user.department_id,
|
||||
entry_scene=x_entry_scene,
|
||||
request_id=x_request_id,
|
||||
ip_address=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
username=user.username,
|
||||
display_name=user.display_name,
|
||||
phone=user.phone,
|
||||
major=user.major,
|
||||
training_stage=user.training_stage,
|
||||
learning_target=user.learning_target,
|
||||
auth_source=user.source,
|
||||
profile=user.profile,
|
||||
)
|
||||
Reference in New Issue
Block a user