I have a three-dimensional sensor in Pythorch, and I would like to set the highest value among dim=1 in the sensor to 1 and the non-highest value to 0. What should I do?
I look forward to your kind cooperation.
I thought about torch.argmax, torch.max, slice processing, etc., but I didn't understand...
import torch
x = torch.randn(4,3,1)
print(x)
US>tensor([[0.4082],
[2.0627],
[0.7252]],
[[0.7946],
[0.2679],
[-0.4184]],
[[0.3380],
[0.8403],
[-1.7227]],
[[-1.1250],
[-1.8144],
[1.4441]]])
print(x)
US>tensor([[0],
[1],
[0]],
[[1],
[0],
[0]],
[[0],
[1],
[0]],
[[0],
[0],
[1]]])
I have used the answers from other sites in the comments to your question as reference.
(Correction) It's too late, but it was simpler to use stack
.
import torch
torch.manual_seed(17)
x = torch.randn(4,3,1)
y=torch.stack ([(a==a.max()) .int() for ainx])
print(f'{x}\n\n{y}')
tensor([[-1.4135],
[ 0.2336],
[ 0.0340]],
[[ 0.3499],
[-0.0145],
[-0.6124]],
[[-1.1835],
[-1.4831],
[ 1.8004]],
[[ 0.0096],
[ 0.1534],
[-2.6631]]])
US>tensor([[0],
[1],
[0]],
[[1],
[0],
[0]],
[[0],
[0],
[1]],
[[0],
[1],
[0]]], dtype=torch.int32)
© 2024 OneMinuteCode. All rights reserved.