Masked Tensor Operations in PyTorch

As far as I know, PyTorch does not inherently have masked tensor operations (such as those available in

The other day, I needed to do some aggregation operations on a tensor while ignoring the masked elements in the operations. Specifically, I needed to do a mean() along a specific dimension, but ignore the masked elements. Fortunately, it’s easy enough to implement these operations manually. Let’s implement the mean() operation.

Let’s say you have a matrix a, and a bool mask m (with the same shape as a) and you want to compute a.mean(dim=1) but only on elements that are not masked. Here’s a small function that does this for you:

We can implement a similar function for finding (say) max() along a specific dimension:

Simple, no? 🙂

Leave a Reply

Your email address will not be published.