Limitations

Warning

Capsa wheels are only available on linux, amd64, CPython 3.8 - 3.11 If you would like to use Capsa in other environments, please let us know.

Current Limitations

We are working to reduce these limitations in future versions of capsa_torch.

PyTorch version

capsa-torch requires torch>=2.1.0

PyTorch Lightning

We have designed capsa-torch to work with pytorch-lightning but this functionality hasn’t been thoroughly tested yet and not all features are guaranteed to work.

In-place Modifications

Modules that perform in-place modifications on their weights/buffers during the forward call are not officially supported and may not behave as expected.

class InPlaceMod(torch.nn.Module):
    ...

    def forward(self, x):
        x.mul_(2) # In-place modification of input is ok
        x = self.layer(x)

        self.forward_count += 1 # In-place modification of module attribute (not supported)

        return x

Warning

If you need to store this kind of data during the forward pass, modify the module’s forward function to take forward_count as an argument.

Non-Tensor code in the Module

Our approach works best with Tensor-only modules. That is, modules that take Tensors as inputs and output Tensors (ideally floating point Tensors). Non-tensor code may prevent the module’s from wrapping or produce unexpected results.

class BadExampleModule(torch.nn.Module):
    ...

    def forward(self, input_prompt: str):
        # Steps like encoding an input string into a Tensor value should be done
        # outside the wrapped forward function
        encoded: torch.Tensor = self.encode_str(input_prompt)
        ...
        # print statements and other non-Tensor code may interfere with wrapping
        print(out)
        ...
        # Similar to above, decoding/postprocess of output Tensors should be done
        # outside of the wrapped forward function.
        out_str = self.decode(out)

        top_ind = torch.argmax(out)
        # Our wrappers work best with models that return floating point unprocessed 
        # outputs. Operations like argmax or even last layer softmax / sigmoid 
        # activations should be done outside the torch module.

        return out_str, top_ind

Nested Wrapping and Sub-modules

Currently we only support using a single wrapper at a time (let us know if wrapper composition is a feature you’re interested in).

This wrapper only wraps the forward function of a single torch.nn.Module. If your module has other methods that call the forward function, then these other methods will use the wrapped forward logic. However, if you wrap a module that contains sub-modules, only the forward function of the outer module will be modified.

class OuterModule(torch.nn.Module):
    def __init__(self):
        self.inner_mod = InnerModule()
        self.linear = torch.nn.Linear()

    def forward(self, x):
        x = self.inner_mod(x)
        x = self.linear(x)
        x = self.helper_fn(x)
        return x

    def add_3_and_call(self, x):
        return self.forward(x + 3)

    def helper_fn(self, x):
        return x * 5

mod = OuterModule()
wrapper = vote.Wrapper()
wrapped_mod = wrapper(mod)
x = torch.randn(2, 5)

# Calls Capsa wrapped forward function
out = wrapped_mod(x)

# `add_3_and_call` will call the Capsa wrapped forward function
add_3_out = wrapped_mod.add_3_and_call(x)

# The logic in `helper_fn` will be used in the Capsa wrapped forward
# function but calling the function directly will not use Capsa
# wrapped code
helper_out = wrapped_mod.helper_fn(x)

# Same as above for sub-modules called by the outer module
inner_out = wrapped_mod.inner_mod(x)

# Same as above
linear_out = wrapped_mod.linear(x)

Current Limitations

We are working to reduce these limitations in future versions of capsa_tf.

Tensorflow Version

capsa_tf requires tensorflow >= 2.12.* and python >=3.8.*,<3.12

Keras

We have designed capsa-tf to work with the Keras library but this functionality is still undergoing testing. If you experience any issues, please report them to us at help@themisai.io.

Please see below for example usage of the Keras3 library:

Note

