Training a Wrapped Model¶
An important step in any machine learning workflow is model training. The transformations that Capsa applies to your model, will add new parameters or modify the usage of existing parameters in a way that requires further training.
In some cases where the model was already trained before applying Capsa, minimal training may be sufficient to acquire useful risk values. However, in general, it is best to wrap models before training, or resume training after wrapping a model.
Fortunately, Capsa is designed to integrate seamlessly with your existing training pipelines.
Initializing an Optimizer¶
In PyTorch, typically to initialize an optimizer, users pass in a list of the models parameters, alongside other hyperparameters for the optimizer.
opt = torch.optim.Adam(model.parameters(), lr=0.001, ...)
When wrapping with Capsa, it is important to remember that Capsa will add additional parameters to your model. This process occurs the first time you call your Capsa wrapped model with inputs. Future calls to the wrapped model will then reuse these added parameters. Due to this it is necessary to call your Capsa wrapped model once before initializing the optimizer.
wrapped_model = wrapper(model)
_ = wrapped_model(x) # Call wrapped model with inputs
# It is now safe to initialize the optimizer
opt = torch.optim.Adam(wrapped_model.parameters(), lr=0.001, ...)
Warning
Failure to follow this approach will prevent your model from training correctly.
When wrapping with Capsa, it is important to remember that Capsa will add additional parameters to your model.
This process occurs the first time you call your Capsa wrapped model.
Future calls to the wrapped model will then reuse these added parameters.
Due to this it is necessary to call your Capsa wrapped model at least once before accessing its trainable_variables
attribute.
wrapped_model = wrapper(model)
opt = tf.optimizers.Adam(learning_rate=1e-4)
with tf.GradientTape() as tape:
# Call wrapped model with inputs
out = wrapped_model(x)
loss = loss_fn(out, ...)
# Calling .trainable_variables on a Capsa wrapped model
# returns a list of two elements, where:
# - first element contains flat list of Capsa added variables
# - second element contains flat list of original model variables
capsa_vars, user_vars = wrapped_model.trainable_variables
# For convenience, unpack these into a single flat list
trainable_vars = [*capsa_vars, *user_vars]
# Alternative syntax:
# trainable_vars = tape.watched_variables()
# Compute gradient of loss wrt all trainable_variables of the wrapped_model
gradients = tape.gradient(loss, trainable_vars)
# Update wrapped_model's trainable_variables with the computed gradient
opt.apply_gradients(zip(gradients, trainable_vars))
Warning
Failure to follow this approach will prevent your model from training correctly.
For an end-to-end training example see Sample Wrapper: GAN.
Saving model weights¶
In Capsa-Torch model saving works as usual in PyTorch. After initializing, wrapping, and training your model simply call
torch.save(model.state_dict(), PATH)
as usual to save your models state dict to disk.
In Capsa-Tensorflow – after initializing, wrapping, and training your model – you can use the following to save your model’s weights to disk:
with open(PATH, "wb") as f:
pickle.dump(wrapped_model.trainable_variables, f)
Loading model weights¶
Loading the model is similarly easy. The only thing to note is that if you are loading wrapped model weights, you must also wrap the model before loading them.
wrapped_model = wrapper(model)
wrapped_model.load_state_dict(torch.load(PATH))
To load weights of a wrapped model into a new model:
wrap the new model using the same Capsa wrapper that was used with your original model
then use
set_capsa_variables
passing the loaded Capsa variables
with open(PATH, "rb") as f:
restored_capsa_vars, restored_user_vars = pickle.load(f)
restored_wrapped_model = wrapper(model)
restored_wrapped_model.set_capsa_variables(restored_capsa_vars)
# will reuse previously created capsa_variables
out, risk = restored_wrapped_model(x, return_risk=True)