corrected editor bug

This commit is contained in:
Clemens-Dautermann 2020-04-25 17:59:23 +02:00
parent d2971cf070
commit c0043b8997
2 changed files with 21 additions and 27 deletions

View file

@ -4,6 +4,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import matplotlib.pyplot as plt
class Net(nn.Module): class Net(nn.Module):
@ -32,17 +33,17 @@ def transform_to_rgb(value):
def update_screen_from_array(array, screen): def update_screen_from_array(array, screen):
for x in range(array.shape[0]): for y in range(array.shape[0]):
for y in range(array.shape[1]): for x in range(array.shape[1]):
x_transformed = x * 30 x_transformed = x * 30
y_transformed = y * 30 + 161 y_transformed = y * 30 + 161
color = transform_to_rgb(array[x, y]) color = transform_to_rgb(array[y, x])
pygame.draw.rect(screen, color, (x_transformed, y_transformed, 30, 30), 0) pygame.draw.rect(screen, color, (x_transformed, y_transformed, 30, 30), 0)
######################## ########################
HARDNESS = 0.7 HARDNESS = 1
######################## ########################
pygame.init() pygame.init()
@ -84,12 +85,17 @@ while not stopFlag:
for coord in coords: for coord in coords:
if not coord[0] < 0 and not coord[0] > 27 and not coord[1] < 0 and not coord[1] > 27: if not coord[0] < 0 and not coord[0] > 27 and not coord[1] < 0 and not coord[1] > 27:
img_array[coord[0], coord[1]] += HARDNESS if img_array[coord[1], coord[0]] < 1:
img_array[coord[1], coord[0]] += HARDNESS
# Clear image on right mouse button press # Clear image on right mouse button press
if pygame.mouse.get_pressed()[2] == 1: if pygame.mouse.get_pressed()[2] == 1:
img_array = np.zeros((28, 28)) img_array = np.zeros((28, 28))
if pygame.mouse.get_pressed()[1] == 1:
plt.imshow(img_array)
plt.show()
update_screen_from_array(img_array, screen) update_screen_from_array(img_array, screen)
# Draw vertical lines # Draw vertical lines
for i in range(30, 28 * 30, 30): for i in range(30, 28 * 30, 30):
@ -101,7 +107,7 @@ while not stopFlag:
pygame.display.flip() pygame.display.flip()
clock.tick(60) clock.tick(60)
np.rot90(img_array, k=1)
tensor = torch.from_numpy(img_array).view(1, 28 * 28).float() tensor = torch.from_numpy(img_array).view(1, 28 * 28).float()
with torch.no_grad(): with torch.no_grad():
prediction = torch.argmax(net(tensor)) prediction = torch.argmax(net(tensor))

View file

@ -1,29 +1,17 @@
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from torchvision import transforms, datasets from torchvision import transforms, datasets
import matplotlib.pyplot as plt
data = datasets.MNIST('../datasets', train=True, download=True, data = datasets.MNIST('../datasets', train=True, download=True,
transform=transforms.Compose([ transform=transforms.Compose([
transforms.ToTensor() transforms.ToTensor()
])) ]))
loader = torch.utils.data.DataLoader(data, batch_size=15, shuffle=False) loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
set = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0, '5': 0, '6': 0, '7': 0, '8': 0, '9': 0}
for data in loader: for data in loader:
print(data[1].shape) tensor = data[0].view([28, 28])
plt.imshow(tensor)
plt.show()
for _, label in tqdm(loader):
set[str(label[0].item())] += 1
print(set)
num = 0
for x in set:
num += set[x]
print(num)
for x in set:
set[x] /= num
set[x] *= 100
print(set)