capsa-tf added “layers” will not be visible under model.summary()

  1. Initializing and wrapping a Keras3 model:

    from capsa_tf import vote
    
    # define your sequential/functional/subclassed model
    model = keras.Sequential(...)
    model.compile(optimizer="adam", loss="mse")
    
    # wrap with capsa
    model = vote.Wrapper()(model)
    
    # now can optionally return risk
    # NOTE: only this method expects return_risk argument
    # (other keras methods like .fit, .predict_on_batch, etc
    # do not expect this argument to be passed)
    out, risk = model(x, return_risk=True)
    
    # contains both original and capsa added variables
    model.trainable_variables
    
    # use keras API as usual
    model.fit(x, y, ...)
    model.evaluate(x, y)
    model.predict(x)
    
  2. Saving / Loading a Keras3 model:

    You can save a capsa wrapped model like so:

    # define your sequential/functional/subclassed model
    model = keras.Sequential(...)
    model.compile(...)
    model = vote.Wrapper()(model)
    
    # after training your model with e.g., .fit,
    # .train_on_batch, or your custom training function
    model.train_step((x, y))
    
    filepath = "./model.weights.h5"
    model.save_weights(path)
    

    You can load your wrapped model like so:

    # define your sequential/functional/subclassed model
    restored_model = keras.Sequential(...)
    restored_model.compile(...)
    restored_model = vote.Wrapper()(restored_model)
    
    # call your model on inputs before loading the weights
    # this initializes new capsa variables
    restored_model(x)
    
    # now can load the saved variables
    restored_model.load_weights(filepath)
    
    # we get same loss as before saving/loading
    # (remember that sample wrapper is non
    # deterministic so loss value can differ slightly)
    restored_model.evaluate(x, y)
    
  3. Using Sculpt wrapper:

    Warning

    The steps below are specific to the sculpt wrapper and do not apply to training models wrapped by other wrappers. Please visit the corresponding wrapper page for any additional info on how to train models wrapped by the other wrappers.

    Sculpt differs from the other wrappers in that it requires a capsa-tf loss function to be used, in addition to your loss function.

    If you compiled your Keras3 model with a loss function then as part of wrapping capsa-tf will automatically combine your loss function with capsa-tf.sculpt.Normal.loss_function. The combined loss function is a simple sum of your compiled loss function and the capsa-tf loss function. No further actions are needed on your end.

    from capsa_tf import sculpt
    
    model = keras.Sequential(...)
    model.compile(optimizer="adam", loss="mse")
    # capsa automatically combines your `MSE` loss
    # with `capsa-tf.sculpt.Normal.loss_function`
    model = sculpt.Wrapper()(model)
    

    If your model wasn’t compiled with a loss function or if you want more control over combining the two loss functions. Then as part of wrapping you will see the following warning message.

    [Warning]
    
    You are using a wrapper that requires loss function to be modified,
    but we couldn't find your loss function to automatically replace it.
    
    Please either:
        a) compile your model before wrapping with: `.compile(loss=your_loss_fn)`; or
        b) manually combine your loss with `sculpt.Normal.loss_function` and assign it to `model.compute_loss`
    

    These are the code changes you need to implement:

    from capsa_tf import sculpt
    
    model = keras.Sequential(...)
    model.compile(optimizer="adam")
    # the warning will be printed
    model = sculpt.Wrapper()(model)
    
    def compute_loss(x=None, y=None, y_pred=None, sample_weight=None, training=None):
        # you need to unpack y_pred to get risk values
        y_hat, risk = y_pred
        # then feed these risk values to the sculpt wrapper's loss function
        capsa_loss = sculpt.Normal.loss_function(y, y_hat, risk)
        # use your custom loss function
        usr_loss = my_custom_loss(y, y_hat)
        # manually combine the two
        return tf.reduce_mean([capsa_loss, usr_loss])
    
    # assign the above to this attribute of the wrapped model
    model.compute_loss = compute_loss
    

    Additional notes: for sculpt, the methods below run with return_risk=True. This default is not tunable.

    model.train_step((x, y))
    model.train_on_batch(x, y)
    model.fit(x, y)
    
    model.test_step((x, y))
    model.test_on_batch(x, y)
    model.evaluate(x, y)
    
    model.predict_step(x)
    model.predict_on_batch(x)
    model.predict(x)
    

