共计 2378 个字符,预计需要花费 6 分钟才能阅读完成。
spark对python开放的接口实在是有限,只有scala是亲生的。查了下scala的包和函数,发现提供的真全,博主从零开始撸scala代码,边写边查的节奏,给出以下example代码给大家参考
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, FeatureType, Strategy}
import org.apache.spark.mllib.tree.model.Node
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
object gbdt_lr {
//get decision tree leaf's nodes
def getLeafNodes(node: Node): Array[Int] = {
var treeLeafNodes = new Array[Int](0)
if (node.isLeaf) {
treeLeafNodes = treeLeafNodes.:+(node.id)
} else {
treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.leftNode.get)
treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.rightNode.get)
}
treeLeafNodes
}
// predict decision tree leaf's node value
def predictModify(node: Node, features: DenseVector): Int = {
val split = node.split
if (node.isLeaf) {
node.id
} else {
if (split.get.featureType == FeatureType.Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
// println("Continuous left node")
predictModify(node.leftNode.get, features)
} else {
// println("Continuous right node")
predictModify(node.rightNode.get, features)
}
} else {
if (split.get.categories.contains(features(split.get.feature))) {
// println("Categorical left node")
predictModify(node.leftNode.get, features)
} else {
// println("Categorical right node")
predictModify(node.rightNode.get, features)
}
}
}
}
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("GbdtAndLr")
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
val sampleDir = "/Users/leiyang/IdeaProjects/spark_2.3/src/watermelon3_0_En.csv"
val sc = new SparkContext(sparkConf)
val spark = SparkSession.builder.config(sparkConf).getOrCreate()
val dataFrame = spark.read.format("CSV").option("header", "true").load(sampleDir)
val data = dataFrame.rdd.map { x =>
LabeledPoint(x(9).toString().toInt, new DenseVector(Array(x(1).toString().toInt, x(2).toString().toInt, x(3).toString().toInt,
x(4).toString().toInt, x(5).toString().toInt, x(6).toString().toInt)))
}
val splits = data.randomSplit(Array(0.8, 0.2))
val train = splits(0)
val test = splits(1)
//
// // GBDT Model
val numTrees = 2
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.setNumIterations(numTrees)
val treeStratery = Strategy.defaultStrategy("Classification")
treeStratery.setMaxDepth(5)
treeStratery.setNumClasses(2)
// treeStratery.setCategoricalFeaturesInfo(Map[Int, Int]())
boostingStrategy.setTreeStrategy(treeStratery)
val gbdtModel = GradientBoostedTrees.train(train, boostingStrategy)
// val gbdtModelDir = args(2)
// gbdtModel.save(sc, gbdtModelDir)
val labelAndPreds = test.map { point =>
val prediction = gbdtModel.predict(point.features)
(point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / test.count()
println("Test Error = " + testErr)
// println("Learned classification GBT model:\n" + gbdtModel.toDebugString)
val treeLeafArray = new Array[Array[Int]](numTrees)
for (i <- 0.until(numTrees)) {
treeLeafArray(i) = getLeafNodes(gbdtModel.trees(i).topNode)
}
for (i <- 0.until(numTrees)) {
println("正在打印第%d棵树的topnode叶子节点", i)
for (j <- 0.until(treeLeafArray(i).length)) {
println(j)
}
}
// gbdt构造新特征
val newFeatureDataSet = dataFrame.rdd.map { x =>
(x(9).toString().toInt, new DenseVector(Array(x(1).toString().toInt, x(2).toString().toInt, x(3).toString().toInt,
x(4).toString().toInt, x(5).toString().toInt, x(6).toString().toInt)))
}.map { x =>
var newFeature = new Array[Double](0)
for (i <- 0.until(numTrees)) {
val treePredict = predictModify(gbdtModel.trees(i).topNode, x._2)
//gbdt tree is binary tree
val treeArray = new Array[Double]((gbdtModel.trees(i).numNodes + 1) / 2)
treeArray(treeLeafArray(i).indexOf(treePredict)) = 1
newFeature = newFeature ++ treeArray
}
(x._1, newFeature)
}
val newData = newFeatureDataSet.map(x => LabeledPoint(x._1, new DenseVector(x._2)))
val splits2 = newData.randomSplit(Array(0.8, 0.2))
val train2 = splits2(0)
val test2 = splits2(1)
val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(train2).setThreshold(0.01)
model.weights
val predictionAndLabels = test2.map { case LabeledPoint(label, features) =>
val prediction = model.predict(features)
(prediction, label)
}
val metrics = new MulticlassMetrics(predictionAndLabels)
val precision = metrics.accuracy
println("Precision = " + precision)
sc.stop()
}
}
正文完
请博主喝杯咖啡吧!
博主 您好 我运行您的这个代码时报错,我推测您是对数据做过预处理,您这个代码对应的数据能提供一下吗?谢谢
@侧耳倾听 这个用的就是周志华机器学习那本书中的西瓜数据集