Note: I’ve uploaded a fully-functional self-contained Jupyter notebook example here if you’re more of a
Shift+Enter type of person.
With fastai v1, the concept of callbacks was introduced. The idea of callbacks is centered around modifying the training loop at runtime, without having to dig into the internals of a model.
However, that’s not all they’re good for. I found myself often having to drop somewhere into a training loop and see how the outputs were doing, what they looked like, what the loss function was doing etc. Turns out that callbacks are excellent for dropping
set_trace()s all around your runtime!
If anyone from fastai is reading this, I’m genuinely sorry for calling them “hooks” instead of “callbacks”, but I’m going to have to side with pyTorch here, it’s just a more frequent term in my experience. See Wikipedia on hooks, git hooks, webhooks etc.
An often-cited example of a fastai callback is the
GradientClipping callback (see the docs here). If you hadn’t heard of gradient clipping before, it essentially makes sure your gradients don’t “explode” by clipping them at the end of the backward pass of a batch. Here’s what it looks like:
class GradientClipping(LearnerCallback): "Gradient clipping during training." def __init__(self, learn:Learner, clip:float = 0.): super().__init__(learn) self.clip = clip def on_backward_end(self, **kwargs): "Clip the gradient before the optimizer step." if self.clip: nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)
It’s worth standing still for a second and noticing how elegant the concept of callbacks allows us to be. Instead of having to subclass an entire Learner/model just so you can make a near trivial modification in the backward pass, we can just hook into the right place, throw down the relevant modification and call it a day.
As you’ll notice, at this stage, the docs are… a bit lacking (they only launched them a few months ago from scratch!), and the blogposts that provide a good introduction (I like this one and this one) don’t get concrete enough for my taste. Given that the gentle introduction is covered by the two posts above, let’s focus on concretely writing your very own custom fastai callback.
In order to get debugging with the callbacks, you’ll need to know three things.
set_trace()in there and interactively play with the debugger.
on_backward_end()is called?) so you have a good feeling of where in the training loop they get called.
on_loss_begin()gets called, how do I access variables that are currently being passed around in the training loop? This is actually explained in the
Callbackdocs: you can unpack from
**kwargsany of the variables listed there.
To take an example, I was writing a custom loss function, but I wasn’t certain what things looked like once the loss function would get called. So I figured I’d drop into the start of the loss function, inspect the variables, and take it from there.
IPython.core.debuggerto have nice debugger colors in my Jupyter notebook
last_targetof the model (which are the two things that get fed to the loss function)
Here’s what that would look like:
class LossDebug(LearnerCallback): def __init__(self, learn:Learner): super().__init__(learn) def on_loss_begin(self, last_output, last_target, **kwargs): set_trace()
Some things to note:
LearnerCallbackand not the vanilla
LearnerCallbackisn’t all that special, but it just gives me a couple of semantics, mostly a way to access the current
Learnerobject (so I can take a look at
Learner.modelfor example, or
Learner.model.parameters()which is sometimes useful). I can’t find anything on
LearnerCallbackin the docs but I’m still learning to navigate those. Have a link to the source code of
LearnerCallbackin the meantime. Again, don’t worry too much about the semantics, all you have to remember is that thanks to
LearnerCallbackyou now have access to
self.learnwhich can be nice.
on_loss_begin()and immediately unpack
last_targetbecause I’ll need them. I can of course also just dig into
**kwargsif I need anything else.
All that’s left to do is adding my callback to the training loop:
learn = Learner(databunch, model, loss_func = customBCEWithLogitsFlat(axis=1)) learn.fit(1, callbacks=[LossDebug(learn)])
This should open up the debugger at just the right time.
It’s true that if you’re in a Jupyter environment, you could just run the
fit() function, let it crash, then drop right into the post mortem with the
%debug magic (seriously if you didn’t know this one it’ll blow your mind, read the
%debug docs). It’s just not as flexible/comfortable I find.