Use Case: Robust Object Tracking with TCOW

In this tutorial, we show how you can integrate capsa-torch to an already existing codebase. We will use the TCOW repository. We also present the updated codebase with the capsa-torch integration, which you can find here.

What is TCOW?

TCOW (Tracking through Containers and Occluders in the Wild) is a segmentation model that has the ability to track objects even when they are behind an occluder, or inside a container (i.e. not visible in the current frame).

Example visualization generated by TCOW model.

The TCOW model is a vision transformer model that takes in the input video as a sequence of frames, and a single frame with the mask of the object we want to track. It then outputs three masks per frame of the video:

  1. Target object’s mask

  2. Main occluder’s mask

  3. Main container’s mask

Input-output definition of the TCOW model

However, it is also important to note the capabilities of the repository that the paper’s authors published. The repository does not only define the model, but also the training and evaluation pipelines, as well as the visualization tools. You are able to plug-in new models and new datasets to the repo, and able to train and evaluate them with ease.

Wrapping TCOW with capsa-torch

Wrapping the TCOW with capsa-torch allows us to add risk-awareness to the TCOW model. This means that the wrapped model is able to output risk values for each predicted mask, giving the users an idea of how confident the model is in its predictions. If you use such a risk-aware model in high-risk scenarios (autopilot, medical imaging, etc.), you can use these risk values to make more informed decisions.

Example risk visualization

Changes to the TCOW repository

Args

User args are defined in tcow/args.py. They are accesible by almost all the modules initialized in the repository, and therefore gives us a good way to pass necessary information.

We add all the necessary arguments for capsa-torch. The most important one is the --wrapper argument, which defines how the model will be wrapped. --symbolic_trace and --verbose arguments can both be passed into any wrapper, and therefore shared. However, all the other arguments are specific to different wrappers. For example, the --n_samples argument is passed as a parameter to the sample wrapper to define the number of samples to take from the model. These arguments are better explained in the rest of the documentation.

@@ -93,6 +93,45 @@ def shared_args(parser):
    parser.add_argument('--wandb_group', default='group', type=str,
                        help='Group to put this experiment in on weights and biases.')

+    # Capsa  options
+    parser.add_argument('--wrapper', default="none", type=str,
+                        help='Capsa wrapper to use for the model. Options: none, vote, sculpt, sample')
+    parser.add_argument('--symbolic_trace', default=False, type=_str2bool,
+                        help='')
+    parser.add_argument('--verbose', default=0, type=int,
+                        help='')
+
+    # Capsa sample wrapper options
+    parser.add_argument('--n_samples', default=5, type=int,
+                        help='')
+    parser.add_argument('--distribution', default=0.1, type=float,
+                        help='')    
+    parser.add_argument('--trainable', default=False, type=_str2bool,
+                        help='')
+
+    # Capsa vote wrapper options
+    parser.add_argument('--n_voters', default=3, type=int,
+                        help='')
+    parser.add_argument('--alpha', default=1, type=int,
+                        help='')    
+    parser.add_argument('--use_bias', default=True, type=_str2bool,
+                        help='')
+    parser.add_argument('--independent', default=True, type=_str2bool,
+                        help='')    
+    parser.add_argument('--finetune', default=False, type=_str2bool,
+                        help='')
+
+    # Capsa sculpt wrapper options
+    parser.add_argument('--n_layers', default=3, type=int,
+                        help='')
@@ -265,6 +304,8 @@ def verify_args(args, is_train=False):

    args.wandb_group = ('train' if is_train else 'test') + \
                       ('_debug' if args.is_debug else '')

+   assert args.wrapper in ["none", "vote", "sculpt", "sample"]

    if is_train:

In the verify_args() function, we add a check to make sure that the wrapper argument is valid, since we rely on it’s value to adjust model and optimizer initialization.

Fixes

In data/data_kubric.py there is a helper function for loading kubric dataset. This _load_example_preprocess() function does not close the previously opened PIL images, which causes a memory leak. We fix this by adding image.close() to the end of the function.

@@ -261,11 +261,12 @@ def _load_example_preprocess(self, scene_idx, scene_dp, frame_inds_load):
            segm_fp = os.path.join(frames_dp, f'segmentation_{t:05d}.png')
            if not os.path.exists(rgb_fp):
                break

