wangmiaomiao 10 månader sedan
förälder
incheckning
2c7f3afe8d

+ 52 - 25
slibra-admin/src/main/java/com/slibra/web/controller/business/GRPCController.java

@@ -26,6 +26,9 @@ import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 
+import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+import java.io.OutputStream;
 import java.io.UnsupportedEncodingException;
 import java.math.BigDecimal;
 import java.math.RoundingMode;
@@ -180,34 +183,58 @@ public class GRPCController extends BaseController {
      * @return
      */
     @GetMapping(value = "/inferStreamRag")
-    public AjaxResult inferStreamRag()
-    {
+//    public AjaxResult inferStreamRag(HttpServletResponse response) {
+    public void inferStreamRag(HttpServletResponse response) {
         log.info("进入了调⽤RAG+⼤模型的调⽤参数");
-        ManagedChannel channel = ManagedChannelBuilder.forAddress("10.0.0.24", 17070)
-                .usePlaintext()
-                .build();
-        InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
-        String dataJson = "{\"bot_id\":\"721\",\"exp_id\":\"721\",\"session_id\":\"C20231025153549AC11FC2905668524\",\"use_rag\":\"true\",\"prompt\":\"你是⼀个资深⽔务领域专家,能回答各种⽔务相关问题\",\"history_dia\":[\"什么是BOD\"],\"generate_args\":{\"max_new_tokens\":2048,\"max_length\":4096,\"num_beams\":1,\"do_sample\":true,\"top_p\":0.7,\"temperature\":0.95},\"extra\":{}}";
-        System.out.println(dataJson);
-        PredictionsRequest request = null;
+        // 获取输出流
+        OutputStream outputStream = null;
+        ManagedChannel channel = null;
         try {
-            request = PredictionsRequest.newBuilder()
-                    .setModelName("slibra_bot")
-                    .putInput("method", ByteString.copyFrom("infer_stream", "utf-8"))//推理
-                    .putInput("data", ByteString.copyFrom(dataJson, "utf-8"))
-                    .buildPartial();
-        } catch (UnsupportedEncodingException e) {
-            log.error("转换数据的时候,出现异常,异常为:", e);
-            throw new RuntimeException(e.getMessage());
-        }
-        Iterator<PredictionResponse> response = stub.streamPredictions(request);
-        while (response.hasNext()){
-            String responseStr = response.next().getPrediction().toStringUtf8();
-//            System.out.println(unicodeToChinese(responseStr.substring(16, responseStr.length()-1)) + "---" + responseStr);
-            System.out.println(responseStr);
+            channel = ManagedChannelBuilder.forAddress("10.0.0.24", 17070)
+                    .usePlaintext()
+                    .build();
+            InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
+            String dataJson = "{\"bot_id\":\"721\",\"exp_id\":\"721\",\"session_id\":\"C20231025153549AC11FC2905668524\",\"use_rag\":\"true\",\"prompt\":\"你是⼀个资深⽔务领域专家,能回答各种⽔务相关问题\",\"history_dia\":[\"什么是BOD\"],\"generate_args\":{\"max_new_tokens\":2048,\"max_length\":4096,\"num_beams\":1,\"do_sample\":true,\"top_p\":0.7,\"temperature\":0.95},\"extra\":{}}";
+            System.out.println(dataJson);
+            PredictionsRequest request = null;
+            try {
+                request = PredictionsRequest.newBuilder()
+                        .setModelName("slibra_bot")
+                        .putInput("method", ByteString.copyFrom("infer_stream", "utf-8"))//推理
+                        .putInput("data", ByteString.copyFrom(dataJson, "utf-8"))
+                        .buildPartial();
+            } catch (UnsupportedEncodingException e) {
+                log.error("转换数据的时候,出现异常,异常为:", e);
+                throw new RuntimeException(e.getMessage());
+            }
+            response.setContentType("text/plain");
+            response.setCharacterEncoding("utf-8");
+            outputStream = response.getOutputStream();
+            Iterator<PredictionResponse> predictions = stub.streamPredictions(request);
+            while (predictions.hasNext()) {
+                String responseStr = predictions.next().getPrediction().toStringUtf8();
+                System.out.println(responseStr);
+                if("complete".equals(responseStr)){
+                    System.out.println("结尾语句,无需售出~~~");
+                }else{
+                    outputStream.write(responseStr.getBytes());
+                    outputStream.flush();
+                }
+            }
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        } finally {
+            // 关闭输出流
+            try {
+                outputStream.close();
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }finally {
+                channel.shutdown();
+            }
         }
-        channel.shutdown();
-        return AjaxResult.success("ok");
+
+//        return AjaxResult.success("ok");
     }