Using fastai's callbacks for debugging

Note: I’ve uploaded a fully-functional self-contained Jupyter notebook example here if you’re more of a Shift+Enter type of person.

With fastai v1, the concept of callbacks was introduced. The idea of callbacks is centered around modifying the training loop at runtime, without having to dig into the internals of a model.
However, that’s not all they’re good for. I found myself often having to drop somewhere into a training loop and see how the outputs were doing, what they looked like, what the loss function was doing etc. Turns out that callbacks are excellent for dropping set_trace()s all around your runtime!

If anyone from fastai is reading this, I’m genuinely sorry for calling them “hooks” instead of “callbacks”, but I’m going to have to side with pyTorch here, it’s just a more frequent term in my experience. See Wikipedia on hooks, git hooks, webhooks etc.

A not so gentle introduction

An often-cited example of a fastai callback is the GradientClipping callback (see the docs here). If you hadn’t heard of gradient clipping before, it essentially makes sure your gradients don’t “explode” by clipping them at the end of the backward pass of a batch. Here’s what it looks like:

class GradientClipping(LearnerCallback):
    "Gradient clipping during training."
    def __init__(self, learn:Learner, clip:float = 0.):
        super().__init__(learn)
        self.clip = clip

    def on_backward_end(self, **kwargs):
        "Clip the gradient before the optimizer step."
        if self.clip: nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)

It’s worth standing still for a second and noticing how elegant the concept of callbacks allows us to be. Instead of having to subclass an entire Learner/model just so you can make a near trivial modification in the backward pass, we can just hook into the right place, throw down the relevant modification and call it a day.

As you’ll notice, at this stage, the docs are… a bit lacking (they only launched them a few months ago from scratch!), and the blogposts that provide a good introduction (I like this one and this one) don’t get concrete enough for my taste. Given that the gentle introduction is covered by the two posts above, let’s focus on concretely writing your very own custom fastai callback.

What you’ll need to know

In order to get debugging with the callbacks, you’ll need to know three things.

  • How do you want to debug? My preferred approach is to just slam a set_trace() in there and interactively play with the debugger.
  • Where/when do you want to hook? At what stage of the training loop do you want to intercept/modify/analyse the surroundings? For inspiration, check out the fastai docs on which hooks are available. The hooks are transparently named (can you guess when on_backward_end() is called?) so you have a good feeling of where in the training loop they get called.
  • What environment variables will be needed? Figuring this one out was the hardest for me. If I drop into a debugging session right when on_loss_begin() gets called, how do I access variables that are currently being passed around in the training loop? This is actually explained in the Callback docs: you can unpack from **kwargs any of the variables listed there.

Debugging the loss function with a callback

To take an example, I was writing a custom loss function, but I wasn’t certain what things looked like once the loss function would get called. So I figured I’d drop into the start of the loss function, inspect the variables, and take it from there.

Concretely:

  • I’d use set_trace() from IPython.core.debugger to have nice debugger colors in my Jupyter notebook
  • I’d hook into the start of the loss function using on_loss_begin()
  • I’d want to look at the last_output and last_target of the model (which are the two things that get fed to the loss function)

Here’s what that would look like:

class LossDebug(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        
    def on_loss_begin(self, last_output, last_target, **kwargs):
        set_trace()

Some things to note:

  • I’m subclassing LearnerCallback and not the vanilla Callback class. LearnerCallback isn’t all that special, but it just gives me a couple of semantics, mostly a way to access the current Learner object (so I can take a look at Learner.model for example, or Learner.model.parameters() which is sometimes useful). I can’t find anything on LearnerCallback in the docs but I’m still learning to navigate those. Have a link to the source code of LearnerCallback in the meantime. Again, don’t worry too much about the semantics, all you have to remember is that thanks to LearnerCallback you now have access to self.learn which can be nice.
  • I’m dropping into on_loss_begin() and immediately unpack last_output and last_target because I’ll need them. I can of course also just dig into **kwargs if I need anything else.

All that’s left to do is adding my callback to the training loop:

learn = Learner(databunch, model, loss_func = customBCEWithLogitsFlat(axis=1))
learn.fit(1, callbacks=[LossDebug(learn)])

This should open up the debugger at just the right time.

What about %debug though?

It’s true that if you’re in a Jupyter environment, you could just run the fit() function, let it crash, then drop right into the post mortem with the %debug magic (seriously if you didn’t know this one it’ll blow your mind, read the %debug docs). It’s just not as flexible/comfortable I find.