流式输出,秒级,效果非常棒

This commit is contained in:
2026-03-31 21:06:59 +08:00
parent 279c28b412
commit fe7db4d4cf
3 changed files with 196 additions and 1 deletions

View File

@@ -12,8 +12,10 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
@@ -62,6 +64,13 @@ public class AiQaItemController {
}
}
@PostMapping(value = "/askAIStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@Operation(summary = "客户问答(大模型流式)", description = "流式返回大模型回答分片,流完成后自动落库")
public Flux<String> askCustomerStream(
@Parameter(description = "提问请求", required = true) @RequestBody AiQaCustomerAskRequestDto request) {
return aiQaCustomerAskService.askCustomerStream(request);
}
@PostMapping("/add")
@Operation(summary = "新增", description = "新增 AI 问答子表记录")
public ResponseEntity<Map<String, Object>> add(

View File

@@ -2,6 +2,7 @@ package com.rj.service;
import com.rj.dto.AiQaCustomerAskRequestDto;
import com.rj.dto.AiQaCustomerAskResponseDto;
import reactor.core.publisher.Flux;
/**
* AI 问答子表:客户提问并调用大模型的业务
@@ -15,4 +16,9 @@ public interface IAiQaCustomerAskService {
* @throws Exception HTTP 或解析失败、模型返回错误等
*/
AiQaCustomerAskResponseDto askCustomer(AiQaCustomerAskRequestDto request) throws Exception;
/**
* 流式回答:按分片持续返回大模型输出,流结束后落库。
*/
Flux<String> askCustomerStream(AiQaCustomerAskRequestDto request);
}

View File

@@ -9,6 +9,9 @@ import com.rj.entity.AiQaMain;
import com.rj.service.IAiQaCustomerAskService;
import com.rj.service.IAiQaItemService;
import com.rj.service.IAiQaMainService;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpEntity;
@@ -18,9 +21,13 @@ import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.client.RestTemplate;
import reactor.core.publisher.Flux;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
/**
* 客户 AI 问答:本地 OpenAI 兼容 /v1/chat/completions
@@ -82,7 +89,8 @@ public class AiQaCustomerAskServiceImpl implements IAiQaCustomerAskService {
: configMaxTokens;
String userMessage = "问题类型:" + questionType + "\n用户问题" + questionContent;
String answer = callLocalOpenAiChatCompletions(systemContent, userMessage, model, maxTokens);
// String answer = callLocalOpenAiChatCompletions(systemContent, userMessage, model, maxTokens);
String answer = callLocalOpenAiChatCompletionsByLangChain4j(systemContent, userMessage, model, maxTokens);
LocalDateTime now = LocalDateTime.now();
if (newMain) {
@@ -115,6 +123,88 @@ public class AiQaCustomerAskServiceImpl implements IAiQaCustomerAskService {
.build();
}
@Override
public Flux<String> askCustomerStream(AiQaCustomerAskRequestDto request) {
validate(request);
boolean newMain = request.getParentId() == null || request.getParentId().trim().isEmpty();
String parentId = newMain ? UUID.randomUUID().toString() : request.getParentId().trim();
String questionType = request.getQuestionType().trim();
String questionContent = request.getQuestionContent().trim();
String model = request.getModel() != null && !request.getModel().trim().isEmpty()
? request.getModel().trim()
: configDefaultModel;
String systemContent = request.getSystemPrompt() != null && !request.getSystemPrompt().trim().isEmpty()
? request.getSystemPrompt().trim()
: DEFAULT_SYSTEM_PROMPT;
int maxTokens = request.getMaxTokens() != null && request.getMaxTokens() > 0
? request.getMaxTokens()
: configMaxTokens;
String userMessage = "问题类型:" + questionType + "\n用户问题" + questionContent;
String openAiBaseUrl = normalizeOpenAiBaseUrl(chatCompletionsUrl);
OpenAiStreamingChatModel chatModel = OpenAiStreamingChatModel.builder()
.baseUrl(openAiBaseUrl)
.apiKey("not-used")
.modelName(model)
.maxTokens(maxTokens)
.build();
String prompt = "系统指令:" + systemContent + "\n\n用户输入" + userMessage;
return Flux.create(sink -> chatModel.chat(prompt, new StreamingChatResponseHandler() {
final StringBuilder answerBuffer = new StringBuilder();
@Override
public void onPartialResponse(String partialResponse) {
if (partialResponse != null && !partialResponse.isEmpty()) {
answerBuffer.append(partialResponse);
sink.next(partialResponse);
}
}
@Override
public void onCompleteResponse(ChatResponse completeResponse) {
try {
// 流结束后一次性落库,与 askCustomer 逻辑保持一致
LocalDateTime now = LocalDateTime.now();
if (newMain) {
AiQaMain main = new AiQaMain();
main.setId(parentId);
main.setTitle(truncate(questionContent, MAIN_TITLE_MAX_LEN));
main.setQaType(questionType);
main.setAskContent(questionContent);
main.setSalesName(request.getUserName());
main.setSalesPhone(request.getLoginAccount());
main.setCreateTime(now);
main.setUpdateTime(now);
aiQaMainService.save(main);
}
AiQaItem item = new AiQaItem();
item.setId(UUID.randomUUID().toString());
item.setParentId(parentId);
item.setQuestSrc(truncate(userMessage, ITEM_QUEST_SRC_MAX_LEN));
item.setAnswerText(buildAnswerJson(questionType, questionContent, answerBuffer.toString(), model));
item.setCreateTime(now);
item.setUpdateTime(now);
aiQaItemService.save(item);
sink.complete();
} catch (Exception e) {
sink.error(e);
}
}
@Override
public void onError(Throwable error) {
sink.error(error);
}
}));
}
private static String truncate(String s, int maxLen) {
if (s == null) {
return null;
@@ -148,6 +238,7 @@ public class AiQaCustomerAskServiceImpl implements IAiQaCustomerAskService {
*/
private String callLocalOpenAiChatCompletions(String systemPrompt, String userContent, String model, int maxTokens)
throws Exception {
// 这里保留原有基于 RestTemplate 的 HTTP 调用实现
Map<String, Object> body = new LinkedHashMap<>();
body.put("model", model);
List<Map<String, String>> messages = new ArrayList<>();
@@ -184,4 +275,93 @@ public class AiQaCustomerAskServiceImpl implements IAiQaCustomerAskService {
log.info("OpenAI 兼容接口返回:" + content);
return content.asText("");
}
/**
* 使用 langchain4j 调用本地 OpenAI 兼容 /v1/chat/completions
*
* 说明:
* - 这里直接在方法里构造 OpenAiChatModel本地部署一般不会真正校验 apiKey可以随便给一个占位值
* - 如果你已经在配置里注入了 OpenAiChatModel Bean可以把构造逻辑抽出去或改成注入 Bean 的方式。
*/
private String callLocalOpenAiChatCompletionsByLangChain4j(String systemPrompt, String userContent, String model, int maxTokens) {
log.info("LangChain4j 使用 OpenAI 兼容接口model=" + model + ", chatCompletionsUrl=" + chatCompletionsUrl);
String openAiBaseUrl = normalizeOpenAiBaseUrl(chatCompletionsUrl);
OpenAiStreamingChatModel chatModel = OpenAiStreamingChatModel.builder()
.baseUrl(openAiBaseUrl) // OpenAI 根路径,如 http://host:port/v1
.apiKey("not-used") // 本地一般不会真正校验 key这里给占位值
.modelName(model)
.maxTokens(maxTokens)
.build();
String prompt = "系统指令:" + systemPrompt + "\n\n用户输入" + userContent;
StringBuilder answerBuffer = new StringBuilder();
CountDownLatch done = new CountDownLatch(1);
AtomicReference<Throwable> errorRef = new AtomicReference<>();
chatModel.chat(prompt, new StreamingChatResponseHandler() {
@Override
public void onPartialResponse(String partialResponse) {
if (partialResponse != null && !partialResponse.isEmpty()) {
answerBuffer.append(partialResponse);
log.info("LangChain4j 流式分片:{}", partialResponse);
}
}
@Override
public void onCompleteResponse(ChatResponse completeResponse) {
done.countDown();
}
@Override
public void onError(Throwable error) {
errorRef.set(error);
done.countDown();
}
});
try {
boolean finished = done.await(120, TimeUnit.SECONDS);
if (!finished) {
throw new IllegalStateException("LangChain4j 流式调用超时120s");
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("LangChain4j 流式调用被中断", e);
}
if (errorRef.get() != null) {
throw new IllegalStateException("LangChain4j 流式调用失败: " + errorRef.get().getMessage(), errorRef.get());
}
String answer = answerBuffer.toString();
if (answer.isEmpty()) {
throw new IllegalStateException("LangChain4j 流式调用返回空响应");
}
log.info("LangChain4j OpenAI 兼容接口返回:" + answer);
return answer;
}
/**
* 将配置中的 chat/completions URL 规范为 OpenAI baseUrl。
* 例如:
* - http://127.0.0.1:8000/v1/chat/completions -> http://127.0.0.1:8000/v1
* - http://127.0.0.1:8000/v1 -> http://127.0.0.1:8000/v1
*/
private String normalizeOpenAiBaseUrl(String rawUrl) {
if (rawUrl == null || rawUrl.trim().isEmpty()) {
throw new IllegalArgumentException("ai.qa.local.chat-url 不能为空");
}
String normalized = rawUrl.trim();
// 去掉末尾斜杠,避免后续处理出现双斜杠
while (normalized.endsWith("/")) {
normalized = normalized.substring(0, normalized.length() - 1);
}
String suffix = "/chat/completions";
if (normalized.endsWith(suffix)) {
normalized = normalized.substring(0, normalized.length() - suffix.length());
}
return normalized;
}
}