+           image = PIL.Image.open(depth_fp)
            rgb = plt.imread(rgb_fp)[..., 0:3]  # (H, W, 3) floats.
-           depth = np.array(PIL.Image.open(depth_fp))  # (H, W) floats.
+           depth = np.array(image)  # (H, W) floats.
            depth = depth[..., None]  # (H, W, 1).
            segm = plt.imread(segm_fp)[..., 0:3]  # (H, W, 3) floats.
+           image.close()

            pv_rgb.append(rgb)
            pv_depth.append(depth)

Inference and Loading Models for Evaluation

In eval/test.py there is a function called load_networks(). This function is responsible for loading the model from the checkpoint. We modify this function to pass the test_args to the load_networks() function. This is because we need to access all the user arguments passed, including the --wrapper argument.

If the wrapper is not none, we need to wrap the model with the corresponding wrapper. We do this by adding the following code to tcow/test.py. Since we also updated our Seeker model with extra parameters, we also make sure to follow that same specification when initializing the newly updated Seeker model.

@@ -186,7 +186,7 @@ def main(test_args, logger):
    else:
        device = torch.device(test_args.device)
    (networks, train_args, train_dset_args, model_args, epoch) = \
-        inference.load_networks(test_args.resume, device, logger, epoch=test_args.epoch)
+        inference.load_networks(test_args, device, logger, epoch=test_args.epoch)

    logger.info(f'Took {time.time() - start_time:.3f}s')
+import torch
+from capsa_torch.sample.distribution import Bernoulli

-def load_networks(checkpoint_path, device, logger, epoch=-1):
+def load_networks(test_args, device, logger, epoch=-1):

    print_fn = logger.info if logger is not None else print

+   checkpoint_path = test_args.resume
    assert os.path.exists(checkpoint_path)
    if os.path.isdir(checkpoint_path):
        model_fn = f'model_{epoch}.pth' if epoch >= 0 else 'checkpoint.pth'
        checkpoint_path = os.path.join(checkpoint_path, model_fn)

    print_fn('Loading weights from: ' + checkpoint_path)
    checkpoint = torch.load(checkpoint_path, map_location='cpu')


    # Load all arguments for later use.
    train_args = checkpoint['train_args']
+    #if train_args have attribute
+    if not hasattr(train_args,'wrapper') and hasattr(test_args,'wrapper'): train_args.wrapper = test_args.wrapper 
    train_dset_args = checkpoint['dset_args']


    # Get network instance parameters.
    seeker_args = checkpoint['seeker_args']

    model_args = {'seeker': seeker_args}


+    if test_args.wrapper == "sample":
+        wrapper = sample.Wrapper(symbolic_trace=test_args.symbolic_trace,n_samples=test_args.n_samples,distribution=sample.Bernoulli(test_args.distribution),trainable=test_args.trainable,verbose=test_args.verbose)
+    elif test_args.wrapper == "vote":
+        wrapper = vote.Wrapper(symbolic_trace=test_args.symbolic_trace,finetune=test_args.finetune,n_voters=test_args.n_voters,alpha=test_args.alpha,use_bias=test_args.use_bias,verbose=test_args.verbose,independent=test_args.independent)
+    elif test_args.wrapper == "sculpt":
+        wrapper = sculpt.Wrapper(symbolic_trace=test_args.symbolic_trace,n_layers=test_args.n_layers,verbose=test_args.verbose)
+    else:
+        wrapper = None


    # Instantiate networks.
-    seeker_net = seeker.Seeker(logger, **seeker_args)
+    seeker_net = seeker.Seeker(logger,wrapper=wrapper,wrapper_arg=test_args.wrapper, **seeker_args)
    seeker_net = seeker_net.to(device)
-    seeker_net.load_state_dict(checkpoint['net_seeker'])
+    if test_args.wrapper == "none": 
+        seeker_net.load_state_dict(checkpoint['net_seeker'],strict=True)
+    else:
+        seeker_net.wrap()
+        seeker_net.load_state_dict(checkpoint['net_seeker'],strict=False)


    networks = {'seeker': seeker_net}
    epoch = checkpoint['epoch']
    print_fn('=> Loaded epoch (1-based): ' + str(epoch + 1))

Seeker

The Seeker is the main model in TCOW. It is responsible for taking the input image and the query mask, and outputting the mask predictions for each (target,collider,occluder)

