首页 技术 正文
技术 2022年11月20日
0 收藏 700 点赞 4,071 浏览 6875 个字

【机器学*】k-*邻算法(kNN) 学*笔记

标签(空格分隔): 机器学*


kNN简介

kNN算法是做分类问题的。思想如下:

KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:

  1. 计算测试数据与各个训练数据之间的距离;
  2. 按照距离的递增关系进行排序;
  3. 选取距离最小的K个点;
  4. 确定前K个点所在类别的出现频率;
  5. 返回前K个点中出现频率最高的类别作为测试数据的预测分类。

更为详细的介绍见这个博客:机器学*(一)——K-*邻(KNN)算法
kNN的优缺点见:KNN算法理解
这个博客的内容来自《机器学*实战》一书。

这个博客主要讲解kNN的python实现,把每行的代码都弄明白。

kNN代码实现

下面classify0()就是kNN,这些代码做了对一个点的分类。

# coding=utf-8
import operator
from numpy import *def createDataSet():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labelsdef classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
# 矩阵有一个shape属性,是一个(行,列)形式的元组
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
# 输入的点到每个点的横纵坐标差
# tile是把矩阵重复多次
sqDiffMat = diffMat ** 2
# 横纵坐标差的平方
sqDistances = sqDiffMat.sum(axis=1)
# axis=0, 表示列。axis=1, 表示行。
distances = sqDistances ** 0.5
# 开方
sortedDistIndicies = distances.argsort()
# argsort函数返回的是数组值从小到大的索引值
classCount = {}
# 保存A,B出现次数的字典
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
# 获取索引值对应的是A还是B
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# 在字典中保存A,B出现的次数
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
# 按照A,B出现的次数排序
return sortedClassCount[0][0] # 返回A,B出现最多的那个group, labels = createDataSet()
answer = classify0([0, 0], group, labels, 3)
print answer

关于tile()函数,可以见文章:【python】tile函数简单介绍
关于sorted()函数:

sorted函数sorted(iterable, cmp=None, key=None, reverse=False)
iterable:是可迭代类型;
cmp:用于比较的函数,比较什么由key决定;
key:用列表元素的某个属性或函数进行作为关键字,有默认值,迭代集合中的一项;
operator.itemgetter(1)表示用第2个数据项排序
reverse:排序规则. reverse = True 降序 或者 reverse = False 升序,有默认值。

kNN实战一 改进约会网站配对效果

我只给出代码和每行代码的解释,这个实战项目的更具体介绍见:机器学*(一)——K-*邻(KNN)算法

# coding=utf-8
import operator
from numpy import *
import matplotlib
import matplotlib.pyplot as pltdef classify0(inX, dataSet, labels, k):
"""
:param inX: 样本点
:param dataSet: 初始样本集合
:param labels: 样本集合对应的标签集合
:param k: 选取的k
:return: kNN分类结果
"""
dataSetSize = dataSet.shape[0]
# 矩阵有一个shape属性,是一个(行,列)形式的元组
diffMat = tile(inX, (dataSetSize, 1)) - dataSet # 输入的点到每个点的横纵坐标差
# tile是把矩阵重复多次
sqDiffMat = diffMat ** 2 # 横纵坐标差的平方
sqDistances = sqDiffMat.sum(axis=1) # axis=0, 表示列。axis=1, 表示行。
distances = sqDistances ** 0.5 # 开方
sortedDistIndicies = distances.argsort() # argsort函数返回的是数组值从小到大的索引值
classCount = {} # 保存A,B出现次数的字典
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]] # 获取索引值对应的是A还是B
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 # 在字典中保存A,B出现的次数
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
# 按照A,B出现的次数排序
# sorted函数sorted(iterable, cmp=None, key=None, reverse=False)
'''
iterable:是可迭代类型;
cmp:用于比较的函数,比较什么由key决定;
key:用列表元素的某个属性或函数进行作为关键字,有默认值,迭代集合中的一项;
operator.itemgetter(1)表示用第2个数据项排序
reverse:排序规则. reverse = True 降序 或者 reverse = False 升序,有默认值。
'''
return sortedClassCount[0][0] # 返回A,B出现最多的那个group, labels = createDataSet()
answer = classify0([0, 0], group, labels, 3)
print answerdef file2matrix(filename):
"""
:param filename: 文件名称
:return: 文件中的数据和标签
"""
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines)
# 获取文件的行数
returnMat = zeros((numberOfLines, 3))
# 创建返回的NumPy矩阵,二维矩阵
# zeros函数功能是创建给定类型的矩阵,并初始化为0
classLabelVector = []
# 创建返回的标签
index = 0 # index
for line in arrayOLines:
# 循环每列
line = line.strip()
# 去除每行回车字符
listFromLine = line.split('\t')
# 分割
returnMat[index, :] = listFromLine[0:3]
# 把数据的前三列都放到要返回的矩阵中,3这个索引是不包括的
classLabelVector.append(int(listFromLine[-1]))
# 把数据的每列最后一个元素转换成整数放到标签list里
index += 1
# index自增
return returnMat, classLabelVectordef autoNorm(dataSet):
"""
:param dataSet: 数据集
:return: 归一化结果
"""
minVals = dataSet.min(0) # 0代表列
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet)) # 创建了行列数与dataSet一致的全0矩阵
m = dataSet.shape[0] # 行数
normDataSet = dataSet - tile(minVals, (m, 1)) # 每个元素都减去该列最小值
normDataSet = normDataSet / tile(ranges, (m, 1)) # 具体数值的除,归一化;不是矩阵相除
return normDataSet, ranges, minValsdef datingClassTest():
"""
测试算法的函数
:return:
"""
hoRatio = 0.10 # hold out 10%
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') # load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat) # 归一化
m = normMat.shape[0] # 行数
numTestVecs = int(m * hoRatio) # 抽出的行数
print "numTestVecs", numTestVecs
errorCount = 0.0 # 错误率
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
if classifierResult != datingLabels[i]:
errorCount += 1.0
print "the total error rate is: %f" % (errorCount / float(numTestVecs))
print errorCountdef classifyPerson():
"""
用户输入点作为测试点
:return: 无
"""
resultList = ['not at all', 'in small doses', 'in large doses']
percentTats = float(raw_input("percentage of time spent playing video games?"))
ffMiles = float(raw_input("frequent flier miles earned per year?"))
iceCream = float(raw_input("liters of ice cream consumed per year?"))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') # load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat) # 归一化
inArr = array([ffMiles, percentTats, iceCream]) # 把用户输出的点当做要求点
classifierResult = classify0((inArr - minVals) / ranges, normMat, datingLabels, 3) # 用kNN做分类
print "you weill probably like this person:", resultList[classifierResult - 1] # 转换成真名datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:, 0], datingDataMat[:, 1], 15.0 * array(datingLabels), 15.0 * array(datingLabels))
plt.show()
datingClassTest()
classifyPerson()

