diff --git a/editor.py b/editor.py index 1211bc7..c4cb577 100644 --- a/editor.py +++ b/editor.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import matplotlib.pyplot as plt class Net(nn.Module): @@ -32,17 +33,17 @@ def transform_to_rgb(value): def update_screen_from_array(array, screen): - for x in range(array.shape[0]): - for y in range(array.shape[1]): + for y in range(array.shape[0]): + for x in range(array.shape[1]): x_transformed = x * 30 y_transformed = y * 30 + 161 - color = transform_to_rgb(array[x, y]) + color = transform_to_rgb(array[y, x]) pygame.draw.rect(screen, color, (x_transformed, y_transformed, 30, 30), 0) ######################## -HARDNESS = 0.7 +HARDNESS = 1 ######################## pygame.init() @@ -84,12 +85,17 @@ while not stopFlag: for coord in coords: if not coord[0] < 0 and not coord[0] > 27 and not coord[1] < 0 and not coord[1] > 27: - img_array[coord[0], coord[1]] += HARDNESS + if img_array[coord[1], coord[0]] < 1: + img_array[coord[1], coord[0]] += HARDNESS # Clear image on right mouse button press if pygame.mouse.get_pressed()[2] == 1: img_array = np.zeros((28, 28)) + if pygame.mouse.get_pressed()[1] == 1: + plt.imshow(img_array) + plt.show() + update_screen_from_array(img_array, screen) # Draw vertical lines for i in range(30, 28 * 30, 30): @@ -101,14 +107,14 @@ while not stopFlag: pygame.display.flip() clock.tick(60) - np.rot90(img_array, k=1) + tensor = torch.from_numpy(img_array).view(1, 28 * 28).float() with torch.no_grad(): prediction = torch.argmax(net(tensor)) - text = font.render('Prediction: ' + str(prediction.item()), True, WHITE, BLACK) - textRect = text.get_rect() - textRect.center = (420, 80) - screen.blit(text, textRect) + text = font.render('Prediction: ' + str(prediction.item()), True, WHITE, BLACK) + textRect = text.get_rect() + textRect.center = (420, 80) + screen.blit(text, textRect) pygame.quit() diff --git a/other_scripts/setcounter.py b/other_scripts/setcounter.py index e9eb00c..79133c0 100755 --- a/other_scripts/setcounter.py +++ b/other_scripts/setcounter.py @@ -1,29 +1,17 @@ from tqdm import tqdm import torch from torchvision import transforms, datasets +import matplotlib.pyplot as plt data = datasets.MNIST('../datasets', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor() ])) -loader = torch.utils.data.DataLoader(data, batch_size=15, shuffle=False) -set = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0, '5': 0, '6': 0, '7': 0, '8': 0, '9': 0} +loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False) for data in loader: - print(data[1].shape) + tensor = data[0].view([28, 28]) + plt.imshow(tensor) + plt.show() -for _, label in tqdm(loader): - set[str(label[0].item())] += 1 - -print(set) - -num = 0 -for x in set: - num += set[x] -print(num) - -for x in set: - set[x] /= num - set[x] *= 100 -print(set)