We modify the Seeker model to be able to wrap it with different wrappers. We do this by adding the wrapper and wrapper_arg parameters to the __init__() function. We also add a wrap() function to the Seeker model, which wraps the model with the corresponding wrapper. Finally, we modify the forward() function to call the wrapped_seeker if the wrapper is not none.

Keep in mind, the model does not output risk values if the wrapper argument is none OR if the current phase is train.

# Internal imports.
import mask_tracker

+from capsa_torch.sculpt.distribution import Normal


class Seeker(torch.nn.Module):

-    def __init__(self, logger, **kwargs):
+    def __init__(self, logger,wrapper=None,wrapper_arg=None, **kwargs):
        super().__init__()
        self.logger = logger
+        self.wrapper = wrapper
        self.seeker = mask_tracker.QueryMaskTracker(logger, **kwargs)
+        self.wrapper_arg = wrapper_arg

+        if self.wrapper_arg == 'sculpt':
+            self.distribution = Normal()
+            self.forward = self.sculpt_forward
+        elif self.wrapper_arg == 'vote':
+            self.forward = self.vote_forward
+        elif self.wrapper_arg == 'sample':
+            self.forward = self.sample_forward


+    def wrap(self):
+        self.wrapped_seeker = self.wrapper(self.seeker) if self.wrapper != None else None

+    def forward(self,phase, seeker_input, seeker_query_mask):
+        return self.seeker(seeker_input, seeker_query_mask)

+    def sculpt_forward(self,phase, seeker_input, seeker_query_mask):
+        if phase == "train":
+            risk_output = self.wrapped_seeker(seeker_input, seeker_query_mask,return_risk=True)
+            return self.distribution.sample(risk_output)
+        else:
+            return self.wrapped_seeker(seeker_input, seeker_query_mask,return_risk=True)

+    def vote_forward(self,phase, seeker_input, seeker_query_mask):
+        return self.wrapped_seeker(seeker_input, seeker_query_mask,return_risk=False,tile_and_reduce=False) if phase == "train" else self.wrapped_seeker(seeker_input, seeker_query_mask,return_risk=True)

+    def sample_forward(self,phase, seeker_input, seeker_query_mask):
+        return self.wrapped_seeker(seeker_input, seeker_query_mask,return_risk=False) if phase == 'train' else self.wrapped_seeker(seeker_input, seeker_query_mask,return_risk=True)



-    def forward(self, *args):
-        return self.seeker(*args)

Pipeline

Pipeline tcow/pipeline.py is responsible for knowing how to run the Seeker model. It also takes care of some of the pre-processing and post-processing. Since the main two datasets used by the Seeker model (kubric and plugin) are both highly specific, the Pipeline has two methods forward_plugin() and forward_kubric() that is specialized in preparing the corresponding dataset for inference.

Here, we are editing the forward_kubric() method. We add the phase parameter to the function, which is used to determine whether we should output risk values or not. We also add the all_output_mask_risk list, which is used to store the risk values for each query. Finally, we add the output_mask_risk tensor to the model_retval dictionary, which is returned by the method.

@@ -130,6 +130,7 @@ def forward_kubric(self, data_retval):
        all_full_occl_cont_id = []
        all_target_mask = []
        all_output_mask = []
+        all_output_mask_risk = []

        for q in range(Qs):

@@ -153,9 +154,14 @@ def forward_kubric(self, data_retval):
                raise RuntimeError(
                    f'target_mask all zero? q: {q} query_idx: {query_idx} qt_idx: {qt_idx}')

-            # Run seeker to recover hierarchical masks over time.
-            (output_mask, output_flags) = self.networks['seeker'](
-                seeker_input, seeker_query_mask)  # (B, 3, T, Hf, Wf), (B, T, 3).
+            if 'train' in self.phase or 'none' in self.train_args.wrapper:
+                # Run seeker to recover hierarchical masks over time.
+                (output_mask, output_flags) = self.networks['seeker'](
+                    seeker_input=seeker_input, seeker_query_mask=seeker_query_mask,phase=self.phase)  # (B, 3, T, Hf, Wf), (B, T, 3).
+            else:
+                # Run seeker to recover hierarchical masks over time.
+                (output_mask, output_flags),(output_mask_risk, output_flags_risk) = self.networks['seeker'](
+                    seeker_input=seeker_input, seeker_query_mask=seeker_query_mask,phase=self.phase)  # (B, 3, T, Hf, Wf), (B, T, 3).

            # Save some ground truth metadata, e.g. weighted query desirability, to get a feel for
            # this example or dataset.
