机器学习实战(五)03-Spark-SVM

Ada ·
更新时间:2024-11-14
· 618 次阅读

官方文档示例

package com.netcloud.bigdata.mllib.com.svm.example import org.apache.spark.mllib.classification.SVMWithSGD import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.util.MLUtils import org.apache.spark.{SparkConf, SparkContext} /* * @Author: yangshaojun * @Date: 2020/02/17 20:57 * @Version 1.0 * 官方文档示例 */ object SVMWithSGDExample { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("SVMWithSGDExample").setMaster("local") val sc = new SparkContext(conf) // $example on$ // Load training data in LIBSVM format. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0).cache() val test = splits(1) val numIterations = 200 val model = SVMWithSGD.train(training, numIterations) model.clearThreshold() val scoreAndLabels = test.map { point => val score = model.predict(point.features) (score, point.label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() println(s"Area under ROC = $auROC") } }

自定义SVMWithSGD 参数

package com.netcloud.bigdata.mllib.com.svm.action import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.classification.SVMWithSGD import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.optimization.{HingeGradient, SquaredL2Updater} import org.apache.spark.mllib.util.MLUtils object SVMAction { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("SVMWithSGDExample").setMaster("local") val sc = new SparkContext(conf) // $example on$ // Load training data in LIBSVM format. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0).cache() val test = splits(1) val svm = new SVMWithSGD svm.setIntercept(false) svm.optimizer.setNumIterations(1000) svm.optimizer.setStepSize(1.0) svm.optimizer.setRegParam(0.01) svm.optimizer.setMiniBatchFraction(1.0) svm.optimizer.setConvergenceTol(0.001) svm.optimizer.setGradient(new HingeGradient())//new LeastSquaresGradient 或者 new LogisticGradient svm.optimizer.setUpdater(new SquaredL2Updater())//new SquaredL2Updater() 或 new L1Updater() val model=svm.run(training) model.clearThreshold() val scoreAndLabels = test.map { point => val score = model.predict(point.features) (score, point.label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() println(s"Area under ROC = $auROC") } }
作者:小飞猪666



机器学习实战 实战 学习 svm 机器学习 spark

需要 登录 后方可回复, 如果你还没有账号请 注册新账号