Python 实现 KNN 分类算法

Faye ·
更新时间:2024-09-21
· 742 次阅读

文章目录1. KNN1.1 KNN 分类算法步骤1.2 KNN 的优缺点2. python 实现

本文将详细讲述 KNN 算法及其 python 实现

1. KNN

KNN(K-Nearest Neighbour)即 K最近邻,是分类算法中最简单的算法之一。KNN 算法的核心思想是 如果一个样本在特征空间中的 k 个最相邻的样本中的大多数属于某一个类别,则将该样本归为该类别

1.1 KNN 分类算法步骤

有 N 个已知分类结果的样本点,对新纪录 r 使用 KNN 将其分类

1.确定 k 值,确定计算距离的公式,如常用欧氏距离 d(x,y)=∑i=1n(xi−yi)2d(x,y)=\sqrt{\displaystyle \sum^n_{i = 1}{{(x_i-y_i)}^2}}d(x,y)=i=1∑n​(xi​−yi​)2​ 2.计算 r 和其他样本点之间的距离 dird_{ir}dir​,其中 i∈(1,N)i\in(1,N)i∈(1,N) 3.得到与 r 最接近的 k 个样本 4.将 k 个样本中最多归属类别的分类标签赋予新纪录 r,分类结束 1.2 KNN 的优缺点

优点:

原理简单,容易理解,容易实现 重新训练代价较低 时间复杂度、空间复杂度取决于训练集(一般不会太大)

缺点:

KNN 属于 lazy-learning 算法(对于每一个新加入的预测点,都要从头开始计算与每个样本点的距离),得到的结果及时性差 k 值对结果影响较大 不同类记录相差较大时容易误判 样本点较多时,计算量较大 相对于决策树,结果可解释性不强 2. python 实现

已知分类如图所示(由于是随机产生,所以具体的样本点可能不一样)

其中顺时针依次是第1、2、3类,即红色是第 1 类,蓝色是第 2 类, 灰色是第 3 类

# coding=utf-8 """ @author: shenke @project: AITest @file: knn.py @date: 2020/2/26 @description: python 实现 KNN(K-最邻近)分类算法 """ import numpy as np import matplotlib.pyplot as plt from math import sqrt class KNN(): def __init__(self, k): self.k = k def generate_points(self, x_scope, y_scope, size): """ 产生给定范围内的二维坐标点 """ x = np.random.randint(x_scope[0], x_scope[1], size=size) y = np.random.randint(y_scope[0], y_scope[1], size=size) points = np.dstack((x, y))[0] return points def generate_data(self, size): """ 随机产生三个范围内的数据 """ points1 = self.generate_points([0, 8], [12, 20], size) labels1 = [1] * size points2 = self.generate_points([12, 20], [12, 20], size) labels2 = [2] * size points3 = self.generate_points([7, 13], [0, 8], size) labels3 = [3] * size plt.scatter(points1[:size, 0], points1[:size, 1], color='red') plt.scatter(points2[:size, 0], points2[:size, 1], color='blue') plt.scatter(points3[:size, 0], points3[:size, 1], color='gray') data = np.concatenate([points1, points2, points3]) label = np.concatenate([labels1, labels2, labels3]) return data, label def classify(self, target): """ 实现 KNN 分类 """ k = self.k # 设定每个类别中有 10 个样本点 data, label = self.generate_data(10) # 计算欧氏距离 distance = [sqrt(np.sum((target - point) ** 2)) for point in data] # 返回距离最近的 k 个样本的下标 k_index = np.argsort(distance)[:k] # 返回 k 个样本的标签 k_labels = [label[item] for item in k_index] # 返回 k 个样本中最多归属类别的分类标签 res = max(k_labels, key=k_labels.count) print('该目标点为:第 %d 类' % (res)) # 展示结果 # 标出距离最近的 k 个样本点 plt.scatter([data[index][0] for index in k_index], [data[index][1] for index in k_index], color='', marker='o', edgecolors='green', s=200) # 标出目标点 plt.scatter(target[0], target[1], color='green') plt.show()

测试

from algorithm import knn if __name__ == '__main__': # 设定 k 值为 4,预测点坐标为(10,10) knn.KNN(4).classify([10, 10])

预测结果

上图中标出了预测点(绿色)并圈出了与预测点距离最近的四个点,其中属于第 3 类的样本点个数最多,故预测该点属于第 3 类

但是由于 k 值对预测结果影响较大,可能对预测结果产生误判。如以下情况,四个点中属于第 1 类和第 3 类的样本点个数一样多,这时就无法准确判断出该点的类别


作者:一路是夜幕沉沙



分类算法 算法 Python 分类 knn

需要 登录 后方可回复, 如果你还没有账号请 注册新账号