基于sklearn的logistic回归对于鸢尾花的机器学习分类实践

Shams ·
更新时间:2024-11-10
· 794 次阅读

sklearn(scikit-learn)是python机器学习常用的第三方模块,是一个开源的机器学习库,它支持监督学习和非监督学习。它还为模型拟合、数据预处理、模型选择和评估以及许多其他实用工具提供了各种工具。sklearn对机器学习的常用算法进行了封装,包括回归、降维、分类、聚类等。对于以下的机器学习分类实践所用到的函数及方法进行说明。

1.np.c_[ ]和np.r_[ ]的用法解析

>>> import numpy as np >>> a=np.array([[1,2,3],[4,5,6]]) >>> a array([[1, 2, 3], [4, 5, 6]]) >>> b=np.array([[7,8,9],[10,11,12]]) >>> b array([[ 7, 8, 9], [10, 11, 12]]) >>> c=np.c_[a,b] >>> c array([[ 1, 2, 3, 7, 8, 9], [ 4, 5, 6, 10, 11, 12]]) >>> d=np.r_[a,b] >>> d array([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]])

np.c_[ ]是按照行将两个矩阵连接起来,而np.r_[ ]是按照列将两个矩阵连接起来。

2.numpy的扁平化函数ravel()
ravel()函数是将多维数组转换为一维数组

>>> a=np.array([[1,2,3],[4,5,6],[7,8,9]]) >>> a array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> b=a.ravel() >>> b array([1, 2, 3, 4, 5, 6, 7, 8, 9])

3.sys.stdout.write()方法与print()方法对比
print()在控制台打印我们要输出的内容时会在末尾自动加上换行符\n,然而write()方法是将输出内容写入输出流中,并且输出流会不断更新使得旧的内容被新的覆盖掉,最后只会得到程序中最后一次迭代的结果。

4.sklearn.model_selection的train_test_split方法
train_test_split()函数会根据输入数据和输出的类别标签自动将数据划分为训练数据集和测试数据集。X_train,X_test, y_train, y_test =cross_validation.train_test_split(train_data,train_target,test_size=0.3, random_state=0) x_train,x_test,y_train,y_test分别为划分出来的输入训练数据,输入测试数据,输出训练数据和输出测试数据。若test_size的值是0~1的数,那么该参数指的是测试样本数量占总样本数量的百分比;若test_size是大于1的正整数,那么参数指的是测试样本的数量。random_state指的是随机数种子。这里随机数种子即seed实际上代表着随机数的序号。该序号与每个随机数相对应,相当于随机数固定存放在数组中,而seed参数则相当于随机数在数组中的下标索引。如果设置了seed的值,则每次执行程序所产生的随机数或者随机序列均相等,即都为同一个随机数或者随机序列。原因是,每次执行程序都会产生同一个位置处(seed的值)的随机数或者随机序列。如果没有设置seed参数的取值,那么每次执行程序所产生的随机数或者随机序列均不等。

5.metrics.accuracy_score分类准确率分数计算函数
函数的官方文档说明如下所示:

