|
@@ -30,9 +30,11 @@ import org.pytorch.serve.grpc.inference.PredictionResponse;
|
|
|
import org.pytorch.serve.grpc.inference.PredictionsRequest;
|
|
|
import org.springframework.beans.BeanUtils;
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.beans.factory.annotation.Value;
|
|
|
import org.springframework.util.CollectionUtils;
|
|
|
import org.springframework.web.bind.annotation.*;
|
|
|
|
|
|
+import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
|
import java.io.IOException;
|
|
|
import java.io.OutputStream;
|
|
@@ -76,6 +78,10 @@ public class GRPCController extends BaseController {
|
|
|
@Autowired
|
|
|
private BigModelConfig bigModelConfig;
|
|
|
|
|
|
+ // 用户自定义的端口
|
|
|
+ @Value("${token.port}")
|
|
|
+ private String port;
|
|
|
+
|
|
|
|
|
|
/**
|
|
|
*
|
|
@@ -84,7 +90,7 @@ public class GRPCController extends BaseController {
|
|
|
* @param response
|
|
|
*/
|
|
|
@GetMapping(value = "/test/aaa")
|
|
|
- public void decisionStreamTest(HttpServletResponse response)
|
|
|
+ public void decisionStreamTest(HttpServletRequest httpServletRequest, HttpServletResponse response)
|
|
|
// public void decisionStream(HttpServletResponse response, ChatReq chatReq)
|
|
|
{
|
|
|
log.info("进入了调⽤大模型决策接口");
|
|
@@ -138,8 +144,9 @@ public class GRPCController extends BaseController {
|
|
|
// String rows = JSON.toJSONString(decisionReqs, JSONWriter.Feature.WriteNulls);
|
|
|
boolean needAdd = true;//标识变量是否可以保存
|
|
|
String dataJson = "";
|
|
|
+ String headerPort = httpServletRequest.getHeader(port);
|
|
|
try {
|
|
|
- channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), bigModelConfig.getPort())
|
|
|
+ channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), StringUtils.isBlank(headerPort) ? bigModelConfig.getPort() : Integer.parseInt(headerPort))
|
|
|
.usePlaintext()
|
|
|
.build();
|
|
|
InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
|
|
@@ -237,7 +244,7 @@ public class GRPCController extends BaseController {
|
|
|
* @return
|
|
|
*/
|
|
|
@PostMapping(value = "/decisionStream")
|
|
|
- public void decisionStream(HttpServletResponse response, @RequestBody ChatReq chatReq)
|
|
|
+ public void decisionStream(HttpServletRequest httpServletRequest, HttpServletResponse response, @RequestBody ChatReq chatReq)
|
|
|
// public void decisionStream(HttpServletResponse response, ChatReq chatReq)
|
|
|
{
|
|
|
log.info("进入了调⽤大模型决策接口");
|
|
@@ -281,8 +288,9 @@ public class GRPCController extends BaseController {
|
|
|
// String rows = JSON.toJSONString(decisionReqs, JSONWriter.Feature.WriteNulls);
|
|
|
boolean needAdd = true;//标识变量是否可以保存
|
|
|
String dataJson = "";
|
|
|
+ String headerPort = httpServletRequest.getHeader(port);
|
|
|
try {
|
|
|
- channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), bigModelConfig.getPort())
|
|
|
+ channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), StringUtils.isBlank(headerPort) ? bigModelConfig.getPort() : Integer.parseInt(headerPort))
|
|
|
.usePlaintext()
|
|
|
.build();
|
|
|
InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
|
|
@@ -465,7 +473,7 @@ public class GRPCController extends BaseController {
|
|
|
* @return
|
|
|
*/
|
|
|
@PostMapping(value = "/inferStreamRag")
|
|
|
- public void inferStreamRag(HttpServletResponse response, @RequestBody ChatReq chatReq) {
|
|
|
+ public void inferStreamRag(HttpServletRequest httpServletRequest, HttpServletResponse response, @RequestBody ChatReq chatReq) {
|
|
|
// public void inferStreamRag(HttpServletResponse response, ChatReq chatReq) {
|
|
|
log.info("进入了调⽤RAG+⼤模型的调⽤参数");
|
|
|
// 获取输出流
|
|
@@ -579,8 +587,9 @@ public class GRPCController extends BaseController {
|
|
|
}
|
|
|
//将新的问题放入集合中
|
|
|
historyDates.add(chatReq.getQuestion());
|
|
|
+ String headerPort = httpServletRequest.getHeader(port);
|
|
|
try {
|
|
|
- channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), bigModelConfig.getPort())
|
|
|
+ channel = ManagedChannelBuilder.forAddress(bigModelConfig.getIp(), StringUtils.isBlank(headerPort) ? bigModelConfig.getPort() : Integer.parseInt(headerPort))
|
|
|
.usePlaintext()
|
|
|
.build();
|
|
|
InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
|