Note: I’ve uploaded a fully-functional self-contained Jupyter notebook example here if you’re more of a Shift+Enter
type of person.
With fastai v1, the concept of callbacks was introduced. The idea of callbacks is centered around modifying the training loop at runtime, without having to dig into the internals of a model.
However, that’s not all they’re good for. I found myself often having to drop somewhere into a training loop and see how the outputs were doing, what they looked like, what the loss function was doing etc. Turns out that callbacks are excellent for dropping set_trace()
s all around your runtime!
If anyone from fastai is reading this, I’m genuinely sorry for calling them “hooks” instead of “callbacks”, but I’m going to have to side with pyTorch here, it’s just a more frequent term in my experience. See Wikipedia on hooks, git hooks, webhooks etc.
An often-cited example of a fastai callback is the GradientClipping
callback (see the docs here). If you hadn’t heard of gradient clipping before, it essentially makes sure your gradients don’t “explode” by clipping them at the end of the backward pass of a batch. Here’s what it looks like:
class GradientClipping(LearnerCallback):
"Gradient clipping during training."
def __init__(self, learn:Learner, clip:float = 0.):
super().__init__(learn)
self.clip = clip
def on_backward_end(self, **kwargs):
"Clip the gradient before the optimizer step."
if self.clip: nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)
It’s worth standing still for a second and noticing how elegant the concept of callbacks allows us to be. Instead of having to subclass an entire Learner/model just so you can make a near trivial modification in the backward pass, we can just hook into the right place, throw down the relevant modification and call it a day.
As you’ll notice, at this stage, the docs are… a bit lacking (they only launched them a few months ago from scratch!), and the blogposts that provide a good introduction (I like this one and this one) don’t get concrete enough for my taste. Given that the gentle introduction is covered by the two posts above, let’s focus on concretely writing your very own custom fastai callback.
In order to get debugging with the callbacks, you’ll need to know three things.
set_trace()
in there and interactively play with the debugger.on_backward_end()
is called?) so you have a good feeling of where in the training loop they get called.on_loss_begin()
gets called, how do I access variables that are currently being passed around in the training loop? This is actually explained in the Callback
docs: you can unpack from **kwargs
any of the variables listed there.To take an example, I was writing a custom loss function, but I wasn’t certain what things looked like once the loss function would get called. So I figured I’d drop into the start of the loss function, inspect the variables, and take it from there.
Concretely:
set_trace()
from IPython.core.debugger
to have nice debugger colors in my Jupyter notebookon_loss_begin()
last_output
and last_target
of the model (which are the two things that get fed to the loss function)Here’s what that would look like:
class LossDebug(LearnerCallback):
def __init__(self, learn:Learner):
super().__init__(learn)
def on_loss_begin(self, last_output, last_target, **kwargs):
set_trace()
Some things to note:
LearnerCallback
and not the vanilla Callback
class. LearnerCallback
isn’t all that special, but it just gives me a couple of semantics, mostly a way to access the current Learner
object (so I can take a look at Learner.model
for example, or Learner.model.parameters()
which is sometimes useful). I can’t find anything on LearnerCallback
in the docs but I’m still learning to navigate those. Have a link to the source code of LearnerCallback
in the meantime. Again, don’t worry too much about the semantics, all you have to remember is that thanks to LearnerCallback
you now have access to self.learn
which can be nice.on_loss_begin()
and immediately unpack last_output
and last_target
because I’ll need them. I can of course also just dig into **kwargs
if I need anything else.All that’s left to do is adding my callback to the training loop:
learn = Learner(databunch, model, loss_func = customBCEWithLogitsFlat(axis=1))
learn.fit(1, callbacks=[LossDebug(learn)])
This should open up the debugger at just the right time.
%debug
though?It’s true that if you’re in a Jupyter environment, you could just run the fit()
function, let it crash, then drop right into the post mortem with the %debug
magic (seriously if you didn’t know this one it’ll blow your mind, read the %debug
docs). It’s just not as flexible/comfortable I find.