首页 技术 正文
技术 2022年11月13日
0 收藏 468 点赞 2,473 浏览 5126 个字

 

一、概念

决策树及其集合是分类和回归的机器学习任务的流行方法。 决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征交互。 诸如随机森林和增强的树集合算法是分类和回归任务的最佳表现者。

决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。

二、基本原理

决策树学习通常包含三个方面:特征选择、决策树生成和决策树剪枝。决策树学习思想主要来源于:Quinlan在1986年提出的ID算法、在1993年提出的C4.5算法和Breiman等人在1984年提出的CART算法。

2.1、特征选择

特征选择在于选取对训练数据具有分类能力的特征,这样可以提高决策树学习的效率。通常特征选择的准则是信息增益(或信息增益比、基尼指数等),每次计算每个特征的信息增益,并比较它们的大小,选择信息增益最大(信息增益比最大、基尼指数最小)的特征。

那么问题来了:怎么找到这样的最优划分特征呢?如何来衡量最优?

什么是最优特征,通俗的理解是对训练数据具有很强的分类能力的特征,比如要看相亲的男女是否合适,他们的年龄差这个特征就远比他们的出生地重要,因为年龄差能更好得对相亲是否成功这个分类问题具有更强的分类能力。

但是计算机并不知道哪些特征是最优的,因此,就要找一个衡量特征是不是最优的指标,使得决策树在每一个分支上的数据尽可能属于同一类别的数据,即样本纯度最高。

我们用熵来衡量样本集合的纯度。

这是概率统计与信息论中的一个概念,定义为:

 

其中p(x)=pi表示随机变量X发生概率。

我们可以从两个角度理解这个概念。

第一就是不确定度的一个度量,我们的目标是为了找到一颗树,使得每个分枝上都代表一个分类,也就是说我们希望这个分枝上的不确定性最小,即确定性最大,也就是这些数据都是同一个类别的。熵越小,代表这些数据是同一类别的越多。

第二个角度就是从纯度理解。因为熵是不确定度的度量,如果他们不确定度越小,意味着这个群体的差异很小,也就是它的纯度很高。比如,在明大的某富翁聚会上,来的人大多是某总,普通工薪白领就会很少,如果新来了一个刘总,他是富翁的确定性就很大,不确定性就很小,同时这个群体的纯度很大。总结来说就是熵越小,纯度越大,而我们希望的就是纯度越大越好。

信息增益

我们用信息熵来衡量一个分支的纯度,以及哪个特征是最优的特征
在决策树学习中应用信息增益准则来选择最优特征。信息增益定义如下:

 信息增益

特征A对训练数据集D的信息增益g(D,A) 等于D的不确定度H(D) 减去给定条件A下D的不确定度H(D|A),可以理解为由于特征A使得对数据集D的分类的不确定性减少的程度,信息增益大的特征具有更强的分类能力。

信息增益率

信息增益选择特征倾向于选择取值较多的特征,假设某个属性存在大量的不同值,决策树在选择属性时,将偏向于选择该属性,但这肯定是不正确(导致过拟合)的。因此有必要使用一种更好的方法,那就是信息增益率(Info Gain Ratio)来矫正这一问题。
其公式为:

 信息增益率

其中

 训练数据集D关于特征A的值的熵

,n为特征A取值的个数

基尼指数

概率分布的基尼指数定义为

 基尼指数

其中K表示分类问题中类别的个数

2.2、决策树的生成

从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点,再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增均很小或没有特征可以选择为止,最后得到一个决策树。

决策树需要有停止条件来终止其生长的过程。一般来说最低的条件是:当该节点下面的所有记录都属于同一类,或者当所有的记录属性都具有相同的值时。这两种条件是停止决策树的必要条件,也是最低的条件。在实际运用中一般希望决策树提前停止生长,限定叶节点包含的最低数据量,以防止由于过度生长造成的过拟合问题。

2.3、决策树的剪枝

决策树生成只考虑了通过信息增益或信息增益比来对训练数据更好的拟合,但没有考虑到如果模型过于复杂,会导致过拟合的产生。而剪枝就是缓解过拟合的一种手段,单纯的决策树生成学习局部的模型,而剪枝后的决策树会生成学习整体的模型,因为剪枝的过程中,通过最小化损失函数,可以平衡决策树的对训练数据的拟合程度和整个模型的复杂度。

决策树的损失函数定义如下:

 损失函数

其中,

 图1

三、代码实现

