• Creates a copy of the initialization weights
  • Runs an iteration of gradient descent for a random task on the copy
  • Backpropagates the loss on a test set through the iteration of gradient descent and back to the initial weights, so that we can update the initial weights in a direction in which they would have been easier to update.

So basically it’s back propagagion throuth time, or BPTT, across different tasks.

def maml_sine(model, epochs, lr_inner=0.01, batch_size=1):
    optimizer = torch.optim.Adam(model.params())
    
    for _ in range(epochs):
        for i, t in enumerate(random.sample(SINE_TRAIN, len(SINE_TRAIN))):
            new_model = SineModel()
            new_model.copy(model, same_var=True)
            loss = sine_fit1(new_model, t, create_graph=not first_order)
            for name, param in new_model.named_params():
                grad = param.grad
                new_model.set_param(name, param - lr_inner * grad)
                        
            sine_fit1(new_model, t, force_new=True)
 
            if (i + 1) % batch_size == 0:
                optimizer.step()
                optimizer.zero_grad()