As far as I know, PyTorch does not inherently have masked tensor operations (such as those available in `numpy.ma`

).

The other day, I needed to do some aggregation operations on a tensor while ignoring the masked elements in the operations. Specifically, I needed to do a `mean()`

along a specific dimension, but ignore the masked elements. Fortunately, it’s easy enough to implement these operations manually. Let’s implement the `mean()`

operation.

Let’s say you have a matrix `a`

, and a bool mask `m`

(with the same shape as `a`

) and you want to compute `a.mean(dim=1)`

but only on elements that are not masked. Here’s a small function that does this for you:

1 2 3 |
def masked_mean(tensor, mask, dim): masked = torch.mul(tensor, mask) # Apply the mask using an element-wise multiply return masked.sum(dim=dim) / mask.sum(dim=dim) # Find the average! |

We can implement a similar function for finding (say) `max()`

along a specific dimension:

1 2 3 4 5 |
def masked_max(tensor, mask, dim): masked = torch.mul(tensor, mask) neg_inf = torch.zeros_like(tensor) neg_inf[~mask] = -math.inf # Place the smallest values possible in masked positions return (masked + neg_inf).max(dim=dim)[0] |

Simple, no? ðŸ™‚

## Recent Comments