How to use k-fold cross validation in a neural network

51,363

Solution 1

You seem to be a bit confused (I remember I was too) so I am going to simplify things for you. ;)

Sample Neural Network Scenario

Whenever you are given a task such as devising a neural network you are often also given a sample dataset to use for training purposes. Let us assume you are training a simple neural network system Y = W · X where Y is the output computed from calculating the scalar product (·) of the weight vector W with a given sample vector X. Now, the naive way to go about this would be using the entire dataset of, say, 1000 samples to train the neural network. Assuming that the training converges and your weights stabilise you can then safely say that you network will correctly classify the training data. But what happens to the network if presented with previously unseen data? Clearly the purpose of such systems is to be able to generalise and correctly classify data other than the one used for training.

Overfitting Explained

In any real-world situation, however, previously-unseen/new data is only available once your neural network is deployed in a, let's call it, production environment. But since you have not tested it adequately you are probably going to have a bad time. :) The phenomenon by which any learning system matches its training set almost perfectly but constantly fails with unseen data is called overfitting.

The Three Sets

Here come in the validation and testing parts of the algorithm. Let's go back to the original dataset of 1000 samples. What you do is you split it into three sets -- training, validation and testing (Tr, Va and Te) -- using carefully selected proportions. (80-10-10)% is usually a good proportion, where:

  • Tr = 80%
  • Va = 10%
  • Te = 10%

Training and Validation

Now what happens is that the neural network is trained on the Tr set and its weights are correctly updated. The validation set Va is then used to compute the classification error E = M - Y using the weights resulting from the training, where M is the expected output vector taken from the validation set and Y is the computed output resulting from the classification (Y = W * X). If the error is higher than a user-defined threshold then the whole training-validation epoch is repeated. This training phase ends when the error computed using the validation set is deemed low enough.

Smart Training

Now, a smart ruse here is to randomly select which samples to use for training and validation from the total set Tr + Va at each epoch iteration. This ensures that the network will not over-fit the training set.

Testing

The testing set Te is then used to measure the performance of the network. This data is perfect for this purpose as it was never used throughout the training and validation phase. It is effectively a small set of previously unseen data, which is supposed to mimic what would happen once the network is deployed in the production environment.

The performance is again measured in term of classification error as explained above. The performance can also (or maybe even should) be measured in terms of precision and recall so as to know where and how the error occurs, but that's the topic for another Q&A.

Cross-Validation

Having understood this training-validation-testing mechanism, one can further strengthen the network against over-fitting by performing K-fold cross-validation. This is somewhat an evolution of the smart ruse I explained above. This technique involves performing K rounds of training-validation-testing on, different, non-overlapping, equally-proportioned Tr, Va and Te sets.

Given k = 10, for each value of K you will split your dataset into Tr+Va = 90% and Te = 10% and you will run the algorithm, recording the testing performance.

k = 10
for i in 1:k
     # Select unique training and testing datasets
     KFoldTraining <-- subset(Data)
     KFoldTesting <-- subset(Data)

     # Train and record performance
     KFoldPerformance[i] <-- SmartTrain(KFoldTraining, KFoldTesting)

# Compute overall performance
TotalPerformance <-- ComputePerformance(KFoldPerformance)

Overfitting Shown

I am taking the world-famous plot below from wikipedia to show how the validation set helps prevent overfitting. The training error, in blue, tends to decrease as the number of epochs increases: the network is therefore attempting to match the training set exactly. The validation error, in red, on the other hand follows a different, u-shaped profile. The minimum of the curve is when ideally the training should be stopped as this is the point at which the training and validation error are lowest.

Overfitting reduced by validating neural network

References

For more references this excellent book will give you both a sound knowledge of machine learning as well as several migraines. Up to you to decide if it's worth it. :)

Solution 2

  1. Divide your data into K non-overlapping folds. Have each fold K contain an equal number of items from each of the m classes (stratified cross-validation; if you have 100 items from class A and 50 from class B and you do 2 fold validation, each fold should contain a random 50 items from A and 25 from B).

    1. For i in 1..k:

      • Designate fold i the test fold
      • Designate one of the remaining k-1 folds the validation fold (this can either be random or a function of i, doesn't really matter)
      • Designate all remaining folds the training fold
      • Do a grid search for all free parameters (e.g. learning rate, # of neurons in hidden layer) training on your training data and computing loss on your validation data. Pick parameters minimising loss
      • Use the classifier with the winning parameters to evaluate test loss. Accumulate results

You have now collected aggregate results across all the folds. This is your final performance. If you're going to apply this for real, in the wild, use the best parameters from the grid search to train on all the data.

Share:
51,363
Ortixx
Author by

Ortixx

Updated on July 09, 2022

Comments

  • Ortixx
    Ortixx almost 2 years

    We are writing a small ANN which is supposed to categorize 7000 products into 7 classes based on 10 input variables.

    In order to do this we have to use k-fold cross validation but we are kind of confused.

    We have this excerpt from the presentation slide:

    k-fold cross validation diagram

    What are exactly the validation and test sets?

    From what we understand is that we run through the 3 training sets and adjust the weights (single epoch). Then what do we do with the validation? Because from what I understand is that the test set is used to get the error of the network.

    What happens next is also confusing to me. When does the crossover take place?

    If it's not too much to ask, a bullet list of step would be appreciated