|
@@ -26,6 +26,9 @@ import org.springframework.web.bind.annotation.GetMapping;
|
|
|
import org.springframework.web.bind.annotation.RequestMapping;
|
|
|
import org.springframework.web.bind.annotation.RestController;
|
|
|
|
|
|
+import javax.servlet.http.HttpServletResponse;
|
|
|
+import java.io.IOException;
|
|
|
+import java.io.OutputStream;
|
|
|
import java.io.UnsupportedEncodingException;
|
|
|
import java.math.BigDecimal;
|
|
|
import java.math.RoundingMode;
|
|
@@ -180,34 +183,58 @@ public class GRPCController extends BaseController {
|
|
|
* @return
|
|
|
*/
|
|
|
@GetMapping(value = "/inferStreamRag")
|
|
|
- public AjaxResult inferStreamRag()
|
|
|
- {
|
|
|
+// public AjaxResult inferStreamRag(HttpServletResponse response) {
|
|
|
+ public void inferStreamRag(HttpServletResponse response) {
|
|
|
log.info("进入了调⽤RAG+⼤模型的调⽤参数");
|
|
|
- ManagedChannel channel = ManagedChannelBuilder.forAddress("10.0.0.24", 17070)
|
|
|
- .usePlaintext()
|
|
|
- .build();
|
|
|
- InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
|
|
|
- String dataJson = "{\"bot_id\":\"721\",\"exp_id\":\"721\",\"session_id\":\"C20231025153549AC11FC2905668524\",\"use_rag\":\"true\",\"prompt\":\"你是⼀个资深⽔务领域专家,能回答各种⽔务相关问题\",\"history_dia\":[\"什么是BOD\"],\"generate_args\":{\"max_new_tokens\":2048,\"max_length\":4096,\"num_beams\":1,\"do_sample\":true,\"top_p\":0.7,\"temperature\":0.95},\"extra\":{}}";
|
|
|
- System.out.println(dataJson);
|
|
|
- PredictionsRequest request = null;
|
|
|
+ // 获取输出流
|
|
|
+ OutputStream outputStream = null;
|
|
|
+ ManagedChannel channel = null;
|
|
|
try {
|
|
|
- request = PredictionsRequest.newBuilder()
|
|
|
- .setModelName("slibra_bot")
|
|
|
- .putInput("method", ByteString.copyFrom("infer_stream", "utf-8"))//推理
|
|
|
- .putInput("data", ByteString.copyFrom(dataJson, "utf-8"))
|
|
|
- .buildPartial();
|
|
|
- } catch (UnsupportedEncodingException e) {
|
|
|
- log.error("转换数据的时候,出现异常,异常为:", e);
|
|
|
- throw new RuntimeException(e.getMessage());
|
|
|
- }
|
|
|
- Iterator<PredictionResponse> response = stub.streamPredictions(request);
|
|
|
- while (response.hasNext()){
|
|
|
- String responseStr = response.next().getPrediction().toStringUtf8();
|
|
|
-// System.out.println(unicodeToChinese(responseStr.substring(16, responseStr.length()-1)) + "---" + responseStr);
|
|
|
- System.out.println(responseStr);
|
|
|
+ channel = ManagedChannelBuilder.forAddress("10.0.0.24", 17070)
|
|
|
+ .usePlaintext()
|
|
|
+ .build();
|
|
|
+ InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
|
|
|
+ String dataJson = "{\"bot_id\":\"721\",\"exp_id\":\"721\",\"session_id\":\"C20231025153549AC11FC2905668524\",\"use_rag\":\"true\",\"prompt\":\"你是⼀个资深⽔务领域专家,能回答各种⽔务相关问题\",\"history_dia\":[\"什么是BOD\"],\"generate_args\":{\"max_new_tokens\":2048,\"max_length\":4096,\"num_beams\":1,\"do_sample\":true,\"top_p\":0.7,\"temperature\":0.95},\"extra\":{}}";
|
|
|
+ System.out.println(dataJson);
|
|
|
+ PredictionsRequest request = null;
|
|
|
+ try {
|
|
|
+ request = PredictionsRequest.newBuilder()
|
|
|
+ .setModelName("slibra_bot")
|
|
|
+ .putInput("method", ByteString.copyFrom("infer_stream", "utf-8"))//推理
|
|
|
+ .putInput("data", ByteString.copyFrom(dataJson, "utf-8"))
|
|
|
+ .buildPartial();
|
|
|
+ } catch (UnsupportedEncodingException e) {
|
|
|
+ log.error("转换数据的时候,出现异常,异常为:", e);
|
|
|
+ throw new RuntimeException(e.getMessage());
|
|
|
+ }
|
|
|
+ response.setContentType("text/plain");
|
|
|
+ response.setCharacterEncoding("utf-8");
|
|
|
+ outputStream = response.getOutputStream();
|
|
|
+ Iterator<PredictionResponse> predictions = stub.streamPredictions(request);
|
|
|
+ while (predictions.hasNext()) {
|
|
|
+ String responseStr = predictions.next().getPrediction().toStringUtf8();
|
|
|
+ System.out.println(responseStr);
|
|
|
+ if("complete".equals(responseStr)){
|
|
|
+ System.out.println("结尾语句,无需售出~~~");
|
|
|
+ }else{
|
|
|
+ outputStream.write(responseStr.getBytes());
|
|
|
+ outputStream.flush();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ } finally {
|
|
|
+ // 关闭输出流
|
|
|
+ try {
|
|
|
+ outputStream.close();
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }finally {
|
|
|
+ channel.shutdown();
|
|
|
+ }
|
|
|
}
|
|
|
- channel.shutdown();
|
|
|
- return AjaxResult.success("ok");
|
|
|
+
|
|
|
+// return AjaxResult.success("ok");
|
|
|
}
|
|
|
|
|
|
|