Keras load_model with custom objects doesn't work properly
Your loss function's name is loss
(i.e. def loss(y_true, y_pred):
). Therefore, when loading back the model you need to specify 'loss'
as its name:
model = load_model(path, custom_objects={'loss': weighted_loss})
Comments
-
pafi almost 2 years
Setting
As already mentioned in the title, I got a problem with my custom loss function, when trying to load the saved model. My loss looks as follows:
def weighted_cross_entropy(weights): weights = K.variable(weights) def loss(y_true, y_pred): y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon()) loss = y_true * K.log(y_pred) * weights loss = -K.sum(loss, -1) return loss return loss weighted_loss = weighted_cross_entropy([0.1,0.9])
So during training, I used the
weighted_loss
function as loss function and everything worked well. When training is finished I save the model as.h5
file with the standardmodel.save
function from keras API.Problem
When I am trying to load the model via
model = load_model(path,custom_objects={"weighted_loss":weighted_loss})
I am getting a
ValueError
telling me that the loss is unknown.Error
The error message looks as follows:
File "...\predict.py", line 29, in my_script "weighted_loss": weighted_loss}) File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 419, in load_model model = _deserialize_model(f, custom_objects, compile) File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 312, in _deserialize_model sample_weight_mode=sample_weight_mode) File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\training.py", line 139, in compile loss_function = losses.get(loss) File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 133, in get return deserialize(identifier) File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 114, in deserialize printable_module_name='loss function') File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\utils\generic_utils.py", line 165, in deserialize_keras_object ':' + function_name) ValueError: Unknown loss function:loss
Questions
How can I fix this problem? May it be possible that the reason for that is my wrapped loss definition? So
keras
doesn't know, how to handle theweights
variable?