我们可以使用Python自带的pickle模块或sklearn的joblib模块将已训练好的模型持久化到本地。
第一种:使用pickle
将训练好的模型本地持久化
fromsklearn.neighborsimportKNeighborsClassifier
fromsklearnimportdatasets
importnumpyasnp
importpickle
iris=datasets.load_iris()
print(iris.data)
i=np.random.permutation(len(iris.data))
x_train=iris.data[i[:-20]]
y_train=iris.target[i[:-20]]
x_test=iris.data[i[-20:]]
y_tets=iris.target[i[-20:]]
model=KNeighborsClassifier()
model.fit(x_train,y_train)
s=pickle.dumps(model)
f=open('knn_testp.m','wb')
f.write(s)
f.close()
使用本地模型进行预测
importpickle
f=open('knn_testp.m','rb')
s=f.read()
model=pickle.loads(s)
print(model.predict([[5.1,3.5,1.4,0.2]]))
第二种:使用joblib
将训练好的模型本地持久化
fromsklearn.neighborsimportKNeighborsClassifier
fromsklearnimportdatasets
importnumpyasnp
fromsklearn.externalsimportjoblib
iris=datasets.load_iris()
print(iris.data)
i=np.random.permutation(len(iris.data))
x_train=iris.data[i[:-20]]
y_train=iris.target[i[:-20]]
x_test=iris.data[i[-20:]]
y_tets=iris.target[i[-20:]]
model=KNeighborsClassifier()
model.fit(x_train,y_train)
print(model.score(x_test,y_tets))
joblib.dump(model,'knn_test.m')
利用本地模型进行预测
fromsklearn.externalsimportjoblib
model=joblib.load('knn_test.m')
print(model.predict([[5.1,3.5,1.4,0.2]]))
当数据量比较大时,使用joblib将更加高效。
领取专属 10元无门槛券
私享最新 技术干货