Keras load_model with custom objects doesn't work properly

12,254

Your loss function's name is loss (i.e. def loss(y_true, y_pred):). Therefore, when loading back the model you need to specify 'loss' as its name:

model = load_model(path, custom_objects={'loss': weighted_loss})
Share:
12,254
pafi
Author by

pafi

ML Engineer with math background.

Updated on July 19, 2022

Comments

  • pafi
    pafi almost 2 years

    Setting

    As already mentioned in the title, I got a problem with my custom loss function, when trying to load the saved model. My loss looks as follows:

    def weighted_cross_entropy(weights):
    
        weights = K.variable(weights)
    
        def loss(y_true, y_pred):
            y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())
    
            loss = y_true * K.log(y_pred) * weights
            loss = -K.sum(loss, -1)
            return loss
    
        return loss
    
    weighted_loss = weighted_cross_entropy([0.1,0.9])
    

    So during training, I used the weighted_loss function as loss function and everything worked well. When training is finished I save the model as .h5file with the standard model.save function from keras API.

    Problem

    When I am trying to load the model via

    model = load_model(path,custom_objects={"weighted_loss":weighted_loss})
    

    I am getting a ValueError telling me that the loss is unknown.

    Error

    The error message looks as follows:

    File "...\predict.py", line 29, in my_script
    "weighted_loss": weighted_loss})
    File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 419, in load_model
    model = _deserialize_model(f, custom_objects, compile)
    File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 312, in _deserialize_model
    sample_weight_mode=sample_weight_mode)
    File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\training.py", line 139, in compile
    loss_function = losses.get(loss)
    File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 133, in get
    return deserialize(identifier)
    File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 114, in deserialize
    printable_module_name='loss function')
    File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\utils\generic_utils.py", line 165, in deserialize_keras_object
    ':' + function_name)
    ValueError: Unknown loss function:loss
    

    Questions

    How can I fix this problem? May it be possible that the reason for that is my wrapped loss definition? So keras doesn't know, how to handle the weights variable?