Java和Python使用Grpc访问Tensorflow的Serving代码

发现网上大量的代码都是mnist,我自己反正不是搞图像处理的,所以这个例子我怎么都不想搞;

wide&deep这种,包含各种特征的模型,才是我的需要,iris也是从文本训练模型,所以非常简单;

本文给出Python和Java访问Tensorflow的Serving代码。

Java版本使用Grpc访问Tensorflow的Serving代码

package io.github.qf6101.tensorflowserving;

import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder;
import org.tensorflow.example.*;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 参考:https://www.jianshu.com/p/d82107165119
 * 参考:https://github.com/grpc/grpc-java
 */
public class PssIrisGrpcClient {

    public static Example createExample() {
        Features.Builder featuresBuilder = Features.newBuilder();

        Map<String, Float> dataMap = new HashMap<String, Float>();
        dataMap.put("SepalLength", 5.1f);
        dataMap.put("SepalWidth", 3.3f);
        dataMap.put("PetalLength", 1.7f);
        dataMap.put("PetalWidth", 0.5f);

        Map<String, Feature> featuresMap = mapToFeatureMap(dataMap);
        featuresBuilder.putAllFeature(featuresMap);

        Features features = featuresBuilder.build();
        Example.Builder exampleBuilder = Example.newBuilder();
        exampleBuilder.setFeatures(features);
        return exampleBuilder.build();
    }

    private static Map<String, Feature> mapToFeatureMap(Map<String, Float> dataMap) {
        Map<String, Feature> resultMap = new HashMap<String, Feature>();
        for (String key : dataMap.keySet()) {
            // // data1 = {"SepalLength":5.1,"SepalWidth":3.3,"PetalLength":1.7,"PetalWidth":0.5}
            FloatList floatList = FloatList.newBuilder().addValue(dataMap.get(key)).build();
            Feature feature = Feature.newBuilder().setFloatList(floatList).build();
            resultMap.put(key, feature);
        }
        return resultMap;
    }

    public static void main(String[] args) {
        String host = "127.0.0.1";
        int port = 8888;

        ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port)
                // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid
                // needing certificates.
                .usePlaintext()
                .build();
        PredictionServiceGrpc.PredictionServiceBlockingStub blockingStub = PredictionServiceGrpc.newBlockingStub(channel);

        com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder()
                .setValue(1)
                .build();

        Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder()
                .setName("iris")
                .setVersion(version)
                .setSignatureName("classification")
                .build();

        List<ByteString> exampleList = new ArrayList<ByteString>();
        exampleList.add(createExample().toByteString());

        TensorShapeProto.Dim featureDim = TensorShapeProto.Dim.newBuilder().setSize(exampleList.size()).build();
        TensorShapeProto shapeProto = TensorShapeProto.newBuilder().addDim(featureDim).build();
        org.tensorflow.framework.TensorProto tensorProto = TensorProto.newBuilder().addAllStringVal(exampleList).setDtype(DataType.DT_STRING).setTensorShape(shapeProto).build();

        Predict.PredictRequest request = Predict.PredictRequest.newBuilder()
                .setModelSpec(modelSpec)
                .putInputs("inputs", tensorProto)
                .build();
        tensorflow.serving.Predict.PredictResponse response = blockingStub.predict(request);
        System.out.println(response);

        channel.shutdown();
    }
}

需要增加如下maven依赖:

        <!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow -->
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.12.0</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/io.grpc/grpc-netty -->
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-netty</artifactId>
            <version>1.20.0</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/io.grpc/grpc-protobuf -->
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-protobuf</artifactId>
            <version>1.20.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/io.grpc/grpc-stub -->
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-stub</artifactId>
            <version>1.20.0</version>
        </dependency>

输出结果:

outputs {
  key: "scores"
  value {
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: 1
      }
      dim {
        size: 3
      }
    }
    float_val: 0.9997806
    float_val: 2.1938368E-4
    float_val: 1.382611E-9
  }
}
outputs {
  key: "classes"
  value {
    dtype: DT_STRING
    tensor_shape {
      dim {
        size: 1
      }
      dim {
        size: 3
      }
    }
    string_val: "0"
    string_val: "1"
    string_val: "2"
  }
}

Python版本使用Grpc访问Tensorflow的Serving代码

# 创建 gRPC 连接
import pandas as pd
from grpc.beta import implementations
import tensorflow as tf
from tensorflow_serving.apis import prediction_service_pb2, classification_pb2

#channel = implementations.insecure_channel('127.0.0.1', 8500):8888
channel = implementations.insecure_channel('127.0.0.1', 8888)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

def _create_feature(v):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[v]))

data1 = {"SepalLength":5.1,"SepalWidth":3.3,"PetalLength":1.7,"PetalWidth":0.5}
features1 = {k: _create_feature(v) for k, v in data1.items()}
example1 = tf.train.Example(features=tf.train.Features(feature=features1))

data2 = {"SepalLength":1.1,"SepalWidth":1.3,"PetalLength":1.7,"PetalWidth":0.5}
features2 = {k: _create_feature(v) for k, v in data2.items()}
example2 = tf.train.Example(features=tf.train.Features(feature=features2))

# 获取测试数据集,并转换成 Example 实例。
examples = [example1, example2]

# 准备 RPC 请求,指定模型名称。
request = classification_pb2.ClassificationRequest()
request.model_spec.name = 'iris'
request.input.example_list.examples.extend(examples)

# 获取结果
response = stub.Classify(request, 10.0)
print(response)

Python代码看起来简单不少,但是我们的线上服务都是Java,所以不好集成的,只能做一些离线的批量预测;

输出如下:

result {
  classifications {
    classes {
      label: "0"
      score: 0.9997805953025818
    }
    classes {
      label: "1"
      score: 0.00021938368445262313
    }
    classes {
      label: "2"
      score: 1.382611025668723e-09
    }
  }
  classifications {
    classes {
      label: "0"
      score: 0.0736534595489502
    }
    classes {
      label: "1"
      score: 0.8393719792366028
    }
    classes {
      label: "2"
      score: 0.08697459846735
    }
  }
}
model_spec {
  name: "iris"
  version {
    value: 1
  }
  signature_name: "serving_default"
}

个人其实非常喜欢HTTP+JSON接口,完全不用搞这么多grpc这些麻烦的东西,尤其Java的grpc,遇到好多问题好崩溃;

不过号称grpc比http性能好不少,线上只能用grpc。

相关推荐

Leave a Comment