Pytorch Index
This post really helps:
Let's assume we have tensor A
of shape [B, H, C], the basic way to select the index of the batch is:
A[tensor, tensor, ...]
This one will collect all the indexes across tensors at the same position, and use that as the index.
For example:
## creating A
t = torch.tensor([[1, 2], [3, 4]])
w = -t
###
A = torch.cat((t.unsqueeze(0), w.unsqueeze(0))) # [2, 2, 2]
"""
A:
tensor([[[ 1, 2],
[ 3, 4]],
[[-1, -2],
[-3, -4]]])
"""
batch = torch.LongTensor([[0,0],[1,1]]) # [2,2]
row = torch.LongTensor([[0,1],[1,1]]) # [2,2]
col = torch.LongTensor([[1,0],[1,0]]) # [2,2]
res = A[batch, row, col]
"""
tensor([[ 2, 3],
[-4, -3]])
"""
In the above code, the index is (0,0,1) (0,1,0) (1,1,1) (1 1 0)
The result is [[A[0,0,1],A[0,1,0]],[A[1,1,1],A[1,1,0]]]
A[tensor]
Assume we only have one tensor as indexing tensor.
As the example above, we only use the batch
tensor to index A
, therefore
we will have a result tensor of shape 2,2,2,2
A[batch]
"""
torch.Size([2, 2, 2, 2])
"""
The principle is simple. For each element in the index, we take the corresponding element in A.
Therefore, the result will be [[A[0],A[0]], [A[1],A[1]]]
.
Each element in A has shape [2,2]
, so the result will be of shape [2,2,2,2]
.