@@ -172,6 +178,7 @@ def forward_kubric(self, data_retval):
            all_full_occl_cont_id.append(full_occl_cont_id)  # (B, T, 2).
            all_target_mask.append(target_mask)  # (B, 3, T, Hf, Wf).
            all_output_mask.append(output_mask)  # (B, 1/3, T, Hf, Wf).
+            if self.train_args.wrapper != 'none' and 'train' not in self.phase: all_output_mask_risk.append(output_mask_risk)  # (B, 1/3, T, Hf, Wf).

        sel_occl_fracs = torch.stack(all_occl_fracs, dim=1)  # (B, Qs, T, 3).
        sel_desirability = torch.stack(all_desirability, dim=1)  # (B, Qs).
@@ -180,6 +187,7 @@ def forward_kubric(self, data_retval):
        full_occl_cont_id = torch.stack(all_full_occl_cont_id, dim=1)  # (B, Qs, T, 2).
        target_mask = torch.stack(all_target_mask, dim=1)  # (B, Qs, 3, T, Hf, Wf).
        output_mask = torch.stack(all_output_mask, dim=1)  # (B, Qs, 1/3, T, Hf, Wf).
+        if self.train_args.wrapper != 'none' and 'train' not in self.phase: output_mask_risk = torch.stack(all_output_mask_risk, dim=1)  # (B, Qs, 1/3, T, Hf, Wf).

        # Organize & return relevant info.
        # Ensure that everything is on a CUDA device.
@@ -196,6 +204,7 @@ def forward_kubric(self, data_retval):
        # (B, Qs, 1, T, Hf, Wf).
        model_retval['target_mask'] = target_mask.to(self.device)  # (B, Qs, 3, T, Hf, Wf).
        model_retval['output_mask'] = output_mask.to(self.device)  # (B, Qs, 1/3, T, Hf, Wf).
+        if self.train_args.wrapper != 'none' and 'train' not in self.phase: model_retval['output_mask_risk'] = output_mask_risk.to(self.device)  # (B, Qs, 1/3, T, Hf, Wf).

        return model_retval

Here, we are editing the forward_plugin() method. We add the phase parameter to the function, which is used to determine whether we should output risk values or not. We also add the output_mask_risk tensor to the model_retval dictionary, which is returned by the method.

@@ -224,9 +233,16 @@ def forward_plugin(self, data_retval):
        if not seeker_query_mask.any():
            raise RuntimeError(f'seeker_query_mask all zero?')

-        # Run seeker to recover hierarchical masks over time.
-        (output_mask, output_flags) = self.networks['seeker'](
-            seeker_input, seeker_query_mask)  # (B, 3, T, Hf, Wf), (B, T, 3).
+        if 'train' in self.phase or 'none' in self.train_args.wrapper:
+            # Run seeker to recover hierarchical masks over time.
+            (output_mask, output_flags) = self.networks['seeker'](
+                seeker_input=seeker_input, seeker_query_mask=seeker_query_mask,phase=self.phase)  # (B, 3, T, Hf, Wf), (B, T, 3).
+        else:
+            # Run seeker to recover hierarchical masks over time.
+            (output_mask, output_flags),(output_mask_risk, output_flags_risk) = self.networks['seeker'](
+                seeker_input=seeker_input, seeker_query_mask=seeker_query_mask,phase=self.phase)  # (B, 3, T, Hf, Wf), (B, T, 3).



        # Organize & return relevant info.
        # Ensure that everything is on a CUDA device.
@@ -236,6 +252,9 @@ def forward_plugin(self, data_retval):
        model_retval['target_mask'] = target_mask.to(self.device)  # (B, 3, T, Hf, Wf).
        model_retval['output_mask'] = output_mask.to(self.device)  # (B, 3, T, Hf, Wf).
        model_retval['output_flags'] = output_flags.to(self.device)  # (B, T, 3).

+        if self.train_args.wrapper != 'none' and 'train' not in self.phase: model_retval['output_flags_risk'] = output_flags_risk.to(self.device)
+        if self.train_args.wrapper != 'none' and 'train' not in self.phase: model_retval['output_mask_risk'] = output_mask_risk.to(self.device)

        return model_retval

