spark实现gbdt和lr

10,403次阅读
2 条评论

共计 2378 个字符,预计需要花费 6 分钟才能阅读完成。

spark实现gbdt和lr

spark对python开放的接口实在是有限,只有scala是亲生的。查了下scala的包和函数,发现提供的真全,博主从零开始撸scala代码,边写边查的节奏,给出以下example代码给大家参考

import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, FeatureType, Strategy}
import org.apache.spark.mllib.tree.model.Node
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession

object gbdt_lr {
  //get decision tree leaf's nodes
  def getLeafNodes(node: Node): Array[Int] = {
    var treeLeafNodes = new Array[Int](0)
    if (node.isLeaf) {
      treeLeafNodes = treeLeafNodes.:+(node.id)
    } else {
      treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.leftNode.get)
      treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.rightNode.get)
    }
    treeLeafNodes
  }

  // predict decision tree leaf's node value
  def predictModify(node: Node, features: DenseVector): Int = {
    val split = node.split
    if (node.isLeaf) {
      node.id
    } else {
      if (split.get.featureType == FeatureType.Continuous) {
        if (features(split.get.feature) <= split.get.threshold) {
          //          println("Continuous left node")
          predictModify(node.leftNode.get, features)
        } else {
          //          println("Continuous right node")
          predictModify(node.rightNode.get, features)
        }
      } else {
        if (split.get.categories.contains(features(split.get.feature))) {
          //          println("Categorical left node")
          predictModify(node.leftNode.get, features)
        } else {
          //          println("Categorical right node")
          predictModify(node.rightNode.get, features)
        }
      }
    }
  }

  def main(args: Array[String]) {

    val sparkConf = new SparkConf().setAppName("GbdtAndLr")
    sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    val sampleDir = "/Users/leiyang/IdeaProjects/spark_2.3/src/watermelon3_0_En.csv"
    val sc = new SparkContext(sparkConf)
    val spark = SparkSession.builder.config(sparkConf).getOrCreate()
    val dataFrame = spark.read.format("CSV").option("header", "true").load(sampleDir)

    val data = dataFrame.rdd.map { x =>
      LabeledPoint(x(9).toString().toInt, new DenseVector(Array(x(1).toString().toInt, x(2).toString().toInt, x(3).toString().toInt,
        x(4).toString().toInt, x(5).toString().toInt, x(6).toString().toInt)))
    }
    val splits = data.randomSplit(Array(0.8, 0.2))
    val train = splits(0)
    val test = splits(1)
    //
    //    // GBDT Model
    val numTrees = 2
    val boostingStrategy = BoostingStrategy.defaultParams("Classification")
    boostingStrategy.setNumIterations(numTrees)
    val treeStratery = Strategy.defaultStrategy("Classification")
    treeStratery.setMaxDepth(5)
    treeStratery.setNumClasses(2)
    //    treeStratery.setCategoricalFeaturesInfo(Map[Int, Int]())
    boostingStrategy.setTreeStrategy(treeStratery)
    val gbdtModel = GradientBoostedTrees.train(train, boostingStrategy)
    //    val gbdtModelDir = args(2)
    //    gbdtModel.save(sc, gbdtModelDir)
    val labelAndPreds = test.map { point =>
      val prediction = gbdtModel.predict(point.features)
      (point.label, prediction)
    }
    val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / test.count()
    println("Test Error = " + testErr)
    //    println("Learned classification GBT model:\n" + gbdtModel.toDebugString)

    val treeLeafArray = new Array[Array[Int]](numTrees)
    for (i <- 0.until(numTrees)) {
      treeLeafArray(i) = getLeafNodes(gbdtModel.trees(i).topNode)
    }
    for (i <- 0.until(numTrees)) {
      println("正在打印第%d棵树的topnode叶子节点", i)
      for (j <- 0.until(treeLeafArray(i).length)) {
        println(j)
      }

    }
    //    gbdt构造新特征
    val newFeatureDataSet = dataFrame.rdd.map { x =>
      (x(9).toString().toInt, new DenseVector(Array(x(1).toString().toInt, x(2).toString().toInt, x(3).toString().toInt,
        x(4).toString().toInt, x(5).toString().toInt, x(6).toString().toInt)))
    }.map { x =>
      var newFeature = new Array[Double](0)
      for (i <- 0.until(numTrees)) {
        val treePredict = predictModify(gbdtModel.trees(i).topNode, x._2)
        //gbdt tree is binary tree
        val treeArray = new Array[Double]((gbdtModel.trees(i).numNodes + 1) / 2)
        treeArray(treeLeafArray(i).indexOf(treePredict)) = 1
        newFeature = newFeature ++ treeArray
      }
      (x._1, newFeature)
    }
    val newData = newFeatureDataSet.map(x => LabeledPoint(x._1, new DenseVector(x._2)))
    val splits2 = newData.randomSplit(Array(0.8, 0.2))
    val train2 = splits2(0)
    val test2 = splits2(1)

    val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(train2).setThreshold(0.01)
    model.weights
    val predictionAndLabels = test2.map { case LabeledPoint(label, features) =>
      val prediction = model.predict(features)
      (prediction, label)
    }
    val metrics = new MulticlassMetrics(predictionAndLabels)
    val precision = metrics.accuracy
    println("Precision = " + precision)

    sc.stop()
  }
}
正文完
请博主喝杯咖啡吧!
post-qrcode
 1
admin
版权声明:本站原创文章,由 admin 2018-03-19发表,共计2378字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(2 条评论)
验证码
侧耳倾听 评论达人 LV.1
2019-02-02 12:25:46 回复

博主 您好 我运行您的这个代码时报错,我推测您是对数据做过预处理,您这个代码对应的数据能提供一下吗?谢谢

 Windows  Chrome  中国湖北省荆州市电信
    admin 博主
    2019-02-11 19:28:36 回复

    @侧耳倾听 这个用的就是周志华机器学习那本书中的西瓜数据集

     Macintosh  Chrome  中国广东省深圳市电信