我们以iris数据集(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。

3.1、读取数据

首先,读取文本文件;然后,通过map将每行的数据用“,”隔开,在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类。把这里我们用LabeledPoint来存储标签列和特征列。LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。所以,我们把莺尾花的分类进行了一下改变,”Iris-setosa”对应分类0,”Iris-versicolor”对应分类1,其余对应分类2;然后获取莺尾花的4个特征,存储在Vector中。

import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import scala.Tuple2;
SparkConf conf = new SparkConf().setAppName("decisionTree").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);/**
* 读取数据
* 转化成 LabeledPoint类型
*/
JavaRDD<String> source = sc.textFile("data/mllib/iris.data");
JavaRDD<LabeledPoint> data = source.map(line->{
String[] parts = line.split(",");
double label = 0.0;
if(parts[4].equals("Iris-setosa")) {
label = 0.0;
}else if(parts[4].equals("Iris-versicolor")) {
label = 1.0;
}else {
label = 2.0;
}
return new LabeledPoint(label,Vectors.dense(Double.parseDouble(parts[0]),
Double.parseDouble(parts[1]),
Double.parseDouble(parts[2]),
Double.parseDouble(parts[3])));
});

3.2、划分数据集

接下来,首先进行数据集的划分,这里划分70%的训练集和30%的测试集:

JavaRDD<LabeledPoint>[] splits =  data.randomSplit(new double[] {0.7,0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

3.3、构建模型

调用决策树的trainClassifier方法构建决策树模型,设置参数,比如分类数、信息增益的选择、树的最大深度等:

int numClasses = 3;//分类数
int maxDepth = 5; //树的最大深度
int maxBins = 30;//离散连续特征时使用的bin数。增加maxBins允许算法考虑更多的分割候选者并进行细粒度的分割决策。
String impurity = "gini";
Map<Integer,Integer> categoricalFeaturesInfo = new HashMap<Integer,Integer>();//空的categoricalFeaturesInfo表示所有功能都是连续的。
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);

3.4、模型预测

接下来我们调用决策树模型的predict方法对测试数据集进行预测,并把模型结构打印出来:

JavaPairRDD<Double, Double> predictionAndLabel =  testData.mapToPair(point->{
return new Tuple2<>(model.predict(point.features()),point.label());
});
//打印预测和实际结果
predictionAndLabel.foreach(x->{
System.out.println("predictionAndLabel:"+x);
});
System.out.println("Learned classification tree model:"+model.toDebugString());
/**
*控制台输出结果:
-----------------------
Learned classification tree model:DecisionTreeModel classifier of depth 5 with 15 nodes
If (feature 2 <= 2.45)
Predict: 0.0
Else (feature 2 > 2.45)
If (feature 2 <= 4.75)
Predict: 1.0
Else (feature 2 > 4.75)
If (feature 2 <= 4.95)
If (feature 0 <= 6.25)
If (feature 1 <= 3.05)
Predict: 2.0
Else (feature 1 > 3.05)
Predict: 1.0
Else (feature 0 > 6.25)
Predict: 1.0
Else (feature 2 > 4.95)
If (feature 3 <= 1.7000000000000002)
If (feature 0 <= 6.05)
Predict: 1.0
Else (feature 0 > 6.05)
Predict: 2.0
Else (feature 3 > 1.7000000000000002)
Predict: 2.0
------------------------
**/

3.5、准确性评估

最后,我们把模型预测的准确性打印出来:

double testErr = predictionAndLabel.filter(pl ->  !pl._1().equals(pl._2())).count() / (double)  testData.count();
System.out.println("Test Error:"+testErr);
/**
*控制台输出结果:
------------------------------
Test Error:0.06976744186046512
------------------------------
**/
上一篇: lvm 日常操作。
下一篇: css 箭头三角形
相关推荐
python开发_常用的python模块及安装方法
adodb:我们领导推荐的数据库连接组件bsddb3:BerkeleyDB的连接组件Cheetah-1.0:我比较喜欢这个版本的cheeta…
日期:2022-11-24 点赞:878 阅读:9,088
Educational Codeforces Round 11 C. Hard Process 二分
C. Hard Process题目连接:http://www.codeforces.com/contest/660/problem/CDes…
日期:2022-11-24 点赞:807 阅读:5,564
下载Ubuntn 17.04 内核源代码
zengkefu@server1:/usr/src$ uname -aLinux server1 4.10.0-19-generic #21…
日期:2022-11-24 点赞:569 阅读:6,412
可用Active Desktop Calendar V7.86 注册码序列号
可用Active Desktop Calendar V7.86 注册码序列号Name: www.greendown.cn Code: &nb…
日期:2022-11-24 点赞:733 阅读:6,185
Android调用系统相机、自定义相机、处理大图片
Android调用系统相机和自定义相机实例本博文主要是介绍了android上使用相机进行拍照并显示的两种方式,并且由于涉及到要把拍到的照片显…
日期:2022-11-24 点赞:512 阅读:7,822
Struts的使用
一、Struts2的获取  Struts的官方网站为:http://struts.apache.org/  下载完Struts2的jar包,…
日期:2022-11-24 点赞:671 阅读:4,905