博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
机器学习实战第二章——KNN算法(源码解析)
阅读量:3977 次
发布时间:2019-05-24

本文共 5737 字,大约阅读时间需要 19 分钟。

机器学习实战中的内容讲的都比较清楚,一般都能看懂,这里就不再讲述了,这里主要是对代码进行解析,如果你很熟悉python,这个可以不用看。

#coding=utf-8'''Created on 2015年12月29日@author: admin'''from numpy import arrayfrom numpy import tilefrom numpy import zerosimport operatorfrom os import listdir# 创建数据集,并返回数据集和分类标签def createDataSet():    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])    labels = ['A','A','B','B']    return group,labels# 对新数据进行分类def classify0(inX,dataSet,labels,k):    #dataSet.shape[0]是dataSet第一维的数目    dataSetSize = dataSet.shape[0]     #要分类的新数据与原始数据做差    diffMat = tile(inX,(dataSetSize,1)) - dataSet    #求差的平方    sqDiffMat = diffMat**2    #求差的平方的和    sqDistance = sqDiffMat.sum(axis=1)     #求标准差    distances = sqDistance**0.5    #距离排序    sortDistIndicies = distances.argsort()     #定义元字典    classCount = {}     #遍历前k个元素    for i in range(k):        #获得前k个元素的标签        voteIlabel = labels[sortDistIndicies[i]]        #计算前k个数据标签出现的次数        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1     #对得到的标签字典按降序排列    sortedClassCount =sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True)    #返回出现次数最多的标签    return sortedClassCount[0][0]# 读取文本文件中的数据def file2matrix(filename):    # 打开文件    fr = open(filename)    # 计算文本文件的行数    numberOfLines = len(fr.readlines())    # 创建返回的数据矩阵    returnMat = zeros((numberOfLines,3))    # 创建类标签    classLabelVector = []    # 打开文件    fr = open(filename)    # 定义索引    index = 0    # 读取文件的每一行并处理    for line in fr.readlines():        # 去除行的尾部的换行符        line = line.strip()        # 将一行数据按空进行分割        listFromLine = line.split('\t')        # 0:3列为数据集的数据        returnMat[index,:] = listFromLine[0:3]        # 最后一列为数据的分类标签        classLabelVector.append(int(listFromLine[-1]))        # 索引加1        index += 1    # 返回数据集和对应的类标签    return returnMat,classLabelVector# 归一化函数def autoNorm(dataSet):    # 求数据矩阵每一列的最小值    minVals = dataSet.min(0)    # 求数据矩阵每一列的最大值    maxVals = dataSet.max(0)    # 求数据矩阵每一列的最大最小值差值    ranges = maxVals - minVals#    normDataSet = zeros(shape(dataSet))    # 返回数据矩阵第一维的数目    m = dataSet.shape[0]    # 求矩阵每一列减去该列最小值,得出差值    normDataSet = dataSet - tile(minVals,(m,1))    # 用求的差值除以最大最小值差值,即数据的变化范围,即归一化    normDataSet = normDataSet / tile(ranges,(m,1))    # 返回归一化后的数据,最大最小值差值,最小值    return normDataSet,ranges,minVals# 分类器测试函数def datingClassTest():    # 测试集所占的比例    hoRatio = 0.10    # 从文件中读取数据    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')    # 对数据进行归一化    normMat,ranges,minVals = autoNorm(datingDataMat)    # 求数据的条数    m = normMat.shape[0]    # 求测试集的数据数目    numTestVecs = int(m*hoRatio)    # 定义误判数目    errorCount = 0.0    # 对测试数据进行遍历    for i in range(numTestVecs):        # 对每一条数据进行分类        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],2)        # 输出分类结果和实际的类别        print "the classifer came back with: %d,the real answer is: %d" %(classifierResult,datingLabels[i])        # 如果分类结果与实际结果不一致        if(classifierResult != datingLabels[i]):            # 误分类数加1            errorCount += 1.0    # 输出错误率    print "the total error rate is: %f" %(errorCount/float(numTestVecs))# 对人分类def classiyPerson():    # 定义分类结果的类别    resultList = ['not at all','in small doses','in large doses']    # 读取输入数据    percentTats = float(raw_input("percentage of time spent playing video games?"))    # 读取输入数据    ffMiles = float(raw_input("frequent flier miles earned per year?"))    # 读取输入数据    iceCream = float(raw_input("liters of ice cream consumed per year?"))    # 从文件中读取已有数据    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')    # 对数据进行归一化    normMat,ranges,minVals = autoNorm(datingDataMat)    # 将单个输入数据定义成一条数据    inArr =array([ffMiles,percentTats,iceCream])    # 对输入数据进行分类    classifierResult = classify0(inArr,datingDataMat,datingLabels,3)    # 输出预测的分类类别    print "You will probably like this person:",resultList[classifierResult - 1]# 将单个手写字符文件变成向量 def img2vector(filename):    # 定义要返回的向量    returnVect = zeros((1,1024))    # 打开文件    fr = open(filename)    # 遍历文件中的每一行和每一列    for i in range(32):        # 读取一行        lineStr = fr.readline()        # 对读取数据赋值到returnVect中        for j in range(32):            returnVect[0,32*i+j] = int(lineStr[j])    # 返回向量    return returnVect# 手写字符识别测试def handwritingClassTest():    # 定义手写字符标签(类别)    hwLabels = []    # 列出目录下所有的文件    trainingFileList = listdir('digits/trainingDigits')    # 计算文件的数目    m = len(trainingFileList)    # 定义手写字符数据矩阵    trainingMat = zeros((m,1024))    # 依次读取每个文件    for i in range(m):        # 定义文件名        fileNameStr = trainingFileList[i]        # 对文件名进行分割        fileStr = fileNameStr.split('.')[0]        # 获得文件名中的类标签        classNumStr = int(fileStr.split('_')[0])        # 把类标签放到hwLabels中        hwLabels.append(classNumStr)        # 把文件变成向量并赋值到trainingMat中        trainingMat[i,:] = img2vector('digits/trainingDigits/%s' % fileNameStr)    # 列出测试目录下的所有文件    testFileList = listdir('digits/testDigits')    # 定义错误率    errorCount = 0.0    # 定义测试文件数目    mTest = len(testFileList)    # 遍历测试    for i in range(mTest):        # 定义测试文件名        fileNameStr = testFileList[i]        # 对测试文件名进行分割        fileStr = fileNameStr.split('.')[0]        # 获得测试文件的类标签        classNumStr = int(fileStr.split('_')[0])        # 将测试文件转换成向量        vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)        # 进行分类        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)        # 输出预测类别和实际类别        print "the classifier came back with:%d,the real answer is %d" % (classifierResult,classNumStr)        # 如果二者不一致,累加错误数量        if(classifierResult != classNumStr):            errorCount += 1.0    # 输出分类错误的数目    print "\nthe total number of errors is:%d" % errorCount    # 输出分类的错误率    print "\nthe total error rate is:%f" % (errorCount/float(mTest))

