Pytorch Tricks

less than 1 minute read

Published:

How to efficiently get one-hot representation with pytorch? Here’s the solution to it.

======

import torch
import torch.nn as nn
import torch.nn.functonal as F

x = F.softmax(torch.rand(10, 5))
pred = torch.argmax(x, dim=1)
one_hot_pred = torch.zeros_like(x).scatter_(1, pred.unsqueeze(1), 1.)

Good luck!