二维示例:

import torch 
a=torch.tensor([[1,2,3],[-1,3,2]])
print(a.shape)
b=torch.argmax(a,dim=0) #dim=0就是代表比较第0维 dim=1表示比较第1维
print(b)
 
#output
#torch.Size([2, 3])
#tensor([0, 1, 0]) 
 
 
import torch 
a=torch.tensor([[1,2,3],[-1,3,2]])
print(a.shape)
b=torch.argmax(a,dim=1)
print(b)
#output 
#torch.Size([2, 3])
#tensor([2, 1])

三维示例

import torch 
a=torch.tensor([[[1,2,3],[-1,3,2]],[[0,4,-2],[2,6,0]]])
print(a.shape)
b=torch.argmax(a,dim=0)
print(b)
 
'''
torch.Size([2, 2, 3])
tensor([[0, 1, 0],
        [1, 1, 0]])
'''
 
import torch 
a=torch.tensor([[[1,2,3],[-1,3,2]],[[0,4,-2],[2,6,0]]])
print(a.shape)
b=torch.argmax(a,dim=1)
print(b)
 
'''
torch.Size([2, 2, 3])
tensor([[0, 1, 0],
        [1, 1, 1]])
'''
 
 
import torch 
a=torch.tensor([[[1,2,3],[-1,3,2]],[[0,4,-2],[2,6,0]]])
print(a.shape)
b=torch.argmax(a,dim=2)
print(b)
 
'''
torch.Size([2, 2, 3])
tensor([[2, 1],
        [1, 1]])
'''

从三维来看

dim=0就是最外层的进行比较
即 [[1,2,3],[-1,3,2]] 和 [[0,4,-2],[2,6,0]]来比较

注:下面的小括号都代表着取最大值的下标例如: (5,6) ---> 1

也就是[[(1,0),(2,4),(3,-2)],[(-1,2),(3,6),(2,0)]] ---> [[0,1,0],[1,1,0]]

dim=1就是里面一层进行比较
即 [[1,2,3],[-1,3,2]] 比较 和 [[0,4,-2],[2,6,0]] 比较

也就是 [(1,-1),(2,3),(3,2)] -----> [0,1,0] [(0,2),(4,6),(-2,0)] -----> [1,1,1]

两个拼起来: [[0,1,0],[1,1,1]]

dim=2就是最里面一层了
即[1,2,3] 和 [-1,3,2] 和 [0,4,-2] 和 [2,6,0]

也就是(1,2,3) ---> 2 和 (-1,3,2) ---> 1 和 (0,4,-2) ---> 1 和 (2,6,0) ---> 1

合起来[[2,1],[1,1]]

Last modification:February 4th, 2021 at 03:56 pm
如果觉得我的文章对你有用,请随意赞赏