目录

简介

本文简单介绍如何搭建基于java + LightGBM的线上实时预测系统。

准备训练数据和测试数据

训练数据格式很简单,\t分割,第一列为预估值,后面为特征值。

修改配置文件

train.conf

# task type, support train and predict
task = train

# boosting type, support gbdt for now, alias: boosting, boost
boosting_type = gbdt

# application type, support following application
# regression , regression task
# binary , binary classification task
# lambdarank , lambdarank task
# alias: application, app
objective = regression

# eval metrics, support multi metric, delimite by ',' , support following metrics
# l1 
# l2 , default metric for regression
# ndcg , default metric for lambdarank
# auc 
# binary_logloss , default metric for binary
# binary_error
metric = l2

# frequence for metric output
metric_freq = 1

# true if need output metric for training data, alias: tranining_metric, train_metric
is_training_metric = true

# number of bins for feature bucket, 255 is a recommend setting, it can save memories, and also has good accuracy. 
max_bin = 255

# training data
# if exsting weight file, should name to "regression.train.weight"
# alias: train_data, train
data = TRAIN_FILE_PLACEHOLDER

# validation data, support multi validation data, separated by ','
# if exsting weight file, should name to "regression.test.weight"
# alias: valid, test, test_data, 
valid_data = VALID_FILE_PLACEHOLDER

# number of trees(iterations), alias: num_tree, num_iteration, num_iterations, num_round, num_rounds
num_trees = 100

# shrinkage rate , alias: shrinkage_rate
learning_rate = 0.05

# number of leaves for one tree, alias: num_leaf
num_leaves = 31

# type of tree learner, support following types:
# serial , single machine version
# feature , use feature parallel to train
# data , use data parallel to train
# voting , use voting based parallel to train
# alias: tree
tree_learner = serial

# number of threads for multi-threading. One thread will use one CPU, default is setted to #cpu. 
# num_threads = 8

# feature sub-sample, will random select 80% feature to train on each iteration 
# alias: sub_feature
feature_fraction = 0.9

# Support bagging (data sub-sample), will perform bagging every 5 iterations
bagging_freq = 5

# Bagging farction, will random select 80% data on bagging
# alias: sub_row
bagging_fraction = 0.8

# minimal number data for one leaf, use this to deal with over-fit
# alias : min_data_per_leaf, min_data
min_data_in_leaf = 100

# minimal sum hessians for one leaf, use this to deal with over-fit
min_sum_hessian_in_leaf = 5.0

# save memory and faster speed for sparse feature, alias: is_sparse
is_enable_sparse = true

# when data is bigger than memory size, set this to true. otherwise set false will have faster speed
# alias: two_round_loading, two_round
use_two_round_loading = false

# true if need to save data to binary file and application will auto load data from binary file next time
# alias: is_save_binary, save_binary
is_save_binary_file = false

# output model file
output_model = LightGBM_model.txt

# support continuous train from trained gbdt model
# input_model= trained_model.txt

# output prediction file for predict task
# output_result= prediction.txt

# support continuous train from initial score file
# input_init_score= init_score.txt


# number of machines in parallel training, alias: num_machine
num_machines = 1

# local listening port in parallel training, alias: local_port
local_listen_port = 12400

# machines list file for parallel training, alias: mlist
machine_list_file = mlist.txt

predict.conf


task = predict

data = PREDICT_FILE_PLACEHOLDER

input_model= LightGBM_model.txt

修改

sample_train_file、sample_valid_file为训练数据集和测试数据集的绝对路径地址。

sed -i -e "s/TRAIN_FILE_PLACEHOLDER/${sample_train_file}/g" -e "s/VALID_FILE_PLACEHOLDER/${sample_valid_file}/g" train.conf
sed -i -e "s/PREDICT_FILE_PLACEHOLDER/${sample_valid_file}/g" predict.conf

训练和预测

lightgbm=${dir_name}git/LightGBM/lightgbm

${lightgbm} config=train.conf
${lightgbm} config=predict.conf

其中lightgbm使用的是git上的代码,https://github.com/Microsoft/LightGBM 。

转换成pmml

因为lightgbm生成的模型文件是txt的,需要转换成pmml文件。

java -jar ${jar_file} --lgbm-input LightGBM_model.txt --pmml-output LightGBM_model.txt.pmml

sed -i -e 's/PMML-4_3/PMML-4_2/g' LightGBM_model.txt.pmml
sed -i -e 's/version=\"4.3\"/version=\"4.2\"/g' LightGBM_model.txt.pmml

上面的jar_file指向转换用的jar文件converter-executable-1.2-SNAPSHOT.jar,可以网上下载得到,如果出现问题需要自己修改代码。

线上预测

定义PmmlModel

package com.fashici.model;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.model.ImportFilter;
import org.jpmml.model.JAXBUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.InputSource;

import javax.xml.transform.sax.SAXSource;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class PmmlModel {
  private static final Logger LOGGER = LoggerFactory.getLogger(PmmlModel.class);

  private PMML pmml;
  private Evaluator evaluator;

  private PmmlModel() {

  }

  public PmmlModel(String pmmlPath) throws Exception {
    pmml = unmarshal(pmmlPath);
    ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
    evaluator = modelEvaluatorFactory.newModelManager(pmml);
    evaluator.verify();
  }

  public PmmlModel(InputStream inputStream) throws Exception {
    InputSource source = new InputSource(inputStream);
    SAXSource transformedSource = ImportFilter.apply(source);
    pmml = JAXBUtil.unmarshalPMML(transformedSource);
    ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
    evaluator = modelEvaluatorFactory.newModelManager(pmml);
    evaluator.verify();
  }

  public Double predict(List<Double> features) {
    List<FieldName> activeFields = evaluator.getActiveFields();
    if (features.size() < activeFields.size()) {
      LOGGER.error("features size {} is less than activeFields size {}!",
              features.size(), activeFields.size());
      return null;
    }
    Map<FieldName, Double> fieldMap = new HashMap<>();
    for (int i = 0; i < activeFields.size(); ++i) {
      fieldMap.put(new FieldName("Column_" + i), features.get(i));
    }
    try {
      Map<FieldName, ?> result = evaluator.evaluate(fieldMap);
      return (Double) result.get(new FieldName("_target"));
    } catch (Exception e) {
      LOGGER.error("error {}", e);
    }
    return null;
  }

  private PMML unmarshal(String modelPath) throws Exception {
    InputSource source = new InputSource(new FileInputStream(new File(modelPath)));
    SAXSource transformedSource = ImportFilter.apply(source);
    return JAXBUtil.unmarshalPMML(transformedSource);
  }
}

使用PmmlModel预测

InputStream inputStream = ModelDataLoader.class.getClassLoader().getResourceAsStream(MODEL_FILEPATH);
PmmlModel lgbmModel = new PmmlModel(inputStream);
List<Double> features = ModelFeature.getFeatureList(someVar);
double score = lgbmModel.predict(features);