41
loading...
This website collects cookies to deliver better user experience
if current validation loss lower than candidate validation loss:
save model to disk overwriting previous candidate
set candidate validation loss to current validation loss
fastai simplifies training fast and accurate neural nets using modern best practices
learner.fit_one_cycle(... ,cbs=[..., SaveModelCallback(monitor='valid_loss')])
learner
is a standard fastai Learner object. By default, the callback will track the validation loss to determine when to save a new best model. Use the monitor
argument to set it to any other metric tracked by your learner
object. Following each epoch during training, the current value for the target metric is compared to the previous best value - if it is an improvement, the model is persisted in the models
directory (and overwriting the previous best candidate, if present). loss
or error
) or a larger value (everything else). This behavior can be overridden using the comp
argument. The model is persisted using fastai's save_model
function, which is a wrapper for Pytorch's native torch.save
.SaveModelCallback
that will log all metrics tracked by fastai during training. The code for can be found here.last_saved_metadata
) associated with the best model. How to make use of this? All is to be revealed in the next section!WandbCallback
. To use it, one need to initialize a W&B run, and to add the callback to the learner object like so:# Import W&B package
import wandb
# Initialize W&B run (can potentially set project name, run name, etc...)
wandb.init()
# Add Callback to learner to track training metrics and log best models
learn = learner(..., cbs=WandbCallback())
SaveModelCallback
-- at the end of the training process, the best performing model will be automatically logged as an artifact of the W&B run.WandbCallback
: the metadata associated with the model is recorded at the end of the run and not at the epoch when the best model was saved. In other words, the metadata does not correspond to the saved model at all, and can be misleading (for example when the tracked metric diverged towards the end of training due to overfitting).SaveModelCallback
that was discussed in the previous section comes in. It will save all the information needed to associate the model with its actual metadata. To take advantage of this, it is also necessary to use a custom version of WandbCallback
, which can be found here.def after_fit(self):
if self.log_model:
if self.save_model.last_saved_path is None:
print('WandbCallback could not retrieve a model to upload')
else:
log_model(self.save_model.last_saved_path, metadata=self.save_model.last_saved_metadata)
for metadata_key in self.save_model.last_saved_metadata:
wandb.run.summary[f'best_{metadata_key}'] = self.save_model.last_saved_metadata[metadata_key]``
best_
. This allows runs to be sorted and compared based on the performance of their respective best model
best_matthews_corrcoef
metadata associated with their respective best models