I have a tensor `a` that I would like to first mask using `mask` and then discard the remaining frames. To ensure the output tensor is of the correct shape, padding should fill in the remaining values at the end. I can assume there is only a single continuous sequence of `True`‘s in each row of the mask.

e.g.

``````a = torch.arange(1,17).reshape(4,4)
# tensor([[ 1,  2,  3,  4],
#         [ 5,  6,  7,  8],
#         [ 9, 10, 11, 12],
#         [13, 14, 15, 16]])

mask = torch.tensor([[False,  True,  True, False],
[False,  True,  True,  True],
[ True, False, False, False],
[ True,  True,  True,  True]])

# desired output (assuming padding value is 0):
# tensor([[ 2,  3,  0,  0],
#         [ 6,  7,  8,  0],
#         [ 9,  0,  0,  0],
#         [13, 14, 15, 16]])
``````

I can achieve the desired output by applying `torch.masked_select` followed by `torch.nn.functional.pad` on each row in a loop but I am struggling to think of a way to do this more efficiently in batches.

I have also looked into starting by using `torch.roll` and zeroing after appropriate indexes, but this function can only be applied across an entire dimension and not a custom amount of roll per row.

By applying `torch.sort` on the mask itself you can achieve the desired result. Indeed if your sort the boolean values you can manage to move the `False` values at the end of the stack, and let the `True` values at the beginning.

Do note this might vary depending on the sorting algorithm, there might be some shuffling for certain algorithms…. As @Seraf Fej pointed out: you can use the `stable=True` option on `torch.stable` such that the order of equivalent items is preserved.

Then use the indices of the sorting to gather the values on `a` with `torch.gather`. Finally, you will need to mask the resulting matrix to replace the discarded values with the appropriate padding.

``````>>> a
tensor([[ 1,  2,  3,  4],
[ 5,  6,  7,  8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])

tensor([[False,  True,  True, False],
[False,  True,  True,  True],
[ True, False, False, False],
[ True,  True,  True,  True]])
``````

``````>>> values, indices = mask.sort(1, descending=True, stable=True)

>>> values
tensor([[ True,  True, False, False],
[ True,  True,  True, False],
[ True, False, False, False],
[ True,  True,  True,  True]])

>>> indices
tensor([[1, 2, 0, 3],
[1, 2, 3, 0],
[0, 1, 2, 3],
[0, 1, 2, 3]])
``````

Gather from `indices` and mask with `values`:

``````>>> a.gather(1, indices)*values
tensor([[ 2,  3,  0,  0],
[ 6,  7,  8,  0],
[ 9,  0,  0,  0],
[13, 14, 15, 16]])
``````

You can easily extend to any padding value using `torch.where`:

``````>>> torch.where(values, a.gather(1, indices), -1)
tensor([[ 2,  3, -1, -1],
[ 6,  7,  8, -1],
[ 9, -1, -1, -1],
[13, 14, 15, 16]])
``````

Or using the inverse mask `~values`, weighted by the padding value:

``````>>> a.gather(1, indices)*values -1*~values
tensor([[ 2,  3, -1, -1],
[ 6,  7,  8, -1],
[ 9, -1, -1, -1],
[13, 14, 15, 16]])
``````