上述程序将把数据画出图来,然后计算kNN判断准确率,并且最后让用户输入数据,对该数据进行分类。

kNN实战二:手写体识别

数据集下载: http://www.ituring.com.cn/book/download/0019ab9d-0fda-4c17-941b-afe639fcccac

def img2vector(filename):
returnVect = zeros((1, 1024)) # 每个文件一行结果,1024个0
fr = open(filename) # 打开文件
for i in range(32): # 遍历32行
lineStr = fr.readline() # 读行
for j in range(32): # 读每个字符
returnVect[0, 32 * i + j] = int(lineStr[j]) # 把字符放到结果中
return returnVectdef handwritingClassTest():
hwLabels = [] # 保存标签
trainingFileList = listdir('digits/trainingDigits') # 加载训练集
m = len(trainingFileList) # 训练集文件个数
trainingMat = zeros((m, 1024)) # 训练集数据矩阵
for i in range(m): # 遍历文件
fileNameStr = trainingFileList[i] # 训练集文件名
fileStr = fileNameStr.split('.')[0] # 去掉文件名结尾的.txt
classNumStr = int(fileStr.split('_')[0]) # 把文件名分割之后,获得前半部分,即这个文件表示的字符标签
hwLabels.append(classNumStr) # 把文件表示的字符标签放到标签list中
trainingMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)
# 把每个文件中的字符画转成行向量
testFileList = listdir('digits/testDigits') # 得到测试集所有文件目录
errorCount = 0.0 # 错误率
mTest = len(testFileList) # 测试集长度
for i in range(mTest): # 遍历测试集
fileNameStr = testFileList[i] # 测试集文件名
fileStr = fileNameStr.split('.')[0] # 去掉文件名结尾的.txt
classNumStr = int(fileStr.split('_')[0]) # 把文件名分割之后,获得前半部分,即这个文件表示的字符标签
vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
# 把每个文件中的字符画转成行向量
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) # 做分类
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
if classifierResult != classNumStr: errorCount += 1.0
print "\nthe total number of errors is: %d" % errorCount
print "\nthe total error rate is: %f" % (errorCount / float(mTest))handwritingClassTest()

这个算法的执行效率并不高,每个测试向量做2000次的距离计算,每个距离计算包括了1024个维度浮点计算,总计要执行900次。

最后运行结果,错误率1.2%:

the total number of errors is: 11
the total error rate is: 0.011628

kNN是分类问题最有效最简单的算法,但是要保存全部数据集,对每个数据计算距离值。实际使用很耗时。而且无法给出任何数据的基础结构信息,无法知晓平均实例样本和典型实例样本之间具有什么特征。

这篇博客是对《机器学*实战》一书的学*笔记,如有不明白之处,请阅读该书。

相关推荐
python开发_常用的python模块及安装方法
adodb:我们领导推荐的数据库连接组件bsddb3:BerkeleyDB的连接组件Cheetah-1.0:我比较喜欢这个版本的cheeta…
日期:2022-11-24 点赞:878 阅读:9,076
Educational Codeforces Round 11 C. Hard Process 二分
C. Hard Process题目连接:http://www.codeforces.com/contest/660/problem/CDes…
日期:2022-11-24 点赞:807 阅读:5,552
下载Ubuntn 17.04 内核源代码
zengkefu@server1:/usr/src$ uname -aLinux server1 4.10.0-19-generic #21…
日期:2022-11-24 点赞:569 阅读:6,400
可用Active Desktop Calendar V7.86 注册码序列号
可用Active Desktop Calendar V7.86 注册码序列号Name: www.greendown.cn Code: &nb…
日期:2022-11-24 点赞:733 阅读:6,176
Android调用系统相机、自定义相机、处理大图片
Android调用系统相机和自定义相机实例本博文主要是介绍了android上使用相机进行拍照并显示的两种方式,并且由于涉及到要把拍到的照片显…
日期:2022-11-24 点赞:512 阅读:7,812
Struts的使用
一、Struts2的获取  Struts的官方网站为:http://struts.apache.org/  下载完Struts2的jar包,…
日期:2022-11-24 点赞:671 阅读:4,894