package org.apache.spark.mllib.optimization

import org.apache.spark.{Logging, SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot}
import org.apache.spark.mllib.util.MLUtils
import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.mllib.linalg._
import breeze.linalg.{Vector => BV, axpy => brzAxpy, norm => brzNorm}
import java.io._

import breeze.numerics.sqrt


class LRModel(weights:Vector, intercept:Double, rate:Double) extends Serializable{
  val weightMatrix=weights
  def predictPoint(dataMatrix: Vector) = {
    val margin = dot(weightMatrix, dataMatrix) + intercept
    val score = 1.0 / (1.0 + math.exp(-margin))
    score/(score+(1-score)/rate)
  }
}

class LRGradient extends Serializable{
  def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
    val gradient = Vectors.zeros(weights.size)
    val loss = compute(data, label, weights, gradient)
    (gradient, loss)
  }
  def compute(data: Vector, label: Double, weights: Vector, cumGradient: Vector): Double = {
    val margin = -1.0 * dot(data, weights)
    val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
    axpy(multiplier, data, cumGradient)
    if (label > 0) {
      MLUtils.log1pExp(margin)
    } else {
      MLUtils.log1pExp(margin) - margin
    }

  }
}



class FTRLLR(numFeatures:Int,iter:Int,a:Double) extends Serializable {
  private val gradient = new LRGradient()
  def compute(data: RDD[(Double, Vector)],ww:Array[Double],n:Array[Double],z:Array[Double]): (Vector,Array[Double],Array[Double]) = {
    val numWeight = numFeatures
    val b = 1.0
    val L1 = 1.0
    val L2 = 1.0
    val minibatch=1.0
    /*
  这里直接对应的论文的步骤
 */
    for (it <- 0 until iter) {
      val bcWeights = data.context.broadcast(ww)
      val tmp=data.sample(false, minibatch, 42)
        .treeAggregate((BDV.zeros[Double](numWeight), 0.0, 0L,Vectors.zeros(numWeight)))(
          seqOp = (c, v) => {
            val l = gradient.compute(v._2, v._1, Vectors.dense(bcWeights.value), Vectors.fromBreeze(c._1))
            (c._1, c._2 + l, c._3 + 1, v._2)
          },
          combOp = (c1, c2) => {
            (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3,Vectors.fromBreeze(c1._4.toBreeze+c2._4.toBreeze))
          })
      val g = Vectors.fromBreeze(tmp._1 / minibatch)
      val feature=Vectors.fromBreeze(tmp._4.toBreeze/minibatch)
      feature.foreachActive {
        case (i, v) =>
          var sign = -1.0
          if (z(i) > 0) sign = 1.0
          if (sign * z(i) < L1) ww(i) = 0
          else ww(i) = (sign * L1 - z(i)) / ((b + sqrt(n(i))) / a + L2)
          val sigma = (sqrt(n(i) + g(i) * g(i)) - sqrt(n(i))) / a
          z(i) += g(i) - sigma * ww(i)
          n(i) += g(i) * g(i)
      }
    }
    (Vectors.dense(ww),n,z)
  }
}