流式输出,秒级,效果非常棒
This commit is contained in:
@@ -12,8 +12,10 @@ import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
@@ -62,6 +64,13 @@ public class AiQaItemController {
|
||||
}
|
||||
}
|
||||
|
||||
@PostMapping(value = "/askAIStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
@Operation(summary = "客户问答(大模型流式)", description = "流式返回大模型回答分片,流完成后自动落库")
|
||||
public Flux<String> askCustomerStream(
|
||||
@Parameter(description = "提问请求", required = true) @RequestBody AiQaCustomerAskRequestDto request) {
|
||||
return aiQaCustomerAskService.askCustomerStream(request);
|
||||
}
|
||||
|
||||
@PostMapping("/add")
|
||||
@Operation(summary = "新增", description = "新增 AI 问答子表记录")
|
||||
public ResponseEntity<Map<String, Object>> add(
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.rj.service;
|
||||
|
||||
import com.rj.dto.AiQaCustomerAskRequestDto;
|
||||
import com.rj.dto.AiQaCustomerAskResponseDto;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* AI 问答子表:客户提问并调用大模型的业务
|
||||
@@ -15,4 +16,9 @@ public interface IAiQaCustomerAskService {
|
||||
* @throws Exception HTTP 或解析失败、模型返回错误等
|
||||
*/
|
||||
AiQaCustomerAskResponseDto askCustomer(AiQaCustomerAskRequestDto request) throws Exception;
|
||||
|
||||
/**
|
||||
* 流式回答:按分片持续返回大模型输出,流结束后落库。
|
||||
*/
|
||||
Flux<String> askCustomerStream(AiQaCustomerAskRequestDto request);
|
||||
}
|
||||
|
||||
@@ -9,6 +9,9 @@ import com.rj.entity.AiQaMain;
|
||||
import com.rj.service.IAiQaCustomerAskService;
|
||||
import com.rj.service.IAiQaItemService;
|
||||
import com.rj.service.IAiQaMainService;
|
||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.http.HttpEntity;
|
||||
@@ -18,9 +21,13 @@ import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
/**
|
||||
* 客户 AI 问答:本地 OpenAI 兼容 /v1/chat/completions
|
||||
@@ -82,7 +89,8 @@ public class AiQaCustomerAskServiceImpl implements IAiQaCustomerAskService {
|
||||
: configMaxTokens;
|
||||
|
||||
String userMessage = "问题类型:" + questionType + "\n用户问题:" + questionContent;
|
||||
String answer = callLocalOpenAiChatCompletions(systemContent, userMessage, model, maxTokens);
|
||||
// String answer = callLocalOpenAiChatCompletions(systemContent, userMessage, model, maxTokens);
|
||||
String answer = callLocalOpenAiChatCompletionsByLangChain4j(systemContent, userMessage, model, maxTokens);
|
||||
|
||||
LocalDateTime now = LocalDateTime.now();
|
||||
if (newMain) {
|
||||
@@ -115,6 +123,88 @@ public class AiQaCustomerAskServiceImpl implements IAiQaCustomerAskService {
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<String> askCustomerStream(AiQaCustomerAskRequestDto request) {
|
||||
validate(request);
|
||||
|
||||
boolean newMain = request.getParentId() == null || request.getParentId().trim().isEmpty();
|
||||
String parentId = newMain ? UUID.randomUUID().toString() : request.getParentId().trim();
|
||||
String questionType = request.getQuestionType().trim();
|
||||
String questionContent = request.getQuestionContent().trim();
|
||||
|
||||
String model = request.getModel() != null && !request.getModel().trim().isEmpty()
|
||||
? request.getModel().trim()
|
||||
: configDefaultModel;
|
||||
String systemContent = request.getSystemPrompt() != null && !request.getSystemPrompt().trim().isEmpty()
|
||||
? request.getSystemPrompt().trim()
|
||||
: DEFAULT_SYSTEM_PROMPT;
|
||||
|
||||
int maxTokens = request.getMaxTokens() != null && request.getMaxTokens() > 0
|
||||
? request.getMaxTokens()
|
||||
: configMaxTokens;
|
||||
|
||||
String userMessage = "问题类型:" + questionType + "\n用户问题:" + questionContent;
|
||||
String openAiBaseUrl = normalizeOpenAiBaseUrl(chatCompletionsUrl);
|
||||
|
||||
OpenAiStreamingChatModel chatModel = OpenAiStreamingChatModel.builder()
|
||||
.baseUrl(openAiBaseUrl)
|
||||
.apiKey("not-used")
|
||||
.modelName(model)
|
||||
.maxTokens(maxTokens)
|
||||
.build();
|
||||
|
||||
String prompt = "系统指令:" + systemContent + "\n\n用户输入:" + userMessage;
|
||||
|
||||
return Flux.create(sink -> chatModel.chat(prompt, new StreamingChatResponseHandler() {
|
||||
final StringBuilder answerBuffer = new StringBuilder();
|
||||
|
||||
@Override
|
||||
public void onPartialResponse(String partialResponse) {
|
||||
if (partialResponse != null && !partialResponse.isEmpty()) {
|
||||
answerBuffer.append(partialResponse);
|
||||
sink.next(partialResponse);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCompleteResponse(ChatResponse completeResponse) {
|
||||
try {
|
||||
// 流结束后一次性落库,与 askCustomer 逻辑保持一致
|
||||
LocalDateTime now = LocalDateTime.now();
|
||||
if (newMain) {
|
||||
AiQaMain main = new AiQaMain();
|
||||
main.setId(parentId);
|
||||
main.setTitle(truncate(questionContent, MAIN_TITLE_MAX_LEN));
|
||||
main.setQaType(questionType);
|
||||
main.setAskContent(questionContent);
|
||||
main.setSalesName(request.getUserName());
|
||||
main.setSalesPhone(request.getLoginAccount());
|
||||
main.setCreateTime(now);
|
||||
main.setUpdateTime(now);
|
||||
aiQaMainService.save(main);
|
||||
}
|
||||
AiQaItem item = new AiQaItem();
|
||||
item.setId(UUID.randomUUID().toString());
|
||||
item.setParentId(parentId);
|
||||
item.setQuestSrc(truncate(userMessage, ITEM_QUEST_SRC_MAX_LEN));
|
||||
item.setAnswerText(buildAnswerJson(questionType, questionContent, answerBuffer.toString(), model));
|
||||
item.setCreateTime(now);
|
||||
item.setUpdateTime(now);
|
||||
aiQaItemService.save(item);
|
||||
|
||||
sink.complete();
|
||||
} catch (Exception e) {
|
||||
sink.error(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
sink.error(error);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
private static String truncate(String s, int maxLen) {
|
||||
if (s == null) {
|
||||
return null;
|
||||
@@ -148,6 +238,7 @@ public class AiQaCustomerAskServiceImpl implements IAiQaCustomerAskService {
|
||||
*/
|
||||
private String callLocalOpenAiChatCompletions(String systemPrompt, String userContent, String model, int maxTokens)
|
||||
throws Exception {
|
||||
// 这里保留原有基于 RestTemplate 的 HTTP 调用实现
|
||||
Map<String, Object> body = new LinkedHashMap<>();
|
||||
body.put("model", model);
|
||||
List<Map<String, String>> messages = new ArrayList<>();
|
||||
@@ -184,4 +275,93 @@ public class AiQaCustomerAskServiceImpl implements IAiQaCustomerAskService {
|
||||
log.info("OpenAI 兼容接口返回:" + content);
|
||||
return content.asText("");
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用 langchain4j 调用本地 OpenAI 兼容 /v1/chat/completions
|
||||
*
|
||||
* 说明:
|
||||
* - 这里直接在方法里构造 OpenAiChatModel(本地部署一般不会真正校验 apiKey,可以随便给一个占位值);
|
||||
* - 如果你已经在配置里注入了 OpenAiChatModel Bean,可以把构造逻辑抽出去或改成注入 Bean 的方式。
|
||||
*/
|
||||
private String callLocalOpenAiChatCompletionsByLangChain4j(String systemPrompt, String userContent, String model, int maxTokens) {
|
||||
log.info("LangChain4j 使用 OpenAI 兼容接口,model=" + model + ", chatCompletionsUrl=" + chatCompletionsUrl);
|
||||
String openAiBaseUrl = normalizeOpenAiBaseUrl(chatCompletionsUrl);
|
||||
OpenAiStreamingChatModel chatModel = OpenAiStreamingChatModel.builder()
|
||||
.baseUrl(openAiBaseUrl) // OpenAI 根路径,如 http://host:port/v1
|
||||
.apiKey("not-used") // 本地一般不会真正校验 key,这里给占位值
|
||||
.modelName(model)
|
||||
.maxTokens(maxTokens)
|
||||
.build();
|
||||
|
||||
String prompt = "系统指令:" + systemPrompt + "\n\n用户输入:" + userContent;
|
||||
StringBuilder answerBuffer = new StringBuilder();
|
||||
CountDownLatch done = new CountDownLatch(1);
|
||||
AtomicReference<Throwable> errorRef = new AtomicReference<>();
|
||||
|
||||
chatModel.chat(prompt, new StreamingChatResponseHandler() {
|
||||
@Override
|
||||
public void onPartialResponse(String partialResponse) {
|
||||
if (partialResponse != null && !partialResponse.isEmpty()) {
|
||||
answerBuffer.append(partialResponse);
|
||||
log.info("LangChain4j 流式分片:{}", partialResponse);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCompleteResponse(ChatResponse completeResponse) {
|
||||
done.countDown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
errorRef.set(error);
|
||||
done.countDown();
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
boolean finished = done.await(120, TimeUnit.SECONDS);
|
||||
if (!finished) {
|
||||
throw new IllegalStateException("LangChain4j 流式调用超时(120s)");
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IllegalStateException("LangChain4j 流式调用被中断", e);
|
||||
}
|
||||
|
||||
if (errorRef.get() != null) {
|
||||
throw new IllegalStateException("LangChain4j 流式调用失败: " + errorRef.get().getMessage(), errorRef.get());
|
||||
}
|
||||
|
||||
String answer = answerBuffer.toString();
|
||||
if (answer.isEmpty()) {
|
||||
throw new IllegalStateException("LangChain4j 流式调用返回空响应");
|
||||
}
|
||||
log.info("LangChain4j OpenAI 兼容接口返回:" + answer);
|
||||
return answer;
|
||||
}
|
||||
|
||||
/**
|
||||
* 将配置中的 chat/completions URL 规范为 OpenAI baseUrl。
|
||||
* 例如:
|
||||
* - http://127.0.0.1:8000/v1/chat/completions -> http://127.0.0.1:8000/v1
|
||||
* - http://127.0.0.1:8000/v1 -> http://127.0.0.1:8000/v1
|
||||
*/
|
||||
private String normalizeOpenAiBaseUrl(String rawUrl) {
|
||||
if (rawUrl == null || rawUrl.trim().isEmpty()) {
|
||||
throw new IllegalArgumentException("ai.qa.local.chat-url 不能为空");
|
||||
}
|
||||
String normalized = rawUrl.trim();
|
||||
|
||||
// 去掉末尾斜杠,避免后续处理出现双斜杠
|
||||
while (normalized.endsWith("/")) {
|
||||
normalized = normalized.substring(0, normalized.length() - 1);
|
||||
}
|
||||
|
||||
String suffix = "/chat/completions";
|
||||
if (normalized.endsWith(suffix)) {
|
||||
normalized = normalized.substring(0, normalized.length() - suffix.length());
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user