本文共 2376 字,大约阅读时间需要 7 分钟。
import numpy as npimport matplotlibimport matplotlib.pyplot as pltfrom sklearn import datasetsfrom sklearn.model_selection import train_test_splitfrom sklearn.neighbors import KNeighborsClassifierfrom sklearn.metrics import accuracy_score
1、获取数据
digits = datasets.load_digits()X = digits.datay = digits.target
2、分割数据,得到训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)
3、手动寻找
# def temp(): # knn_clf = KNeighborsClassifier(3) # knn_clf.fit(X_train, y_train) # y_predict = knn_clf.predict(X_test) # accuracy_score(y_test, y_predict) # # 寻找最好的k # best_score = 0.0 # best_k = -1 # for k in range(1,11): # knn_clf = KNeighborsClassifier(k) # knn_clf.fit(X_train, y_train) # y_predict = knn_clf.predict(X_test) # score= accuracy_score(y_test, y_predict) # if score > best_score: # best_k = k # best_score = score # print("best_k:", best_k) # print("best_score:", best_score) # # 考虑距离?不考虑距离? # best_method = "" # best_score = 0.0 # best_k = -1 # for method in ["uniform", "distance"]: # for k in range(1,11): # knn_clf = KNeighborsClassifier(n_neighbors=k, weights=method) # knn_clf.fit(X_train, y_train) # y_predict = knn_clf.predict(X_test) # score= accuracy_score(y_test, y_predict) # if score > best_score: # best_k = k # best_score = score # best_method = method # print("best_k:", best_k) # print("best_score:", best_score) # print("best_method:", best_method) # # # 探索明可夫斯基距离相应的p # # 寻找最好的超参数 Grid Search
3、超参数配置
param_grid = [ { "weights":["uniform"], "n_neighbors":[i for i in range(1,11)] }, { "weights":["distance"], "n_neighbors":[i for i in range(1,11)], "p":[i for i in range(1,6)] }]
4、实例化分类器
knn_clf = KNeighborsClassifier()
5、为分类器和超参数搭建模型
from sklearn.model_selection import GridSearchCVgrid_search = GridSearchCV(knn_clf, param_grid, n_jobs=-1, verbose=2)
6、实例化模型(多种参数配置的分类器)fit训练集
# 本质上是将训练集进一步分为训练集和测试集,得到最好的参数配置
# 因为要不断尝试各种参数交叉验证,所以非常耗时grid_search.fit(X_train, y_train)
7、最终拿到最佳参数配置分类器 best_estimator_
knn_clf = grid_search.best_estimator_
8、使用最佳分类器对测试集预测
y_predict = knn_clf.predict(X_test)
9、打印准确率
print(accuracy_score(y_test, y_predict))
转载地址:http://iwubz.baihongyu.com/