简介
TensorFlow Serving可以将离线训练好的机器学习模型轻松部署到线上,使用gRPC作为接口供外部调用。并且TensorFlow Serving可以支持模型热更新与自动模型版本管理,这可以让算法工作者将工作重心放在离线模型的效果优化上,而不用为线上服务操心。
本文就介绍如何用TensorFlow Serving搭建线性回归预测服务,当然针对这个线性回归任务你可以把训练好的w、b两个参数直接写到代码里,本文只是用简单例子做入门。
环境准备
下面列出本次实验的环境:
- bazel 0.11.0:用于编译部署模型
- java 1.8:用于线上服务
- python 2.7: 用于训练模型
- tensorflow 1.5.0: 用于训练模型
- gRPC:用于接口调用
- gcc 4.8.5:编译代码
模型训练和保存
环境准备好后,接下来我们用python写一个训练和保存模型的代码train.py。
模型代码:
#!/usr/bin/env python
import numpy as np
import tensorflow as tf
import tensorflow.contrib.session_bundle.exporter as exporter
# Generate input data
n_samples = 1000
x_data = np.arange(100, step=.1)
y_data = x_data + 20 * np.sin(x_data / 10)
x_data = np.reshape(x_data, (n_samples, 1))
y_data = np.reshape(y_data, (n_samples, 1))
sample = 1000
learning_rate = 0.01
batch_size = 100
n_steps = 500
# Placeholders for batched input
x = tf.placeholder(tf.float32, shape=(batch_size, 1))
y = tf.placeholder(tf.float32, shape=(batch_size, 1))
with tf.variable_scope('test'):
w = tf.get_variable('weights', (1, 1), initializer=tf.random_normal_initializer())
b = tf.get_variable('bias', (1,), initializer=tf.constant_initializer(0))
y_pred = tf.matmul(x, w) + b
loss = tf.reduce_sum((y - y_pred) ** 2 / n_samples)
opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for _ in range(n_steps):
indices = np.random.choice(n_samples, batch_size)
x_batch = x_data[indices]
y_batch = y_data[indices]
_, loss_val = sess.run([opt, loss], feed_dict={x:x_batch, y:y_batch})
print(w.eval())
print(b.eval())
print(loss_val)
saver = tf.train.Saver()
model_exporter = exporter.Exporter(saver)
model_exporter.init(
sess.graph.as_graph_def(),
named_graph_signatures={
'inputs': exporter.generic_signature({'x': x}),
'outputs': exporter.generic_signature({'y': y_pred})})
model_exporter.export("/tmp/linear-regression/",
tf.constant("1"),
sess)
运行代码:
python train.py
模型结果
[[1.0108936]]
[1.9290849]
19.266457
上面结果表示w=1.0108936,b=1.9290849,y = 1.0108936 * x + 1.9290849
/tmp/linear-regression/ 目录下有以下文件(我训练了两次,第二次的版本号指定为2,所以有两个文件夹):
.
├── 00000001
│ ├── checkpoint
│ ├── export.data-00000-of-00001
│ ├── export.index
│ └── export.meta
└── 00000002
├── checkpoint
├── export.data-00000-of-00001
├── export.index
└── export.meta
模型部署
# 下载tensorflow_serving源码
git clone https://github.com/tensorflow/serving.git
# 编译tensorflow_model_server
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
# 启动服务
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=test --model_base_path=/tmp/linear-regression
执行上述命令后,出现以下输出,就表示部署成功了。
Running ModelServer at 0.0.0.0:9000 ...
请求服务
模型部署完后,接下来我们用java来编写线上请求服务的代码:
ManagedChannel channel = null;
try {
channel = ManagedChannelBuilder.forAddress("服务所部属的ip地址", 9000).usePlaintext(true).build();
PredictionServiceGrpc.PredictionServiceBlockingStub stub =
PredictionServiceGrpc.newBlockingStub(channel);
Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName("test");
predictRequestBuilder.setModelSpec(modelSpecBuilder);
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
List<Float> floatList = new ArrayList<>();
Random random = new Random();
float x = random.nextFloat();
floatList.add(x);
tensorProtoBuilder.addAllFloatVal(floatList);
predictRequestBuilder.putInputs("x", tensorProtoBuilder.build());
Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
LOG.debug("x={}, result={}", x,
predictResponse.getOutputsOrThrow("y").getFloatValList().toString());
} catch (StatusRuntimeException e) {
LOG.error("StatusRuntimeException: ", e);
} finally {
if (channel != null) {
channel.shutdown();
}
}
依赖代码:
<dependency>
<groupId>com.yesup.oss</groupId>
<artifactId>tensorflow-client</artifactId>
<version>1.4-2</version>
</dependency>
请求后:
x=0.033170342, result=[1.9626166]
x=0.90723795, result=[2.846206]
模型更新
更改train.py里的版本号,比如1改为2,然后执行。
执行完后,模型会进行自动更新。
参数如下:
[[1.0106797]]
[2.5606086]
20.278612
重新请求后:
x=0.8308283, result=[3.40031]
参考
参考了不少博客,这里只列出两篇实用的。