def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None): """Accuracy classification score. In multilabel classification, this function computes subset accuracy: the set of labels predicted for a sample must *exactly* match the corresponding set of labels in y_true. Read more in the :ref:`User Guide `. Parameters ---------- y_true : 1d array-like, or label indicator array / sparse matrix Ground truth (correct) labels. y_pred : 1d array-like, or label indicator array / sparse matrix Predicted labels, as returned by a classifier. normalize : bool, optional (default=True) If ``False``, return the number of correctly classified samples. Otherwise, return the fraction of correctly classified samples. sample_weight : array-like of shape (n_samples,), default=None Sample weights. Returns ------- score : float If ``normalize == True``, return the fraction of correctly classified samples (float), else returns the number of correctly classified samples (int). The best performance is 1 with ``normalize == True`` and the number of samples with ``normalize == False``.

如果参数normalize为True,则函数返回正确分类样本数占总训练或测试样本总数的百分比,若为False则返回正确分类的样本数量。

这里我用的是sklearn库自带的鸢尾花数据集进行监督学习。通过from sklearn import datasets iris=datasets.load_iris() 导入鸢尾花数据集,通过print(dir(iris))dir()函数查看数据集的属性,得到含有6个属性的属性列表['DESCR', 'data', 'feature_names', 'filename', 'target', 'target_names']我们可以通过print(iris.DESCR)查看数据的特征描述,具体如下所示:

Iris plants dataset -------------------- **Data Set Characteristics:** :Number of Instances: 150 (50 in each of three classes) :Number of Attributes: 4 numeric, predictive attributes and the class :Attribute Information: - sepal length in cm - sepal width in cm - petal length in cm - petal width in cm - class: - Iris-Setosa - Iris-Versicolour - Iris-Virginica :Summary Statistics: ============== ==== ==== ======= ===== ==================== Min Max Mean SD Class Correlation ============== ==== ==== ======= ===== ==================== sepal length: 4.3 7.9 5.84 0.83 0.7826 sepal width: 2.0 4.4 3.05 0.43 -0.4194 petal length: 1.0 6.9 3.76 1.76 0.9490 (high!) petal width: 0.1 2.5 1.20 0.76 0.9565 (high!) ============== ==== ==== ======= ===== ==================== :Missing Attribute Values: None :Class Distribution: 33.3% for each of 3 classes. :Creator: R.A. Fisher :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov) :Date: July, 1988 The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken from Fisher's paper. Note that it's the same as in R, but not as in the UCI Machine Learning Repository, which has two wrong data points. This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field and is referenced frequently to this day. (See Duda & Hart, for example.) The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant. One class is linearly separable from the other 2; the latter are NOT linearly separable from each other. .. topic:: References - Fisher, R.A. "The use of multiple measurements in taxonomic problems" Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to Mathematical Statistics" (John Wiley, NY, 1950). - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis. (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218. - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System Structure and Classification Rule for Recognition in Partially Exposed Environments". IEEE Transactions on Pattern Analysis and Machine Intelligence, Vol. PAMI-2, No. 1, 67-71. - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions on Information Theory, May 1972, 431-433. - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II conceptual clustering system finds 3 classes in the data. - Many, many more ...

通过特征描述,我们知道共有150个数据样本,并且有三种类别标签,每种类别标签有50个数据样本。这150个数据样本除去类别标签共有四种数值型属性,即花萼的长度、花萼的宽度、花瓣的长度和花瓣的宽度,查看数据表格我们发现数据集除去表头是150行5列的数据表格。这里我们需要留意sepal(花萼)和petal(花瓣)的长度和宽度的最小和最大值,因为在绘制决策边界时需要用到这些值。三种分类标签分别为0、1、2,分别代表三种鸢尾花(setosa、versicolor、virginica),可以通过print(iris.target_names)查看。
源代码如下所示:

from sklearn import datasets import matplotlib.pyplot as plt from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split import numpy as np import sys from sklearn import metrics def show_data_set(x,y,attr,data): plt.plot(x[y==0,0],x[y==0,1],'rs',label=data.target_names[0]) plt.plot(x[y==1,0],x[y==1,1],'bx',label=data.target_names[1]) plt.plot(x[y==2,0],x[y==2,1],'go',label=data.target_names[2]) if attr=="sepal": plt.xlabel(data.feature_names[0]) plt.ylabel(data.feature_names[1]) plt.title("三种鸢尾花的花萼长度和宽度") else: plt.xlabel(data.feature_names[2]) plt.ylabel(data.feature_names[3]) plt.title("三种鸢尾花的花瓣长度和宽度") plt.legend() plt.rcParams["font.sans-serif"] = ["KaiTi"] plt.rcParams["axes.unicode_minus"] = False plt.show() def plot_data(x,y,attr): plt.plot(x[y==0,0],x[y==0,1],'rs',label='setosa') plt.plot(x[y==1,0],x[y==1,1],'bx',label='versicolor') plt.plot(x[y==2,0],x[y==2,1],'go',label='virginica') if attr=="sepal": plt.xlabel("sepal length (cm)") plt.ylabel("sepal width (cm)") plt.title("三种鸢尾花的花萼长度和宽度分类结果") else: plt.xlabel("petal length (cm)") plt.ylabel("petal width (cm)") plt.title("三种鸢尾花的花瓣长度和宽度分类结果") plt.legend() plt.rcParams["font.sans-serif"] = ["KaiTi"] plt.rcParams["axes.unicode_minus"] = False plt.show() def plot_decision_boundary(x_min, x_max, y_min, y_max, pred_func): h = 0.01 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) Z = pred_func(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral) def test(x_train,x_test,y_train,y_test,multi_class="ovr",solver="newton-cg",attr="sepal"): #实例化分类器,solver为优化算法,这里采用newton-cg(改进的牛顿法),常用的还有lbfgs(拟牛顿法改进,不需要直接求解海森矩阵),\ # liblinear(类似线性核分类器)\ # sag(随机平均梯度法),saga(加速版本的随机平均梯度法) log_reg=LogisticRegression(multi_class=multi_class,solver=solver) log_reg.fit(x_train,y_train) #根据训练数据适应模型 predict_train=log_reg.predict(x_train) #根据训练样本预测类别 #计算分类准确率分数(正确分类训练样本数占总训练样本数的百分比) sys.stdout.write("LogisticRegression(multi_class=%s,solver=%s) Train Accuracy:%.4f\n"%\ (multi_class,solver,metrics.accuracy_score(y_train,predict_train,normalize=False))) predict_test=log_reg.predict(x_test) #根据测试样本预测类别 #计算分类准确率分数(正确分类测试样本数占总测试样本数的百分比) sys.stdout.write("LogisticRegression(multi_class=%s,solver=%s) Test Accuracy:%.4f\n"%\ (multi_class,solver,metrics.accuracy_score(y_test,predict_test))) if attr=="sepal": plot_decision_boundary(4,8.5,1.5,4.8, lambda x: log_reg.predict(x)) #绘制决策边界 else: plot_decision_boundary(0.5,7.5,0,3, lambda x: log_reg.predict(x)) #绘制决策边界 plot_data(x_train,y_train,attr) if __name__=="__main__": iris=datasets.load_iris() print(iris.target_names) x_sepal=iris.data[:,:2] x_petal=iris.data[:,2:4] y=iris.target show_data_set(x_sepal,y,"sepal",iris) show_data_set(x_petal,y,"petal",iris) x_strain,x_stest,y_strain,y_stest=train_test_split(x_sepal,y,test_size=0.25,random_state=1) x_ptrain,x_ptest,y_ptrain,y_ptest=train_test_split(x_petal,y,test_size=0.25,random_state=1) test(x_strain,x_stest,y_strain,y_stest,multi_class="ovr",solver="newton-cg",attr="sepal") test(x_ptrain,x_ptest,y_ptrain,y_ptest,multi_class="ovr",solver="newton-cg",attr="petal")

以下是采用multi_class="ovr"一对多分类器模型对数据进行分类的结果,关于花萼和花瓣训练和测试数据中,正确分类的样本数占总样本的百分比如下所示:
LogisticRegression(multi_class=ovr,solver=newton-cg) Train Accuracy:0.8036
LogisticRegression(multi_class=ovr,solver=newton-cg) Test Accuracy:0.7632
LogisticRegression(multi_class=ovr,solver=newton-cg) Train Accuracy:0.9643
LogisticRegression(multi_class=ovr,solver=newton-cg) Test Accuracy:0.9211

对于LogisticRegression()方法中的multi_class除了’ovr’还有另外一个值,即’multinomial’多项式回归模型。‘ovr’(one-vs-rest),也叫做one-vs-all,是指对于n个类别,有n个二项分类器,所谓二项分类器是指利用二次曲面对空间即特征维数大于2的情况下的数据进行分类的模型。每个分类器针对其中一个数据类别和剩余类别进行分类,在进行预测时,利用这n个二项分类器进行分类,得到数据属于当前类的概率,选择其中概率最大的一个类别作为最终的预测结果。而’multinomial’多项式回归模型是基于交叉熵来判断最终的预测结果
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
以下是采用multi_class="multinomial"多项式分类器模型对数据进行分类的结果,关于花萼和花瓣训练和测试数据中,正确分类的样本数占总样本的百分比如下所示:
LogisticRegression(multi_class=multinomial,solver=newton-cg) Train Accuracy:0.8214
LogisticRegression(multi_class=multinomial,solver=newton-cg) Test Accuracy:0.7895
LogisticRegression(multi_class=multinomial,solver=newton-cg) Train Accuracy:0.9732
LogisticRegression(multi_class=multinomial,solver=newton-cg) Test Accuracy:0.9737

在这里插入图片描述
在这里插入图片描述
对比两种分类器的预测结果,发现对于二维数据的分类预测,还是多项式回归模型的精确度更高一些。
参考:
基于sklearn的LogisticRegression鸢尾花多类分类实践
one-vs-rest与one-vs-one以及sklearn的实现
sklearn.linear_model.LogisticRegression
如何理解SAG,SVRG,SAGA三种优化算法


作者:Legolas~



鸢尾花 学习 分类 机器学习

需要 登录 后方可回复, 如果你还没有账号请 注册新账号