Ver código fonte

允许用户在header中自定义大模型的port

王苗苗 8 meses atrás
pai
commit
dbaea7ed48

+ 15 - 6
slibra-admin/src/main/java/com/slibra/web/controller/business/GRPCController.java

@@ -30,9 +30,11 @@ import org.pytorch.serve.grpc.inference.PredictionResponse;
 import org.pytorch.serve.grpc.inference.PredictionsRequest;
 import org.springframework.beans.BeanUtils;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
 import org.springframework.util.CollectionUtils;
 import org.springframework.web.bind.annotation.*;
 
+import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import java.io.IOException;
 import java.io.OutputStream;
@@ -76,6 +78,10 @@ public class GRPCController extends BaseController {
     @Autowired
     private BigModelConfig bigModelConfig;
 
+    // 用户自定义的端口
+    @Value("${token.port}")
+    private String port;
+
 
     /**
      *
@@ -84,7 +90,7 @@ public class GRPCController extends BaseController {
      * @param response
      */
     @GetMapping(value = "/test/aaa")
-    public void decisionStreamTest(HttpServletResponse response)
+    public void decisionStreamTest(HttpServletRequest httpServletRequest, HttpServletResponse response)
 //    public void decisionStream(HttpServletResponse response,  ChatReq chatReq)
     {
         log.info("进入了调⽤大模型决策接口");
@@ -138,8 +144,9 @@ public class GRPCController extends BaseController {
 //        String rows = JSON.toJSONString(decisionReqs, JSONWriter.Feature.WriteNulls);
         boolean needAdd = true;//标识变量是否可以保存
         String dataJson = "";
+        String headerPort = httpServletRequest.getHeader(port);
         try {
-            channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), bigModelConfig.getPort())
+            channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), StringUtils.isBlank(headerPort) ? bigModelConfig.getPort() : Integer.parseInt(headerPort))
                     .usePlaintext()
                     .build();
             InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
@@ -237,7 +244,7 @@ public class GRPCController extends BaseController {
      * @return
      */
     @PostMapping(value = "/decisionStream")
-    public void decisionStream(HttpServletResponse response, @RequestBody ChatReq chatReq)
+    public void decisionStream(HttpServletRequest httpServletRequest, HttpServletResponse response, @RequestBody ChatReq chatReq)
 //    public void decisionStream(HttpServletResponse response,  ChatReq chatReq)
     {
         log.info("进入了调⽤大模型决策接口");
@@ -281,8 +288,9 @@ public class GRPCController extends BaseController {
 //        String rows = JSON.toJSONString(decisionReqs, JSONWriter.Feature.WriteNulls);
         boolean needAdd = true;//标识变量是否可以保存
         String dataJson = "";
+        String headerPort = httpServletRequest.getHeader(port);
         try {
-            channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), bigModelConfig.getPort())
+            channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), StringUtils.isBlank(headerPort) ? bigModelConfig.getPort() : Integer.parseInt(headerPort))
                     .usePlaintext()
                     .build();
             InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
@@ -465,7 +473,7 @@ public class GRPCController extends BaseController {
      * @return
      */
     @PostMapping(value = "/inferStreamRag")
-    public void inferStreamRag(HttpServletResponse response, @RequestBody ChatReq chatReq) {
+    public void inferStreamRag(HttpServletRequest httpServletRequest, HttpServletResponse response, @RequestBody ChatReq chatReq) {
 //    public void inferStreamRag(HttpServletResponse response, ChatReq chatReq) {
         log.info("进入了调⽤RAG+⼤模型的调⽤参数");
         // 获取输出流
@@ -579,8 +587,9 @@ public class GRPCController extends BaseController {
         }
         //将新的问题放入集合中
         historyDates.add(chatReq.getQuestion());
+        String headerPort = httpServletRequest.getHeader(port);
         try {
-            channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), bigModelConfig.getPort())
+            channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), StringUtils.isBlank(headerPort) ? bigModelConfig.getPort() : Integer.parseInt(headerPort))
                     .usePlaintext()
                     .build();
             InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);

+ 2 - 0
slibra-admin/src/main/resources/application.yml

@@ -106,6 +106,8 @@ token:
   secret: abcdefghijklmnopqrstuvwxyz
   # 令牌有效期(默认30分钟)
   expireTime: 1440000000
+  # 用户自定义的端口
+  port: port
 
 ## MyBatis配置
 #mybatis: