Pytorch Tricks
Published:
How to efficiently get one-hot representation with pytorch? Here’s the solution to it.
======
import torch
import torch.nn as nn
import torch.nn.functonal as F
x = F.softmax(torch.rand(10, 5))
pred = torch.argmax(x, dim=1)
one_hot_pred = torch.zeros_like(x).scatter_(1, pred.unsqueeze(1), 1.)