Please refer to the following instructions for using Keras2:

  1. For Sequential and Functional models, you can capsa-decorate any function which calls the original model and then use – that decorated function – instead of the original model. Concretely for the example below, use wrapped_model instead of model.

    You can also add any additional computations inside the Capsa decorated function if needed (e.g. inside of wrapped_model below), that additional logic together with the forward of the original model will be transformed with the applied Capsa wrapper.

    For end-to-end training example see: MNIST DCGAN

    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(16, 3, activation='tanh',input_shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(8, 3, activation='tanh'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, activation='tanh'),
    ])
    
    @vote.Wrapper()
    def wrapped_model(x):
        logits = model(x, training=False)
        return logits
    
    # calls `model` forward that has been wrapped
    out = wrapped_model(x)
    
    # but calling `model` itself (instead of the `wrapped_model`) will
    # call the original **unmodified** `model`
    out = model(x)
    

    Warning

    ‘train_step’, ‘test_step’, ‘predict_step’, and higher level functions that build on top of them, like ‘fit’ –– will all use original unmodified ‘model’ under the hood. So do not use these methods.

  2. For subclassed models, you can capsa-decorate any method of the model. Most often call method is decorated.

Do not wrap backward pass

Please don’t try to apply capsa-tf wrappers to a function that computes backward pass.

Concretely:

  • if you want to wrap an entire training step

    @capsa_tf.vote.Wrapper(...)   # WRONG!
    def my_func(...):
        # run forward pass
        with tf.GradientTape() as tape:
            y_pred = model(...)
            loss = compute_loss(...)
        # run backwards pass
        optimizer.minimize(...)
        return ...
    
  • rewrite it like this, and wrap with capsa-tf only the forward pass

    @capsa_tf.vote.Wrapper(...)   # CORRECT!
    def my_forward_func(...):
        return model(...)
    
    def my_func(...):
        # run forward pass
        with tf.GradientTape() as tape:
            # now this uses the wrapped `model`
            y_pred = my_forward_func(...)
            loss = compute_loss(...)
        # run backwards pass
        optimizer.minimize(...)
        return ...
    
    # NOTE: wrap only the forward (NOT backward) function with capsa-tf -- in this case,
    # we only decorated `my_forward_func` (NOT `my_func`) with a capsa-tf wrapper
    

Nested Wrapping

Currently we only support using a single wrapper at a time (let us know if wrapper composition is a feature you’re interested in).

Nested tf.function

When working with tf.function there’s a potential for a behavior which may seem unintuitive, you need to be aware about it.

Concretely, consider the following snippets:

### snippet 1

def inner_fn(...):
    ...

@capsa_tf.vote.Wrapper
def outer_fn(...):
    inner_fn(...)
    ...

outer_fn(...)

inner_fn is a regular python function – inner_fn will be wrapped only inside outer_fn

  • the code that inner_fn executed inside outer_fn will be wrapped

  • but if you call inner_fn standalone (outside of outer_fn), your input will flow through the original (not wrapped inner_fn)

  • ==> to sum up, if you wrap a tf.Function which calls a python function (not tf.Function), the latter will be wrapped only inside the former

Wrap an already traced tf.function is not supported

Trying to wrap a tf.Function that has been already called on inputs (before wrapping) is intentionally disallowed, please re-declare your function and don’t call it on inputs before you wrapped it.

@tf.function
def my_func(...):
    ...

# traces and runs the original (not Capsa wrapped) function
_ = my_func(...)

# if you try to wrap now (after it
# has been already traced), you'll
# get an informative err message
my_func = capsa_tf.sample.Wrapper(...)(my_func)

Tensorflow v1 not supported

We don’t support tf1 and the related legacy components (e.g. RefVariable).