Training

Training is done in tcow/train.py. All the logic for wrapping, iterating through dataloader, checkpointing, happens here.

In _train_one_epoch() function, we make changes to make sure the model is in the correct device.

import seeker


+import torch
+from capsa_torch import sample,vote,sculpt

def _get_learning_rate(optimizer):
    if isinstance(optimizer, dict):
        optimizer = my_utils.any_value(optimizer)
@@ -77,11 +81,14 @@ def _train_one_epoch(args, train_pipeline, networks_nodp, phase, epoch, optimize
        except Exception as e:

            num_exceptions += 1
-            if num_exceptions >= 20:
+            if num_exceptions >= 3:
+                logger.exception(e)
                raise e
            else:
-                logger.exception(e)
+                train_pipeline[0].to(device)
                continue
+        if epoch == -1: break

        # Perform backpropagation to update model parameters.
        if phase == 'train':
@@ -105,7 +112,7 @@ def _train_one_epoch(args, train_pipeline, networks_nodp, phase, epoch, optimize
            logger.warning('Cutting epoch short for debugging...')
            break

    if phase == 'train':
+    if phase == 'train' and epoch != -1:
        for (k, v) in lr_schedulers.items():
            v.step()

In main() function, we initialize the correct Wrapper for the Seeker model. Later, we pass it to the Seeker model as a parameter.

@@ -180,6 +187,15 @@ def main(args, logger):
    logger.info('Initializing model...')
    start_time = time.time()

+    if args.wrapper == "sample":
+        wrapper = sample.Wrapper(symbolic_trace=args.symbolic_trace,n_samples=args.n_samples,distribution=sample.Bernoulli(args.distribution),trainable=args.trainable,verbose=args.verbose)
+    elif args.wrapper == "vote":
+        wrapper = vote.Wrapper(symbolic_trace=args.symbolic_trace,finetune=args.finetune,n_voters=args.n_voters,alpha=args.alpha,use_bias=args.use_bias,verbose=args.verbose,independent=args.independent)
+    elif args.wrapper == "sculpt":
+        wrapper = sculpt.Wrapper(symbolic_trace=args.symbolic_trace,n_layers=args.n_layers,verbose=args.verbose)
+    else:
+        wrapper = None

    # Instantiate networks.
    networks = dict()

@@ -203,7 +219,8 @@ def main(args, logger):
    seeker_args['query_channels'] = 1
    seeker_args['output_channels'] = 3  # Target/snitch + frontmost occluder + outermost container.
    seeker_args['flag_channels'] = 3  # (occluded, contained, percentage).
-    seeker_net = seeker.Seeker(logger, **seeker_args)
+    seeker_net = seeker.Seeker(logger,wrapper,args.wrapper, **seeker_args)
+    if args.wrapper != None: seeker_net.wrap()

    networks['seeker'] = seeker_net

@@ -236,6 +253,14 @@ def main(args, logger):
    milestones = [(args.num_epochs * 2) // 5,
                  (args.num_epochs * 3) // 5,
                  (args.num_epochs * 4) // 5]

+    # Instantiate datasets.
+    logger.info('Initializing data loaders...')
+    start_time = time.time()
+    (train_loader, val_aug_loader, val_noaug_loader, dset_args) = \
+        data.create_train_val_data_loaders(args, logger)
+    logger.info(f'Took {time.time() - start_time:.3f}s')

    for (k, v) in networks.items():
        if len(list(v.parameters())) != 0:
            optimizers[k] = optimizer_class(v.parameters(), lr=args.learn_rate)
@@ -247,23 +272,18 @@ def main(args, logger):
        logger.info('Loading weights from: ' + args.resume)
        checkpoint = torch.load(args.resume, map_location='cpu')
        for (k, v) in networks_nodp.items():
-            v.load_state_dict(checkpoint['net_' + k])
+            v.load_state_dict(checkpoint['net_' + k],strict=False)
        for (k, v) in optimizers.items():
-            v.load_state_dict(checkpoint['optim_' + k])
+            v.load_state_dict(checkpoint['optim_' + k],strict=False)
        for (k, v) in lr_schedulers.items():
-            v.load_state_dict(checkpoint['lr_sched_' + k])
+            v.load_state_dict(checkpoint['lr_sched_' + k],strict=False)
        start_epoch = checkpoint['epoch'] + 1
    else:
        start_epoch = 0


    logger.info(f'Took {time.time() - start_time:.3f}s')

-    # Instantiate datasets.
-    logger.info('Initializing data loaders...')
-    start_time = time.time()
-    (train_loader, val_aug_loader, val_noaug_loader, dset_args) = \
-        data.create_train_val_data_loaders(args, logger)
-    logger.info(f'Took {time.time() - start_time:.3f}s')

    # Define logic for how to store checkpoints.
    def save_model_checkpoint(epoch):

Visualizing and Logging Risks

TCOW repository has it’s own visualization toolset. It mainly generates videos for the train and test steps. These videos normally consist of the input video edited to also show model mask predictions (target,main occluder,main container). Here’s an example screenshot of the generated videos:

Example risk visualization

We modify the visualization toolset to also show the risk values for each query.

In utils/logvis.py, we add the vis_risk variable to the handle_train_step_mask_track() function. This variable is used to store the risk video for each query. We also add the vis_risk_pause variable, which is used to store the risk video for each query, but with the first frame repeated 3 times. This is done to make the video pause at the first frame for 3 seconds. Finally, we add the vis_risk and vis_risk_pause variables to the vis_extra list, which is used to store all the extra videos that will be generated.

Keep in mind, the model does not output risk values if the wrapper argument is none, or if the extra_visuals argument is not passed.

@@ -141,6 +141,8 @@ def handle_train_step_mask_track(self, epoch, phase, cur_step, total_step, steps
        # (Qs, 3, T, H, W).
        target_mask = model_retval['target_mask'][0].detach().cpu().numpy()
        # (Qs, 3, T, H, W).
+        if train_args.wrapper != "none" and phase != "train": output_mask_risk = model_retval['output_mask_risk'][0].detach().cpu().numpy()
+        # (Qs, 3, T, H, W).
        if 'snitch_weights' in model_retval:
            snitch_weights = model_retval['snitch_weights'][0].detach().cpu().numpy()
        # (Qs, 1, T, H, W).
@@ -151,6 +153,7 @@ def handle_train_step_mask_track(self, epoch, phase, cur_step, total_step, steps
            # Add fake query count dimension (Qs = 1).
            seeker_query_mask = seeker_query_mask[None]  # Add fake query count dimension (Qs = 1).
            output_mask = output_mask[None]
+            if train_args.wrapper != "none" and phase != "train": output_mask_risk = output_mask_risk[None]
            target_mask = target_mask[None]  # Add fake query count dimension (Qs = 1).

            # We want to slow down plugin videos according to how much we are subsampling them
@@ -216,6 +219,9 @@ def handle_train_step_mask_track(self, epoch, phase, cur_step, total_step, steps
                vis_intgt = visualization.create_model_input_target_video(
                    seeker_rgb, seeker_query_mask[q, 0], target_mask[q], query_border,
                    snitch_border, frontmost_border, outermost_border, grayscale=False)

+            if train_args.wrapper != "none" and phase != "train":
+                vis_risk = visualization.create_model_output_risk_video(seeker_rgb,output_mask_risk[q],0)

            vis_extra = []
            if ('test' in phase and test_args.extra_visuals) or \
@@ -230,11 +236,14 @@ def handle_train_step_mask_track(self, epoch, phase, cur_step, total_step, steps
                # Include temporally concatenated & spatially horizontally concatenated versions of
                # (input) + (output + target) or (input + target) + (output + target).
                vis_allout_pause = np.concatenate([vis_allout[0:1]] * 3 + [vis_allout[1:]], axis=0)
+                if train_args.wrapper != "none" and phase != "train":
                    vis_risk_pause = np.concatenate([vis_risk[0:1]] * 3 + [vis_risk[1:]], axis=0)
                vis_intgt_pause = np.concatenate([vis_intgt[0:1]] * 3 + [vis_intgt[1:]], axis=0)
                vis_extra.append(np.concatenate([vis_input, vis_allout], axis=0))  # (T, H, W, 3).
                vis_extra.append(np.concatenate([vis_intgt_pause, vis_allout], axis=0))  # (T, H, W, 3).
                vis_extra.append(np.concatenate([vis_input, vis_allout_pause], axis=2))  # (T, H, W, 3).
                vis_extra.append(np.concatenate([vis_intgt_pause, vis_allout_pause], axis=2))  # (T, H, W, 3).
+                if train_args.wrapper != "none" and phase != "train": vis_extra.append(np.concatenate([vis_allout_pause, vis_risk_pause], axis=2))  # (T, H, W, 3).

            file_name_suffix_q = file_name_suffix + f'_q{q}'
            # Easily distinguish all-zero outputs

In utils/visualization.py, we add the create_model_output_risk_video() function. This function is responsible for preparing risk videos for each query. We also add the get_cbar() function, which is used to generate the relevant colorbar for the risk video. The colorbar image returned by this function gets appended next to the risk videos.

from __init__ import *

+import matplotlib.colors as mcolors
+import matplotlib.cm as cm
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+from io import BytesIO
+from PIL import Image
+from math import floor

def draw_text(image, topleft, label, color, size_mult=1.0):
    '''
@@ -251,3 +260,47 @@ def create_model_input_target_video(

    video = np.clip(vis, 0.0, 1.0)
    return video

+def create_model_output_risk_video(seeker_rgb,output_mask_risk,target_index):

+    colormap = cm.magma
+    normalize = mcolors.Normalize(vmin=np.min(output_mask_risk), vmax=np.max(output_mask_risk))
+    s_map = cm.ScalarMappable(norm=normalize, cmap=colormap)

+    output_mask_risk = np.concatenate((output_mask_risk[0],output_mask_risk[1],output_mask_risk[2]),axis=-1)

+    mapped = s_map.get_cmap()(output_mask_risk)[...,:3]
+    img_bar = get_cbar(mapped,s_map.get_clim()[0],s_map.get_clim()[1],s_map=s_map)

+    img_bar = np.repeat(img_bar[np.newaxis,...]/255.,output_mask_risk.shape[-3],axis=0)
+    final_img = np.append(mapped,img_bar,axis=-2)

+    return final_img


+def get_cbar(img, min_val, max_val, s_map):
+    fig, ax = plt.subplots(figsize=(15,15))

+    img_cmap = ax.imshow(img[0,...], cmap=s_map.cmap, vmin=min_val, vmax=max_val)
+    plt.axis('off')
+    divider = make_axes_locatable(ax)
+    cax = divider.append_axes("right", size="5%", pad=0.)

+    num_ticks = 5
+    ticks = np.linspace(min_val, max_val, num_ticks)

+    cbar = plt.colorbar(mappable=s_map, cax=cax, orientation='vertical', ticks=ticks)
+    cbar.ax.set_yticklabels(["{:4.2f}".format(val) for val in ticks])

+    buf = BytesIO()
+    plt.savefig(buf, format='png', bbox_inches = 'tight', pad_inches = 0)
+    buf.seek(0)

+    img_bar = Image.open(buf)  # Opens the image in the buffer
+    img_bar = np.array(img_bar.resize((floor(1731*(240/383)),240)))[:,-85:-1,:3]  # Convert the image to a NumPy array


+    plt.close(fig)    
+    buf.close()

+    return img_bar

How to use

User interface with the repository is unchanged. We pass arguments when running train.py or test.py. For anything related with capsa, we can add any Capsa related arguments as well

Training

Here’s an example training command where we pass the sculpt argument. The newly created and wrapped model is checkpointed with the name sculpt_v0.

CUDA_VISIBLE_DEVICES=0 python train.py --name sculpt_v0 --data_path /data/kubric_random --batch_size 2 --num_workers 1 --num_queries 1 --num_frames 10 --frame_height 240 --frame_width 320 --causal_attention 1 --seeker_query_time 0 --do_val_noaug True --do_val_aug False --avoid_wandb 1 --wrapper sculpt

Evaluation (generating risk visualizations)

Now that we used the above command, we can generate risk visualizations. For test command, we need to pass the wrapper argument, so that the correct wrapper is used. As you can see, we pass the name of the checkpoint (sculpt_v0) as the resume argument. We can name this evaluation run as eval_v0. Keep in mind, extra_visuals argument is required to be passed if we want to generate risk visualizations.

CUDA_VISIBLE_DEVICES=3 python eval/test.py --resume sculpt_v0 --name eval_v0 --gpu_id 0 --data_path datasets/rubric_all_videos.txt --num_queries 1 --extra_visuals 1 --avoid_wandb 2 --wrapper sculpt