In some situations, you will need to do some advanced indexing/selection with Pytorch, for example answering the question: “how can I select elements from Tensor A following the indexes specified in Tensor B?”
In this post we will present the three most common methods for such tasks, namely torch.index_select, torch.gather and torch.take. We will explain them all in detail and contrast them against each other.
Admittedly, one motivation for this post was that I forgot how and when to use which feature, I ended up googling, browsing Stack Overflow and the official documentation, in my opinion, relatively brief and not very useful. Therefore, as mentioned, here we delve into these features: we motivate when to use them, give 2D and 3D examples, and show the resulting selection graphically.
I hope this post brings clarity to those features and eliminates the need for further exploration. Thank you for reading!
And now, without further ado, let's dive into the features one by one. For everyone, we first start with a 2D example and visualize the resulting selection, and then move on to a somewhat more complex 3D example. Furthermore, we reimplement the operation executed in simple Python; You can refer to the pseudocode as another source of information about what these functions do. At the end, we summarize the functions and their differences in a table.
torch.index_select selects elements along one dimension, while leaving the others unchanged. That is: keep all elements from all other dimensions, but select elements in the target dimensions following the index tensor. Let's demonstrate this with a 2D example, where we select along dimension 1:
num_picks = 2values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# (len_dim_0, num_picks)
picked = torch.index_select(values, 1, indices)
The resulting tensor has the form (len_dim_0, num_picks)
: For each element along dimension 0, we have chosen the same element from dimension 1. Let's visualize this: