package org.apache.spark.mllib.optimization
import org.apache.spark.{Logging, SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot}
import org.apache.spark.mllib.util.MLUtils
import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.mllib.linalg._
import breeze.linalg.{Vector => BV, axpy => brzAxpy, norm => brzNorm}
import java.io._
import breeze.numerics.sqrt
class LRModel(weights:Vector, intercept:Double, rate:Double) extends Serializable{
val weightMatrix=weights
def predictPoint(dataMatrix: Vector) = {
val margin = dot(weightMatrix, dataMatrix) + intercept
val score = 1.0 / (1.0 + math.exp(-margin))
score/(score+(1-score)/rate)
}
}
class LRGradient extends Serializable{
def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
val gradient = Vectors.zeros(weights.size)
val loss = compute(data, label, weights, gradient)
(gradient, loss)
}
def compute(data: Vector, label: Double, weights: Vector, cumGradient: Vector): Double = {
val margin = -1.0 * dot(data, weights)
val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
axpy(multiplier, data, cumGradient)
if (label > 0) {
MLUtils.log1pExp(margin)
} else {
MLUtils.log1pExp(margin) - margin
}
}
}
class FTRLLR(numFeatures:Int,iter:Int,a:Double) extends Serializable {
private val gradient = new LRGradient()
def compute(data: RDD[(Double, Vector)],ww:Array[Double],n:Array[Double],z:Array[Double]): (Vector,Array[Double],Array[Double]) = {
val numWeight = numFeatures
val b = 1.0
val L1 = 1.0
val L2 = 1.0
val minibatch=1.0
/*
这里直接对应的论文的步骤
*/
for (it <- 0 until iter) {
val bcWeights = data.context.broadcast(ww)
val tmp=data.sample(false, minibatch, 42)
.treeAggregate((BDV.zeros[Double](numWeight), 0.0, 0L,Vectors.zeros(numWeight)))(
seqOp = (c, v) => {
val l = gradient.compute(v._2, v._1, Vectors.dense(bcWeights.value), Vectors.fromBreeze(c._1))
(c._1, c._2 + l, c._3 + 1, v._2)
},
combOp = (c1, c2) => {
(c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3,Vectors.fromBreeze(c1._4.toBreeze+c2._4.toBreeze))
})
val g = Vectors.fromBreeze(tmp._1 / minibatch)
val feature=Vectors.fromBreeze(tmp._4.toBreeze/minibatch)
feature.foreachActive {
case (i, v) =>
var sign = -1.0
if (z(i) > 0) sign = 1.0
if (sign * z(i) < L1) ww(i) = 0
else ww(i) = (sign * L1 - z(i)) / ((b + sqrt(n(i))) / a + L2)
val sigma = (sqrt(n(i) + g(i) * g(i)) - sqrt(n(i))) / a
z(i) += g(i) - sigma * ww(i)
n(i) += g(i) * g(i)
}
}
(Vectors.dense(ww),n,z)
}
}
scala FTRL 实现
ftrl相关文章
最近热门
- 7.1.1 设置spark.driver.maxResultSize
- C++ 整形数转字符串
- 论文:Capturing Delayed Feedback in Conversion Rate Prediction via Elapsed-Time Sampling
- Minimum Detectable Effect(MDE)最小可检测效应
- werkzeug ImportError: cannot import name 'secure_filename'
- SFT(Supervised Fine-Tuning,即有监督微调)
- STT模型(Speech-to-Text)
- 论文阅读 TOKEN MERGING: YOUR VIT BUT FASTER(ToMe模型)
- 因果推断 | uplift | 营销增长 | 增长算法 | 智能营销
- SSB - Sample Selection Bias - 样本选择偏差问题
最常浏览
- 016 推荐系统 | 排序学习(LTR - Learning To Rank)
- 偏微分符号
- i.i.d(又称IID)
- 利普希茨连续条件(Lipschitz continuity)
- (error) MOVED 原因和解决方案
- TextCNN详解
- 找不到com.google.protobuf.GeneratedMessageV3的类文件
- Deployment failed: repository element was not specified in the POM inside distributionManagement
- cannot access com.google.protobuf.GeneratedMessageV3 解决方案
- CLUSTERDOWN Hash slot not served 问题原因和解决办法
×