|
@@ -1,76 +0,0 @@
|
|
|
-package org.pytorch.serve.grpc.inference;
|
|
|
-
|
|
|
-import com.alibaba.fastjson.JSON;
|
|
|
-import com.google.protobuf.ByteString;
|
|
|
-import com.google.protobuf.Empty;
|
|
|
-import com.google.protobuf.InvalidProtocolBufferException;
|
|
|
-import io.grpc.ManagedChannel;
|
|
|
-import io.grpc.ManagedChannelBuilder;
|
|
|
-import org.apache.commons.lang3.CharSet;
|
|
|
-import org.apache.commons.lang3.CharSetUtils;
|
|
|
-
|
|
|
-import java.io.UnsupportedEncodingException;
|
|
|
-import java.nio.charset.Charset;
|
|
|
-import java.nio.charset.StandardCharsets;
|
|
|
-import java.util.HashMap;
|
|
|
-import java.util.Iterator;
|
|
|
-
|
|
|
-public class Test {
|
|
|
-
|
|
|
- public static void main(String[] args) throws InvalidProtocolBufferException, UnsupportedEncodingException {
|
|
|
- ManagedChannel channel = ManagedChannelBuilder.forAddress("10.0.0.24", 17070)
|
|
|
- .usePlaintext()
|
|
|
- .build();
|
|
|
- InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub = InferenceAPIsServiceGrpc.newBlockingStub(channel);
|
|
|
-
|
|
|
-// TorchServeHealthResponse response = inferenceAPIsServiceBlockingStub.ping(Empty.getDefaultInstance());
|
|
|
-// Iterator<PredictionResponse> response = inferenceAPIsServiceBlockingStub.streamPredictions(PredictionsRequest.getDefaultInstance());
|
|
|
- String dataJson = "{\"bot_id\":\"721\",\"exp_id\":\"721\",\"session_id\":\"C20231025153549AC11FC2905668524\",\"prompt\":\"你是一个客服\",\"history_dia\":[\"您好!|||\",\"喂您好,我是您的客服顾问,我姓刘,工号是96466,那今天给您来电话,主要就是想跟您说一下,您之前领取了一份免费的保障,这份保障是完全免费的,我这边也是带您去看一下您的保障内容和理赔服务都有哪些。嗯就耽误您几分钟时间给您做个讲解好吧?\",\"不接电话。\",\"喂您好,我是您的客服顾问,我姓刘,工号是96466,那今天给您来电话,主要就是想跟您说一下,您之前领取了一份免费的保障,这份保障是完全免费的,我这边也是带您去看一下您的保障内容和理赔服务都有哪些。嗯就耽误您几分钟时间给您做个讲解好吧?\",\"好,不需要。\",\"啊是这样,我这边是水滴保的,就是您之前在水滴领取过一个免费的赠险吗?我这边就是带您去我们的微信公众号上查看一下这个保单,然后告诉您日后怎么去理赔。\",\"不好,就把电话给打了。我喝完。\",\"啊,您现在啊是这样的,咱们的保障呢已经生效了,要么是咱免费领的,要么是咱付过钱了的。如果出险,要花很多的钱的话,咱是可以理赔报销的。现在咱们花几分钟时间带您了解一下具体的内容和操作流程,您看好吧。\",\"嗯。查询。\",\"哦,那您现在手机是在通话页面还是打开微信了?\",\"你你说你说我现在听听。\"],\"generate_args\":{\"max_new_tokens\":1024,\"max_length\":4096,\"num_beams\":1,\"do_sample\":true,\"top_p\":0.7,\"temperature\":0.95},\"extra\":{\"TP_in\":100,\"andan_in_name\":110}}";
|
|
|
-// dataJson = "{\"bot_id\":\"b00001\",\"exp_id\":\"721\",\"norm\":\"进水总氮\",\"session_id\":\"C20231025153549AC11FC2905668524\",\"generate_args\":{\"max_new_tokens\":1024,\"max_length\":4096,\"num_beams\":1,\"do_sample\":true,\"top_p\":0.7,\"temperature\":0.95},\"extra\":{}}";
|
|
|
- System.out.println(dataJson);
|
|
|
- PredictionsRequest request = PredictionsRequest.newBuilder()
|
|
|
- .setModelName("slibra_bot")
|
|
|
- .putInput("method", ByteString.copyFrom("infer_stream", "utf-8"))//推理
|
|
|
-// .putInput("method", ByteString.copyFrom("decision_stream", "utf-8"))//决策
|
|
|
- .putInput("data", ByteString.copyFrom(dataJson, "utf-8"))
|
|
|
- .buildPartial();
|
|
|
- Iterator<PredictionResponse> response = stub.streamPredictions(request);
|
|
|
- while (response.hasNext()){
|
|
|
- String responseStr = response.next().getPrediction().toStringUtf8();
|
|
|
-// System.out.println(unicodeToChinese(responseStr.substring(16, responseStr.length()-1)) + "---" + responseStr);
|
|
|
- System.out.println(responseStr);
|
|
|
- }
|
|
|
- channel.shutdown();
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
- /*public static void main(String[] args) {
|
|
|
- String encoded = "\\345\\216\\214\\346\\260\\247\\346\\261\\240\\347\\241\\235\\351\\205\\270\\347\\233\\220\\346\\260\\256\\346\\230\\257\\345\\220\\246\\345\\244\\247\\344\\272";
|
|
|
- String decoded = decodeOctalToUtf8(encoded);
|
|
|
- System.out.println(decoded);
|
|
|
- }*/
|
|
|
-
|
|
|
-
|
|
|
- public static String decodeOctalToUtf8(String encoded) {
|
|
|
- // 移除反斜杠
|
|
|
- String octalSequence = encoded.replaceAll("\\\\", "");
|
|
|
-
|
|
|
- // 检查长度是否是3的倍数
|
|
|
- if (octalSequence.length() % 3 != 0) {
|
|
|
- throw new IllegalArgumentException("Encoded string length is not a multiple of 3");
|
|
|
- }
|
|
|
-
|
|
|
- byte[] bytes = new byte[octalSequence.length() / 3];
|
|
|
- for (int i = 0, j = 0; i < octalSequence.length(); i += 3, j++) {
|
|
|
- // 提取每三个字符的八进制数
|
|
|
- String octal = octalSequence.substring(i, i + 3);
|
|
|
- // 将八进制数转换为字节
|
|
|
- bytes[j] = (byte) Integer.parseInt(octal, 8);
|
|
|
- }
|
|
|
- // 将字节序列转换为UTF-8编码的字符串
|
|
|
- return new String(bytes, StandardCharsets.UTF_8);
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-}
|