Limitations¶
Warning
Capsa wheels are only available on linux, amd64, CPython 3.8 - 3.11 If you would like to use Capsa in other environments, please let us know.
Current Limitations
We are working to reduce these limitations in future versions of capsa_torch
.
PyTorch version
capsa-torch
requires torch>=2.1.0
PyTorch Lightning
We have designed capsa-torch
to work with pytorch-lightning
but this functionality hasn’t been thoroughly tested yet and not all features are guaranteed to work.
In-place Modifications
Modules that perform in-place modifications on their weights/buffers during the forward call are not officially supported and may not behave as expected.
class InPlaceMod(torch.nn.Module):
...
def forward(self, x):
x.mul_(2) # In-place modification of input is ok
x = self.layer(x)
self.forward_count += 1 # In-place modification of module attribute (not supported)
return x
Warning
If you need to store this kind of data during the forward pass, modify the module’s forward
function to take forward_count
as an argument.
Non-Tensor code in the Module
Our approach works best with Tensor-only modules. That is, modules that take Tensors as inputs and output Tensors (ideally floating point Tensors). Non-tensor code may prevent the module’s from wrapping or produce unexpected results.
class BadExampleModule(torch.nn.Module):
...
def forward(self, input_prompt: str):
# Steps like encoding an input string into a Tensor value should be done
# outside the wrapped forward function
encoded: torch.Tensor = self.encode_str(input_prompt)
...
# print statements and other non-Tensor code may interfere with wrapping
print(out)
...
# Similar to above, decoding/postprocess of output Tensors should be done
# outside of the wrapped forward function.
out_str = self.decode(out)
top_ind = torch.argmax(out)
# Our wrappers work best with models that return floating point unprocessed
# outputs. Operations like argmax or even last layer softmax / sigmoid
# activations should be done outside the torch module.
return out_str, top_ind
Nested Wrapping and Sub-modules
Currently we only support using a single wrapper at a time (let us know if wrapper composition is a feature you’re interested in).
This wrapper only wraps the forward
function of a single torch.nn.Module
. If your module has other methods that call the forward function, then these other methods will use the wrapped forward logic. However, if you wrap a module that contains sub-modules, only the forward function of the outer module will be modified.
class OuterModule(torch.nn.Module):
def __init__(self):
self.inner_mod = InnerModule()
self.linear = torch.nn.Linear()
def forward(self, x):
x = self.inner_mod(x)
x = self.linear(x)
x = self.helper_fn(x)
return x
def add_3_and_call(self, x):
return self.forward(x + 3)
def helper_fn(self, x):
return x * 5
mod = OuterModule()
wrapper = vote.Wrapper()
wrapped_mod = wrapper(mod)
x = torch.randn(2, 5)
# Calls Capsa wrapped forward function
out = wrapped_mod(x)
# `add_3_and_call` will call the Capsa wrapped forward function
add_3_out = wrapped_mod.add_3_and_call(x)
# The logic in `helper_fn` will be used in the Capsa wrapped forward
# function but calling the function directly will not use Capsa
# wrapped code
helper_out = wrapped_mod.helper_fn(x)
# Same as above for sub-modules called by the outer module
inner_out = wrapped_mod.inner_mod(x)
# Same as above
linear_out = wrapped_mod.linear(x)
Current Limitations
We are working to reduce these limitations in future versions of capsa_tf
.
Tensorflow Version
capsa_tf
requires tensorflow >= 2.12.*
and python >=3.8.*,<3.12
Keras
We have designed capsa-tf
to work with the Keras
library but this functionality is still undergoing testing. If you experience any issues, please report them to us at help@themisai.io.
Please see below for example usage of the Keras3
library:
Note
capsa-tf
added “layers” will not be visible under model.summary()
Initializing and wrapping a
Keras3
model:from capsa_tf import vote # define your sequential/functional/subclassed model model = keras.Sequential(...) model.compile(optimizer="adam", loss="mse") # wrap with capsa model = vote.Wrapper()(model) # now can optionally return risk # NOTE: only this method expects return_risk argument # (other keras methods like .fit, .predict_on_batch, etc # do not expect this argument to be passed) out, risk = model(x, return_risk=True) # contains both original and capsa added variables model.trainable_variables # use keras API as usual model.fit(x, y, ...) model.evaluate(x, y) model.predict(x)
Saving / Loading a
Keras3
model:You can save a capsa wrapped model like so:
# define your sequential/functional/subclassed model model = keras.Sequential(...) model.compile(...) model = vote.Wrapper()(model) # after training your model with e.g., .fit, # .train_on_batch, or your custom training function model.train_step((x, y)) filepath = "./model.weights.h5" model.save_weights(path)
You can load your wrapped model like so:
# define your sequential/functional/subclassed model restored_model = keras.Sequential(...) restored_model.compile(...) restored_model = vote.Wrapper()(restored_model) # call your model on inputs before loading the weights # this initializes new capsa variables restored_model(x) # now can load the saved variables restored_model.load_weights(filepath) # we get same loss as before saving/loading # (remember that sample wrapper is non # deterministic so loss value can differ slightly) restored_model.evaluate(x, y)
Using
Sculpt
wrapper:Warning
The steps below are specific to the sculpt wrapper and do not apply to training models wrapped by other wrappers. Please visit the corresponding wrapper page for any additional info on how to train models wrapped by the other wrappers.
Sculpt differs from the other wrappers in that it requires a
capsa-tf
loss function to be used, in addition to your loss function.If you compiled your
Keras3
model with a loss function then as part of wrappingcapsa-tf
will automatically combine your loss function withcapsa-tf.sculpt.Normal.loss_function
. The combined loss function is a simple sum of your compiled loss function and thecapsa-tf
loss function. No further actions are needed on your end.from capsa_tf import sculpt model = keras.Sequential(...) model.compile(optimizer="adam", loss="mse") # capsa automatically combines your `MSE` loss # with `capsa-tf.sculpt.Normal.loss_function` model = sculpt.Wrapper()(model)
If your model wasn’t compiled with a loss function or if you want more control over combining the two loss functions. Then as part of wrapping you will see the following warning message.
[Warning] You are using a wrapper that requires loss function to be modified, but we couldn't find your loss function to automatically replace it. Please either: a) compile your model before wrapping with: `.compile(loss=your_loss_fn)`; or b) manually combine your loss with `sculpt.Normal.loss_function` and assign it to `model.compute_loss`
These are the code changes you need to implement:
from capsa_tf import sculpt model = keras.Sequential(...) model.compile(optimizer="adam") # the warning will be printed model = sculpt.Wrapper()(model) def compute_loss(x=None, y=None, y_pred=None, sample_weight=None, training=None): # you need to unpack y_pred to get risk values y_hat, risk = y_pred # then feed these risk values to the sculpt wrapper's loss function capsa_loss = sculpt.Normal.loss_function(y, y_hat, risk) # use your custom loss function usr_loss = my_custom_loss(y, y_hat) # manually combine the two return tf.reduce_mean([capsa_loss, usr_loss]) # assign the above to this attribute of the wrapped model model.compute_loss = compute_loss
Additional notes: for sculpt, the methods below run with
return_risk=True
. This default is not tunable.model.train_step((x, y)) model.train_on_batch(x, y) model.fit(x, y) model.test_step((x, y)) model.test_on_batch(x, y) model.evaluate(x, y) model.predict_step(x) model.predict_on_batch(x) model.predict(x)
Please refer to the following instructions for using Keras2
:
For
Sequential
andFunctional
models, you can capsa-decorate any function which calls the original model and then use – that decorated function – instead of the original model. Concretely for the example below, usewrapped_model
instead ofmodel
.You can also add any additional computations inside the Capsa decorated function if needed (e.g. inside of
wrapped_model
below), that additional logic together with the forward of the originalmodel
will be transformed with the applied Capsa wrapper.For end-to-end training example see: MNIST DCGAN
model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(16, 3, activation='tanh',input_shape=(28, 28, 1)), tf.keras.layers.Conv2D(8, 3, activation='tanh'), tf.keras.layers.Flatten(), tf.keras.layers.Dense(10, activation='tanh'), ]) @vote.Wrapper() def wrapped_model(x): logits = model(x, training=False) return logits # calls `model` forward that has been wrapped out = wrapped_model(x) # but calling `model` itself (instead of the `wrapped_model`) will # call the original **unmodified** `model` out = model(x)
Warning
‘train_step’, ‘test_step’, ‘predict_step’, and higher level functions that build on top of them, like ‘fit’ –– will all use original unmodified ‘model’ under the hood. So do not use these methods.
For subclassed models, you can capsa-decorate any method of the model. Most often
call
method is decorated.
Do not wrap backward pass
Please don’t try to apply capsa-tf
wrappers to a function that computes backward pass.
Concretely:
if you want to wrap an entire training step
@capsa_tf.vote.Wrapper(...) # WRONG! def my_func(...): # run forward pass with tf.GradientTape() as tape: y_pred = model(...) loss = compute_loss(...) # run backwards pass optimizer.minimize(...) return ...
rewrite it like this, and wrap with
capsa-tf
only the forward pass@capsa_tf.vote.Wrapper(...) # CORRECT! def my_forward_func(...): return model(...) def my_func(...): # run forward pass with tf.GradientTape() as tape: # now this uses the wrapped `model` y_pred = my_forward_func(...) loss = compute_loss(...) # run backwards pass optimizer.minimize(...) return ... # NOTE: wrap only the forward (NOT backward) function with capsa-tf -- in this case, # we only decorated `my_forward_func` (NOT `my_func`) with a capsa-tf wrapper
Nested Wrapping
Currently we only support using a single wrapper at a time (let us know if wrapper composition is a feature you’re interested in).
Nested tf.function
When working with tf.function
there’s a potential for a behavior which may seem unintuitive,
you need to be aware about it.
Concretely, consider the following snippets:
### snippet 1
def inner_fn(...):
...
@capsa_tf.vote.Wrapper
def outer_fn(...):
inner_fn(...)
...
outer_fn(...)
inner_fn
is a regular python function – inner_fn
will be wrapped only inside outer_fn
the code that
inner_fn
executed insideouter_fn
will be wrappedbut if you call
inner_fn
standalone (outside ofouter_fn
), your input will flow through the original (not wrappedinner_fn
)==> to sum up, if you wrap a
tf.Function
which calls a python function (nottf.Function
), the latter will be wrapped only inside the former
Wrap an already traced tf.function is not supported
Trying to wrap a tf.Function
that has been already called on
inputs (before wrapping) is intentionally disallowed, please re-declare
your function and don’t call it on inputs before you wrapped it.
@tf.function
def my_func(...):
...
# traces and runs the original (not Capsa wrapped) function
_ = my_func(...)
# if you try to wrap now (after it
# has been already traced), you'll
# get an informative err message
my_func = capsa_tf.sample.Wrapper(...)(my_func)
Tensorflow v1 not supported
We don’t support tf1
and the related legacy components (e.g. RefVariable
).