解决 Keras 的 Bad marshal data 错误

最近接手了一个机器学习项目,之前项目小组成员已经训练好了一个模型,但是后面可能由于运行环境的变化或者是其他种种问题,这个模型现在无法正常通过 Keras 的 load_model() 方法实例化,会报 bad marshal data 错误。

第一次接触 Keras,之前一直用的 Tensorflow / TFLearn,不是很清除这个问题产生的原因。就直接 Google 搜索报错信息,发现 Keras 的 Github issues 上面很早就有人提了有关这个报错的 issue:https://github.com/keras-team/keras/issues/7440

其中有个 comment 的原文如下:

This is not caused by Keras encoding of the model but by the serialization of the custom objects. So the only way of solving this is actually by re-creating the model with the custom objects in python 3.x, loading the weights and saving the model again.

从这个 comment 可以知道,这个错误主要是由于在网络中自定义了对象造成了 Keras 通过保存的模型文件反序列化的时候无法识别自定义的对象。

解决方法也很简单:

  1. 用代码构建原始的网络结构
  2. 通过 load_weights()方法把保存的模型中的 Weights 加载到构建好的网络中

具体可用的加载模型的代码如下:

1
2
3
4
5
# Create a model manually
model = create_model()
model.load_weights('path_to_model.h5')
# Do what you want to do with loaded model
model.predict(preprocessed_input)
Buy me a cup of coffee