转载地址:http://swwui.baihongyu.com/

你可能感兴趣的文章
CentOS6.8二进制安装MySQL5.6
查看>>
centos 6x系统下源码安装mysql操作记录
查看>>
Centos搭建Mysql主从复制
查看>>
centos下部署redis服务环境及其配置说明
查看>>
Centos7下部署两套python版本并存环境的操作记录
查看>>
利用阿里云的源yum方式安装Mongodb
查看>>
Mysql的二进制日志binlog的模式说明
查看>>
zabbix监控交换机、防火墙等网络设备
查看>>
Redis数据"丢失"讨论及规避和解决的几点总结
查看>>
Redis日常操作命令小结
查看>>
线程安全的单例模式
查看>>
fastjson深度源码解析- 序列化(五) - json内部注册序列化解析
查看>>
fastjson深度源码解析- 序列化(六) - json特定序列化实现解析
查看>>
fastjson深度源码解析- 词法和语法解析(二) - 基础类型实现解析
查看>>
fastjson深度源码解析- 词法和语法解析(三) - 针对对象实现解析
查看>>
fastjson深度源码解析- 反序列化(一) - 反序列化解析介绍
查看>>
fastjson深度源码解析- 反序列化(二) - 内部注册反序列化解析
查看>>
通过爱效率网站获取百度统计数据说明
查看>>
百度统计接口调用——登录接口
查看>>
百度统计接口调用——获取站点列表
查看>>