代码如下:
import org.apache.spark.mllib.classification.LogisticRegressionModel
val modelSavePath = "/user/tech/model"
val model = LogisticRegressionModel.load(sc, modelSavePath)
model.weights # 模型参数w值
model.intercept # 模型参数b值
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
// Load training data in LIBSVM format.
// val data = MLUtils.loadLibSVMFile(sc, "/path/to/lr/training_data_with_wordemb")
val data = MLUtils.loadLibSVMFile(sc, "/path/to/model/data/train_data_onehot_crossfeature")
// Split data into training (60%) and test (40%).
val splits = data.randomSplit(Array(0.7, 0.3), seed = 11L)
val training = splits(0).cache()
val test = splits(1)
// Run training algorithm to build the model
val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(training)
// Compute raw scores on the test set.
val scoreAndLabels = test.map { case LabeledPoint(label, features) =>
val prediction = model.predict(features)
(prediction, label)
}
// Get evaluation metrics.
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val auROC = metrics.areaUnderROC()
val auPR = metrics.areaUnderPR()
println("Area under ROC = " + auROC)
// Save and load model
// model.save(sc, "/path/to/model/LogisticRegressionWithLBFGS")
// val sameModel = LogisticRegressionModel.load(sc, "/path/to/model/LogisticRegressionWithLBFGS")