Masked Tensor Operations in PyTorch

As far as I know, PyTorch does not inherently have masked tensor operations (such as those available in numpy.ma).

The other day, I needed to do some aggregation operations on a tensor while ignoring the masked elements in the operations. Specifically, I needed to do a mean() along a specific dimension, but ignore the masked elements. Fortunately, it’s easy enough to implement these operations manually. Let’s implement the mean() operation.

Let’s say you have a matrix a, and a bool mask m (with the same shape as a) and you want to compute a.mean(dim=1) but only on elements that are not masked. Here’s a small function that does this for you:

def masked_mean(tensor, mask, dim):
    masked = torch.mul(tensor, mask)  # Apply the mask using an element-wise multiply
    return masked.sum(dim=dim) / mask.sum(dim=dim)  # Find the average!

We can implement a similar function for finding (say) max() along a specific dimension:

def masked_max(tensor, mask, dim):
    masked = torch.mul(tensor, mask)
    neg_inf = torch.zeros_like(tensor)
    neg_inf[~mask] = -math.inf  # Place the smallest values possible in masked positions
    return (masked + neg_inf).max(dim=dim)[0]

Simple, no? 🙂

Leave a Reply

Your email address will not be published.