重庆分公司,新征程启航
为企业提供网站建设、域名注册、服务器等服务
这期内容当中小编将会给大家带来有关使用Spark怎么实现一个随机森林,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。
专注于为中小企业提供成都网站制作、成都网站设计服务,电脑端+手机端+微信端的三站合一,更高效的管理,为中小企业天桥免费做网站提供优质的服务。我们立足成都,凝聚了一批互联网行业人才,有力地推动了超过千家企业的稳健成长,帮助中小企业通过网站建设实现规模扩充和转变。
具体内容如下
public class RandomForestClassficationTest extends TestCase implements Serializable { /** * */ private static final long serialVersionUID = 7802523720751354318L; class PredictResult implements Serializable{ /** * */ private static final long serialVersionUID = -168308887976477219L; double label; double prediction; public PredictResult(double label,double prediction){ this.label = label; this.prediction = prediction; } @Override public String toString(){ return this.label + " : " + this.prediction ; } } public void test_randomForest() throws JAXBException{ SparkConf sparkConf = new SparkConf(); sparkConf.setAppName("RandomForest"); sparkConf.setMaster("local"); SparkContext sc = new SparkContext(sparkConf); String dataPath = RandomForestClassficationTest.class.getResource("/").getPath() + "/sample_libsvm_data.txt"; RDD dataSet = MLUtils.loadLibSVMFile(sc, dataPath); RDD[] rddList = dataSet.randomSplit(new double[]{0.7,0.3},1); RDD trainingData = rddList[0]; RDD testData = rddList[1]; ClassTag labelPointClassTag = trainingData.elementClassTag(); JavaRDD trainingJavaData = new JavaRDD(trainingData,labelPointClassTag); int numClasses = 2; Map categoricalFeatureInfos = new HashMap(); int numTrees = 3; String featureSubsetStrategy = "auto"; String impurity = "gini"; int maxDepth = 4; int maxBins = 32; /** * 1 numClasses分类个数为2 * 2 numTrees 表示的是随机森林中树的个数 * 3 featureSubsetStrategy * 4 */ final RandomForestModel model = RandomForest.trainClassifier(trainingJavaData, numClasses, categoricalFeatureInfos, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, 1); JavaRDD testJavaData = new JavaRDD(testData,testData.elementClassTag()); JavaRDD predictRddResult = testJavaData.map(new Function(){ /** * */ private static final long serialVersionUID = 1L; public PredictResult call(LabeledPoint point) throws Exception { // TODO Auto-generated method stub double pointLabel = point.label(); double prediction = model.predict(point.features()); PredictResult result = new PredictResult(pointLabel,prediction); return result; } }); List predictResultList = predictRddResult.collect(); for(PredictResult result:predictResultList){ System.out.println(result.toString()); } System.out.println(model.toDebugString()); } }
得到的随机森林的展示结果如下:
TreeEnsembleModel classifier with 3 trees Tree 0: If (feature 435 <= 0.0) If (feature 516 <= 0.0) Predict: 0.0 Else (feature 516 > 0.0) Predict: 1.0 Else (feature 435 > 0.0) Predict: 1.0 Tree 1: If (feature 512 <= 0.0) Predict: 1.0 Else (feature 512 > 0.0) Predict: 0.0 Tree 2: If (feature 377 <= 1.0) Predict: 0.0 Else (feature 377 > 1.0) If (feature 455 <= 0.0) Predict: 1.0 Else (feature 455 > 0.0) Predict: 0.0
上述就是小编为大家分享的使用Spark怎么实现一个随机森林了,如果刚好有类似的疑惑,不妨参照上述分析进行理解。如果想知道更多相关知识,欢迎关注创新互联行业资讯频道。