本文将详细讲述 KNN 算法及其 python 实现
1. KNNKNN(K-Nearest Neighbour)即 K最近邻,是分类算法中最简单的算法之一。KNN 算法的核心思想是 如果一个样本在特征空间中的 k 个最相邻的样本中的大多数属于某一个类别,则将该样本归为该类别
1.1 KNN 分类算法步骤有 N 个已知分类结果的样本点,对新纪录 r
使用 KNN 将其分类
k
值,确定计算距离的公式,如常用欧氏距离 d(x,y)=∑i=1n(xi−yi)2d(x,y)=\sqrt{\displaystyle \sum^n_{i = 1}{{(x_i-y_i)}^2}}d(x,y)=i=1∑n(xi−yi)2
2.计算 r
和其他样本点之间的距离 dird_{ir}dir,其中 i∈(1,N)i\in(1,N)i∈(1,N)
3.得到与 r
最接近的 k
个样本
4.将 k
个样本中最多归属类别的分类标签赋予新纪录 r
,分类结束
1.2 KNN 的优缺点
优点:
原理简单,容易理解,容易实现 重新训练代价较低 时间复杂度、空间复杂度取决于训练集(一般不会太大)缺点:
KNN 属于 lazy-learning 算法(对于每一个新加入的预测点,都要从头开始计算与每个样本点的距离),得到的结果及时性差 k 值对结果影响较大 不同类记录相差较大时容易误判 样本点较多时,计算量较大 相对于决策树,结果可解释性不强 2. python 实现已知分类如图所示(由于是随机产生,所以具体的样本点可能不一样)
其中顺时针依次是第1、2、3类,即红色是第 1 类,蓝色是第 2 类, 灰色是第 3 类
# coding=utf-8
"""
@author: shenke
@project: AITest
@file: knn.py
@date: 2020/2/26
@description: python 实现 KNN(K-最邻近)分类算法
"""
import numpy as np
import matplotlib.pyplot as plt
from math import sqrt
class KNN():
def __init__(self, k):
self.k = k
def generate_points(self, x_scope, y_scope, size):
"""
产生给定范围内的二维坐标点
"""
x = np.random.randint(x_scope[0], x_scope[1], size=size)
y = np.random.randint(y_scope[0], y_scope[1], size=size)
points = np.dstack((x, y))[0]
return points
def generate_data(self, size):
"""
随机产生三个范围内的数据
"""
points1 = self.generate_points([0, 8], [12, 20], size)
labels1 = [1] * size
points2 = self.generate_points([12, 20], [12, 20], size)
labels2 = [2] * size
points3 = self.generate_points([7, 13], [0, 8], size)
labels3 = [3] * size
plt.scatter(points1[:size, 0], points1[:size, 1], color='red')
plt.scatter(points2[:size, 0], points2[:size, 1], color='blue')
plt.scatter(points3[:size, 0], points3[:size, 1], color='gray')
data = np.concatenate([points1, points2, points3])
label = np.concatenate([labels1, labels2, labels3])
return data, label
def classify(self, target):
"""
实现 KNN 分类
"""
k = self.k
# 设定每个类别中有 10 个样本点
data, label = self.generate_data(10)
# 计算欧氏距离
distance = [sqrt(np.sum((target - point) ** 2)) for point in data]
# 返回距离最近的 k 个样本的下标
k_index = np.argsort(distance)[:k]
# 返回 k 个样本的标签
k_labels = [label[item] for item in k_index]
# 返回 k 个样本中最多归属类别的分类标签
res = max(k_labels, key=k_labels.count)
print('该目标点为:第 %d 类' % (res))
# 展示结果
# 标出距离最近的 k 个样本点
plt.scatter([data[index][0] for index in k_index], [data[index][1] for index in k_index], color='', marker='o',
edgecolors='green', s=200)
# 标出目标点
plt.scatter(target[0], target[1], color='green')
plt.show()
测试
from algorithm import knn
if __name__ == '__main__':
# 设定 k 值为 4,预测点坐标为(10,10)
knn.KNN(4).classify([10, 10])
预测结果
上图中标出了预测点(绿色)并圈出了与预测点距离最近的四个点,其中属于第 3 类的样本点个数最多,故预测该点属于第 3 类
但是由于 k 值对预测结果影响较大,可能对预测结果产生误判。如以下情况,四个点中属于第 1 类和第 3 类的样本点个数一样多,这时就无法准确判断出该点的类别