pytorch-ai/other_scripts/setcounter.py
Clemens-Dautermann d2971cf070 added editor
2020-04-25 16:59:20 +02:00

29 lines
671 B
Python
Executable file

from tqdm import tqdm
import torch
from torchvision import transforms, datasets
data = datasets.MNIST('../datasets', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor()
]))
loader = torch.utils.data.DataLoader(data, batch_size=15, shuffle=False)
set = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0, '5': 0, '6': 0, '7': 0, '8': 0, '9': 0}
for data in loader:
print(data[1].shape)
for _, label in tqdm(loader):
set[str(label[0].item())] += 1
print(set)
num = 0
for x in set:
num += set[x]
print(num)
for x in set:
set[x] /= num
set[x] *= 100
print(set)