corrected editor bug

This commit is contained in:
Clemens-Dautermann 2020-04-25 17:59:23 +02:00
parent d2971cf070
commit c0043b8997
2 changed files with 21 additions and 27 deletions

View file

@ -4,6 +4,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Net(nn.Module):
@ -32,17 +33,17 @@ def transform_to_rgb(value):
def update_screen_from_array(array, screen):
for x in range(array.shape[0]):
for y in range(array.shape[1]):
for y in range(array.shape[0]):
for x in range(array.shape[1]):
x_transformed = x * 30
y_transformed = y * 30 + 161
color = transform_to_rgb(array[x, y])
color = transform_to_rgb(array[y, x])
pygame.draw.rect(screen, color, (x_transformed, y_transformed, 30, 30), 0)
########################
HARDNESS = 0.7
HARDNESS = 1
########################
pygame.init()
@ -84,12 +85,17 @@ while not stopFlag:
for coord in coords:
if not coord[0] < 0 and not coord[0] > 27 and not coord[1] < 0 and not coord[1] > 27:
img_array[coord[0], coord[1]] += HARDNESS
if img_array[coord[1], coord[0]] < 1:
img_array[coord[1], coord[0]] += HARDNESS
# Clear image on right mouse button press
if pygame.mouse.get_pressed()[2] == 1:
img_array = np.zeros((28, 28))
if pygame.mouse.get_pressed()[1] == 1:
plt.imshow(img_array)
plt.show()
update_screen_from_array(img_array, screen)
# Draw vertical lines
for i in range(30, 28 * 30, 30):
@ -101,7 +107,7 @@ while not stopFlag:
pygame.display.flip()
clock.tick(60)
np.rot90(img_array, k=1)
tensor = torch.from_numpy(img_array).view(1, 28 * 28).float()
with torch.no_grad():
prediction = torch.argmax(net(tensor))

View file

@ -1,29 +1,17 @@
from tqdm import tqdm
import torch
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
data = datasets.MNIST('../datasets', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor()
]))
loader = torch.utils.data.DataLoader(data, batch_size=15, shuffle=False)
set = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0, '5': 0, '6': 0, '7': 0, '8': 0, '9': 0}
loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
for data in loader:
print(data[1].shape)
tensor = data[0].view([28, 28])
plt.imshow(tensor)
plt.show()
for _, label in tqdm(loader):
set[str(label[0].item())] += 1
print(set)
num = 0
for x in set:
num += set[x]
print(num)
for x in set:
set[x] /= num
set[x] *= 100
print(set)