Files
fastapi/backend/app/agents/llm_adapter.py
T
2026-06-01 09:25:26 +08:00

281 lines
12 KiB
Python

import asyncio
import json
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass
import httpx
from app.core.config import settings
from app.core.exceptions import AppError
@dataclass
class LLMResponse:
"""LLM 响应:封装非流式模型输出和耗时指标。"""
content: str
model: str
latency_ms: int
token_usage: dict | None = None
@dataclass
class LLMStreamChunk:
"""LLM 流式片段:封装 SSE 增量内容和完成状态。"""
delta: str
done: bool = False
first_token_ms: int | None = None
total_latency_ms: int | None = None
model: str | None = None
fallback_used: bool = False
class OpenAICompatibleLLMClient:
"""LLM Adapter:统一封装 OpenAI-compatible 模型的可替换调用。"""
def __init__(self) -> None:
self.base_url = settings.llm_base_url.rstrip("/")
self.api_key = settings.llm_api_key
self.timeout = settings.llm_timeout_seconds
self.chat_completions_url = self._build_chat_completions_url()
@property
def is_mock_mode(self) -> bool:
"""模型模式:没有 API Key 或开启 mock 时使用本地模拟响应。"""
return settings.llm_mock_enabled or not self.api_key
async def chat(
self,
messages: list[dict],
model: str,
*,
thinking_enabled: bool | None = None,
reasoning_effort: str | None = None,
response_format: dict | None = None,
max_tokens: int | None = None,
) -> LLMResponse:
"""非流式调用:向 OpenAI-compatible 接口发送 messages 并返回完整文本。"""
start = time.perf_counter()
if self.is_mock_mode:
content = self._mock_response(messages)
return LLMResponse(content=content, model=f"mock-{model}", latency_ms=int((time.perf_counter() - start) * 1000))
try:
async with httpx.AsyncClient(timeout=self._http_timeout()) as client:
resp = await client.post(
self.chat_completions_url,
headers={"Authorization": f"Bearer {self.api_key}"},
json=self._build_payload(
model=model,
messages=messages,
stream=False,
thinking_enabled=thinking_enabled,
reasoning_effort=reasoning_effort,
response_format=response_format,
max_tokens=max_tokens,
),
)
resp.raise_for_status()
data = resp.json()
content = (data["choices"][0]["message"].get("content") or "").strip()
if not content:
raise KeyError("empty llm content")
return LLMResponse(
content=content,
model=model,
latency_ms=int((time.perf_counter() - start) * 1000),
token_usage=data.get("usage"),
)
except (httpx.TimeoutException, httpx.HTTPError, KeyError, IndexError, json.JSONDecodeError) as exc:
if settings.llm_fallback_to_mock:
content = self._mock_response(messages)
return LLMResponse(
content=content,
model=f"mock-fallback-{model}",
latency_ms=int((time.perf_counter() - start) * 1000),
token_usage={"fallback_reason": exc.__class__.__name__},
)
raise AppError("LLM_CALL_FAILED", "llm service call failed", 502) from exc
async def stream_chat(
self,
messages: list[dict],
model: str,
*,
thinking_enabled: bool | None = None,
reasoning_effort: str | None = None,
max_tokens: int | None = None,
) -> AsyncIterator[LLMStreamChunk]:
"""流式调用:以统一 chunk 结构输出 OpenAI-compatible SSE 增量。"""
start = time.perf_counter()
first_token_ms: int | None = None
if self.is_mock_mode:
async for chunk in self._mock_stream(messages, model, start, model_label=f"mock-{model}"):
yield chunk
return
try:
async with httpx.AsyncClient(timeout=self._http_timeout()) as client:
async with client.stream(
"POST",
self.chat_completions_url,
headers={"Authorization": f"Bearer {self.api_key}"},
json=self._build_payload(
model=model,
messages=messages,
stream=True,
thinking_enabled=thinking_enabled,
reasoning_effort=reasoning_effort,
max_tokens=max_tokens,
),
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line.startswith("data:"):
continue
payload = line.removeprefix("data:").strip()
if payload == "[DONE]":
break
data = json.loads(payload)
delta_obj = data["choices"][0].get("delta", {})
content_delta = delta_obj.get("content") or ""
reasoning_delta = delta_obj.get("reasoning_content") or ""
if (content_delta or reasoning_delta) and first_token_ms is None:
first_token_ms = int((time.perf_counter() - start) * 1000)
if content_delta:
yield LLMStreamChunk(delta=content_delta, first_token_ms=first_token_ms)
except (httpx.TimeoutException, httpx.HTTPError, KeyError, IndexError, json.JSONDecodeError) as exc:
if settings.llm_fallback_to_mock:
async for chunk in self._mock_stream(
messages,
model,
start,
model_label=f"mock-fallback-{model}",
fallback_used=True,
):
yield chunk
return
raise AppError("LLM_STREAM_FAILED", "llm stream call failed", 502) from exc
yield LLMStreamChunk(
delta="",
done=True,
first_token_ms=first_token_ms,
total_latency_ms=int((time.perf_counter() - start) * 1000),
model=model,
)
async def _mock_stream(
self,
messages: list[dict],
model: str,
start: float,
model_label: str,
fallback_used: bool = False,
) -> AsyncIterator[LLMStreamChunk]:
"""Mock 流式输出:在模型不可用时保持 Demo 流程可验证。"""
first_token_ms: int | None = None
content = self._mock_response(messages)
for piece in self._split_mock_content(content):
await asyncio.sleep(0.02)
if first_token_ms is None:
first_token_ms = int((time.perf_counter() - start) * 1000)
yield LLMStreamChunk(delta=piece, first_token_ms=first_token_ms)
yield LLMStreamChunk(
delta="",
done=True,
first_token_ms=first_token_ms,
total_latency_ms=int((time.perf_counter() - start) * 1000),
model=model_label,
fallback_used=fallback_used,
)
def _mock_response(self, messages: list[dict]) -> str:
"""Mock 输出:在没有 DeepSeek Key 时保证 Demo 闭环可运行。"""
latest = next((m.get("content", "") for m in reversed(messages) if m.get("role") == "user"), "")
prompt_head = " ".join(m.get("content", "").lower() for m in messages[:2])
if "score_type" in prompt_head and "dimension_scores" in prompt_head:
return json.dumps(
{
"score_type": "percentage",
"total_score": 82,
"dimension_scores": [
{"dimension": "信息获取", "score": 20, "max_score": 25, "comment": "覆盖了发热、咳嗽和喘息,儿科特异性病史仍需加强。"},
{"dimension": "分析推理", "score": 21, "max_score": 25, "comment": "能够识别肺炎方向,鉴别诊断完整性中等。"},
{"dimension": "处置决策", "score": 17, "max_score": 20, "comment": "治疗原则基本合理,风险预案需要更具体。"},
{"dimension": "沟通人文", "score": 12, "max_score": 15, "comment": "有告知意识,家属安抚和健康教育可更系统。"},
{"dimension": "临床整合", "score": 12, "max_score": 15, "comment": "诊疗流程完整,时间分配和整体组织较清晰。"},
],
"errors": [{"title": "儿科特异性病史不足", "description": "疫苗接种、过敏史、既往喘息史追问不足。"}],
"improvement_plan": ["补充儿科问诊框架:出生史、接种史、过敏史、既往喘息史。"],
"evidence_summary": ["用户完成了核心症状追问、检查申请、诊断和治疗提交。"],
"guideline_refs": [],
"overall_comment": "本次训练完成主要诊疗流程,诊断方向正确,治疗方案具备基本可执行性。",
},
ensure_ascii=False,
)
if "体温" in latest or "发热" in latest:
return "最高烧到39度多,已经反复四天了,退烧后会好一点,但很快又起来。"
if "" in latest or "呼吸" in latest:
return "昨天开始喘得明显,活动后更明显,晚上咳嗽也更重。"
if "精神" in latest or "" in latest:
return "精神比平时差一些,吃饭少了,但还能喝水,小便比平时略少。"
if "既往" in latest or "过敏" in latest:
return "以前没有明确哮喘诊断,也没有药物过敏史,小时候感冒时偶尔会咳得久。"
return "家长:孩子主要是发热、咳嗽,昨天开始喘,您可以继续问我具体情况。"
def _split_mock_content(self, content: str) -> list[str]:
"""Mock 分片:把本地模拟文本拆成流式输出片段。"""
return [content[i : i + 8] for i in range(0, len(content), 8)]
def _build_chat_completions_url(self) -> str:
"""接口地址:兼容 base_url 和完整 chat/completions URL 两种写法。"""
if self.base_url.endswith("/chat/completions"):
return self.base_url
return f"{self.base_url}/chat/completions"
def _http_timeout(self) -> httpx.Timeout:
"""超时策略:限制连接、写入和读取等待,避免前端长时间卡在生成中。"""
return httpx.Timeout(
timeout=self.timeout,
connect=min(8, self.timeout),
read=self.timeout,
write=min(15, self.timeout),
pool=min(8, self.timeout),
)
def _build_payload(
self,
*,
model: str,
messages: list[dict],
stream: bool,
thinking_enabled: bool | None = None,
reasoning_effort: str | None = None,
response_format: dict | None = None,
max_tokens: int | None = None,
) -> dict:
"""请求构造:兼容 DeepSeek V4 thinking、reasoning_effort 和 JSON 输出。"""
payload: dict = {"model": model, "messages": messages, "stream": stream}
supports_reasoning_options = self._supports_reasoning_options(model)
if thinking_enabled is not None and supports_reasoning_options:
payload["thinking"] = {"type": "enabled" if thinking_enabled else "disabled"}
if reasoning_effort and supports_reasoning_options and thinking_enabled is not False:
payload["reasoning_effort"] = reasoning_effort
if response_format:
payload["response_format"] = response_format
if max_tokens:
payload["max_tokens"] = max_tokens
return payload
def _supports_reasoning_options(self, model: str) -> bool:
"""厂商兼容:只向 DeepSeek 发送 thinking/reasoning_effort 等专有参数。"""
base = self.base_url.lower()
model_name = model.lower()
return "deepseek" in base or model_name.startswith("deepseek")
DeepSeekClient = OpenAICompatibleLLMClient