2 min read

Quick rules to tune a neural net

I am listing here the rules I am using to tune a neural network.

1. Very high valid_loss

Total time: 00:13
epoch  train_loss  valid_loss  error_rate       
1      12.220007   1144.000000  0.765957    (00:13)

valid_loss is usually less than 1. When it is very big, it means that the learning rate is too high.

2. train_loss bigger than valid_loss

Total time: 00:14
epoch  train_loss  valid_loss  error_rate
1      0.602823    0.119616    0.049645    (00:14)

train_loss needs to be smaller than valid_loss. Otherwise, it means that we are not learning enough from the training set. To reach this point where the train_loss is smaller than the valid_loss we can increase the number of epochs and/or increase the learning rate.

3. error_rate improves slowly

Total time: 01:07
epoch  train_loss  valid_loss  error_rate
1      1.349151    1.062807    0.609929    (00:13)
2      1.373262    1.045115    0.546099    (00:13)

When the error_rate improves with the number of epochs but very slowly, the learning rate needs to be increased.

4. error_rate gets worse after a while

Total time: 06:39
epoch  train_loss  valid_loss  error_rate
1      1.513021    1.041628    0.507326    (00:13)
2      1.290093    0.994758    0.443223    (00:09)
3      1.185764    0.936145    0.410256    (00:09)
4      1.117229    0.838402    0.322344    (00:09)
5      1.022635    0.734872    0.252747    (00:09)
6      0.951374    0.627288    0.192308    (00:10)
7      0.916111    0.558621    0.184982    (00:09)
8      0.839068    0.503755    0.177656    (00:09)
9      0.749610    0.433475    0.144689    (00:09)
10     0.678583    0.367560    0.124542    (00:09)
11     0.615280    0.327029    0.100733    (00:10)
12     0.558776    0.298989    0.095238    (00:09)
13     0.518109    0.266998    0.084249    (00:09)
14     0.476290    0.257858    0.084249    (00:09)
15     0.436865    0.227299    0.067766    (00:09)
16     0.457189    0.236593    0.078755    (00:10)
17     0.420905    0.240185    0.080586    (00:10)
18     0.395686    0.255465    0.082418    (00:09)
19     0.373232    0.263469    0.080586    (00:09)

When the error_rate improves for a while and then starts getting worse again, it can be a sign of overfitting. In this case, you can either use less epochs, either use more regularization techniques.