欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 名人名企 > 机器学习(模型的保存和加载)

机器学习(模型的保存和加载)

2026/5/15 8:10:50 来源:https://blog.csdn.net/mqsdfghjk/article/details/145861536  浏览:    关键词:机器学习(模型的保存和加载)

在机器学习中,模型训练通常需要耗费大量的时间和计算资源。为了避免重复训练,同时方便在不同环境中使用已经训练好的模型,我们需要对模型进行保存和加载。以下将介绍几种常见的模型保存与加载的方法,以 scikit - learn 和 TensorFlow 模型为例。

1. 使用 joblib 保存和加载 scikit - learn 模型

joblib 是 scikit - learn 推荐的用于保存和加载模型的工具,它在处理大型 numpy 数组时比 Python 内置的 pickle 模块更高效。

保存模型 

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from joblib import dump# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 训练模型
model = KNeighborsClassifier()
model.fit(X_train, y_train)# 保存模型
dump(model, 'knn_model.joblib')

 加载模型

from joblib import load# 加载模型
loaded_model = load('knn_model.joblib')# 使用加载的模型进行预测
predictions = loaded_model.predict(X_test)
print(predictions)

2. 使用 pickle 保存和加载 scikit - learn 模型

pickle 是 Python 内置的用于对象序列化和反序列化的模块,也可以用于保存和加载 scikit - learn 模型。

保存模型

import pickle
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 训练模型
model = KNeighborsClassifier()
model.fit(X_train, y_train)# 保存模型
with open('knn_model.pkl', 'wb') as f:pickle.dump(model, f)

 加载模型

import pickle# 加载模型
with open('knn_model.pkl', 'rb') as f:loaded_model = pickle.load(f)# 使用加载的模型进行预测
predictions = loaded_model.predict(X_test)
print(predictions)

3. 使用 TensorFlow 保存和加载深度学习模型

保存模型

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 数据预处理
x_train = x_train / 255.0
x_test = x_test / 255.0# 构建模型
model = Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5)# 保存模型
model.save('mnist_model.h5')

加载模型

import tensorflow as tf
from tensorflow.keras.datasets import mnist# 加载数据集
(_, _), (x_test, y_test) = mnist.load_data()# 数据预处理
x_test = x_test / 255.0# 加载模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')# 使用加载的模型进行预测
predictions = loaded_model.predict(x_test)
print(predictions)

总结

  • 对于 scikit - learn 模型,推荐使用 joblib 进行保存和加载,尤其是处理大型 numpy 数组时。
  • pickle 是 Python 内置的通用序列化工具,也可以用于保存和加载 scikit - learn 模型。
  • 对于 TensorFlow 深度学习模型,可以使用 model.save() 方法保存为 .h5 文件,使用 tf.keras.models.load_model() 方法加载模型

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词