Files
fastapi/app/agents/llm_adapter.py
T

281 lines
12 KiB
Python
Raw Normal View History

import asyncio
import json
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass
import httpx
from app.core.config import settings
from app.core.exceptions import AppError
@dataclass
class LLMResponse:
"""LLM 响应:封装非流式模型输出和耗时指标。"""
content: str
model: str
latency_ms: int
token_usage: dict | None = None
@dataclass
class LLMStreamChunk:
"""LLM 流式片段:封装 SSE 增量内容和完成状态。"""
delta: str
done: bool = False
first_token_ms: int | None = None
total_latency_ms: int | None = None
model: str | None = None
fallback_used: bool = False
class OpenAICompatibleLLMClient:
"""LLM Adapter:统一封装 OpenAI-compatible 模型的可替换调用。"""
def __init__(self) -> None:
self.base_url = settings.llm_base_url.rstrip("/")
self.api_key = settings.llm_api_key
self.timeout = settings.llm_timeout_seconds
self.chat_completions_url = self._build_chat_completions_url()
@property
def is_mock_mode(self) -> bool:
"""模型模式:没有 API Key 或开启 mock 时使用本地模拟响应。"""
return settings.llm_mock_enabled or not self.api_key
async def chat(
self,
messages: list[dict],
model: str,
*,
thinking_enabled: bool | None = None,
reasoning_effort: str | None = None,
response_format: dict | None = None,
max_tokens: int | None = None,
) -> LLMResponse:
"""非流式调用:向 OpenAI-compatible 接口发送 messages 并返回完整文本。"""
start = time.perf_counter()
if self.is_mock_mode:
content = self._mock_response(messages)
return LLMResponse(content=content, model=f"mock-{model}", latency_ms=int((time.perf_counter() - start) * 1000))
try:
async with httpx.AsyncClient(timeout=self._http_timeout()) as client:
resp = await client.post(
self.chat_completions_url,
headers={"Authorization": f"Bearer {self.api_key}"},
json=self._build_payload(
model=model,
messages=messages,
stream=False,
thinking_enabled=thinking_enabled,
reasoning_effort=reasoning_effort,
response_format=response_format,
max_tokens=max_tokens,
),
)
resp.raise_for_status()
data = resp.json()
content = (data["choices"][0]["message"].get("content") or "").strip()
if not content:
raise KeyError("empty llm content")
return LLMResponse(
content=content,
model=model,
latency_ms=int((time.perf_counter() - start) * 1000),
token_usage=data.get("usage"),
)
except (httpx.TimeoutException, httpx.HTTPError, KeyError, IndexError, json.JSONDecodeError) as exc:
if settings.llm_fallback_to_mock:
content = self._mock_response(messages)
return LLMResponse(
content=content,
model=f"mock-fallback-{model}",
latency_ms=int((time.perf_counter() - start) * 1000),
token_usage={"fallback_reason": exc.__class__.__name__},
)
raise AppError("LLM_CALL_FAILED", "llm service call failed", 502) from exc
async def stream_chat(
self,
messages: list[dict],
model: str,
*,
thinking_enabled: bool | None = None,
reasoning_effort: str | None = None,
max_tokens: int | None = None,
) -> AsyncIterator[LLMStreamChunk]:
"""流式调用:以统一 chunk 结构输出 OpenAI-compatible SSE 增量。"""
start = time.perf_counter()
first_token_ms: int | None = None
if self.is_mock_mode:
async for chunk in self._mock_stream(messages, model, start, model_label=f"mock-{model}"):
yield chunk
return
try:
async with httpx.AsyncClient(timeout=self._http_timeout()) as client:
async with client.stream(
"POST",
self.chat_completions_url,
headers={"Authorization": f"Bearer {self.api_key}"},
json=self._build_payload(
model=model,
messages=messages,
stream=True,
thinking_enabled=thinking_enabled,
reasoning_effort=reasoning_effort,
max_tokens=max_tokens,
),
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line.startswith("data:"):
continue
payload = line.removeprefix("data:").strip()
if payload == "[DONE]":
break
data = json.loads(payload)
delta_obj = data["choices"][0].get("delta", {})
content_delta = delta_obj.get("content") or ""
reasoning_delta = delta_obj.get("reasoning_content") or ""
if (content_delta or reasoning_delta) and first_token_ms is None:
first_token_ms = int((time.perf_counter() - start) * 1000)
if content_delta:
yield LLMStreamChunk(delta=content_delta, first_token_ms=first_token_ms)
except (httpx.TimeoutException, httpx.HTTPError, KeyError, IndexError, json.JSONDecodeError) as exc:
if settings.llm_fallback_to_mock:
async for chunk in self._mock_stream(
messages,
model,
start,
model_label=f"mock-fallback-{model}",
fallback_used=True,
):
yield chunk
return
raise AppError("LLM_STREAM_FAILED", "llm stream call failed", 502) from exc
yield LLMStreamChunk(
delta="",
done=True,
first_token_ms=first_token_ms,
total_latency_ms=int((time.perf_counter() - start) * 1000),
model=model,
)
async def _mock_stream(
self,
messages: list[dict],
model: str,
start: float,
model_label: str,
fallback_used: bool = False,
) -> AsyncIterator[LLMStreamChunk]:
"""Mock 流式输出:在模型不可用时保持 Demo 流程可验证。"""
first_token_ms: int | None = None
content = self._mock_response(messages)
for piece in self._split_mock_content(content):
await asyncio.sleep(0.02)
if first_token_ms is None:
first_token_ms = int((time.perf_counter() - start) * 1000)
yield LLMStreamChunk(delta=piece, first_token_ms=first_token_ms)
yield LLMStreamChunk(
delta="",
done=True,
first_token_ms=first_token_ms,
total_latency_ms=int((time.perf_counter() - start) * 1000),
model=model_label,
fallback_used=fallback_used,
)
def _mock_response(self, messages: list[dict]) -> str:
"""Mock 输出:在没有 DeepSeek Key 时保证 Demo 闭环可运行。"""
latest = next((m.get("content", "") for m in reversed(messages) if m.get("role") == "user"), "")
prompt_head = " ".join(m.get("content", "").lower() for m in messages[:2])
if "score_type" in prompt_head and "dimension_scores" in prompt_head:
return json.dumps(
{
"score_type": "percentage",
"total_score": 82,
"dimension_scores": [
{"dimension": "信息获取", "score": 20, "max_score": 25, "comment": "覆盖了发热、咳嗽和喘息,儿科特异性病史仍需加强。"},
{"dimension": "分析推理", "score": 21, "max_score": 25, "comment": "能够识别肺炎方向,鉴别诊断完整性中等。"},
{"dimension": "处置决策", "score": 17, "max_score": 20, "comment": "治疗原则基本合理,风险预案需要更具体。"},
{"dimension": "沟通人文", "score": 12, "max_score": 15, "comment": "有告知意识,家属安抚和健康教育可更系统。"},
{"dimension": "临床整合", "score": 12, "max_score": 15, "comment": "诊疗流程完整,时间分配和整体组织较清晰。"},
],
"errors": [{"title": "儿科特异性病史不足", "description": "疫苗接种、过敏史、既往喘息史追问不足。"}],
"improvement_plan": ["补充儿科问诊框架:出生史、接种史、过敏史、既往喘息史。"],
"evidence_summary": ["用户完成了核心症状追问、检查申请、诊断和治疗提交。"],
"guideline_refs": [],
"overall_comment": "本次训练完成主要诊疗流程,诊断方向正确,治疗方案具备基本可执行性。",
},
ensure_ascii=False,
)
if "体温" in latest or "发热" in latest:
return "最高烧到39度多,已经反复四天了,退烧后会好一点,但很快又起来。"
if "" in latest or "呼吸" in latest:
return "昨天开始喘得明显,活动后更明显,晚上咳嗽也更重。"
if "精神" in latest or "" in latest:
return "精神比平时差一些,吃饭少了,但还能喝水,小便比平时略少。"
if "既往" in latest or "过敏" in latest:
return "以前没有明确哮喘诊断,也没有药物过敏史,小时候感冒时偶尔会咳得久。"
return "家长:孩子主要是发热、咳嗽,昨天开始喘,您可以继续问我具体情况。"
def _split_mock_content(self, content: str) -> list[str]:
"""Mock 分片:把本地模拟文本拆成流式输出片段。"""
return [content[i : i + 8] for i in range(0, len(content), 8)]
def _build_chat_completions_url(self) -> str:
"""接口地址:兼容 base_url 和完整 chat/completions URL 两种写法。"""
if self.base_url.endswith("/chat/completions"):
return self.base_url
return f"{self.base_url}/chat/completions"
def _http_timeout(self) -> httpx.Timeout:
"""超时策略:限制连接、写入和读取等待,避免前端长时间卡在生成中。"""
return httpx.Timeout(
timeout=self.timeout,
connect=min(8, self.timeout),
read=self.timeout,
write=min(15, self.timeout),
pool=min(8, self.timeout),
)
def _build_payload(
self,
*,
model: str,
messages: list[dict],
stream: bool,
thinking_enabled: bool | None = None,
reasoning_effort: str | None = None,
response_format: dict | None = None,
max_tokens: int | None = None,
) -> dict:
"""请求构造:兼容 DeepSeek V4 thinking、reasoning_effort 和 JSON 输出。"""
payload: dict = {"model": model, "messages": messages, "stream": stream}
supports_reasoning_options = self._supports_reasoning_options(model)
if thinking_enabled is not None and supports_reasoning_options:
payload["thinking"] = {"type": "enabled" if thinking_enabled else "disabled"}
if reasoning_effort and supports_reasoning_options and thinking_enabled is not False:
payload["reasoning_effort"] = reasoning_effort
if response_format:
payload["response_format"] = response_format
if max_tokens:
payload["max_tokens"] = max_tokens
return payload
def _supports_reasoning_options(self, model: str) -> bool:
"""厂商兼容:只向 DeepSeek 发送 thinking/reasoning_effort 等专有参数。"""
base = self.base_url.lower()
model_name = model.lower()
return "deepseek" in base or model_name.startswith("deepseek")
DeepSeekClient = OpenAICompatibleLLMClient