corrected editor bug
This commit is contained in:
parent
d2971cf070
commit
c0043b8997
2 changed files with 21 additions and 27 deletions
26
editor.py
26
editor.py
|
|
@ -4,6 +4,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
|
@ -32,17 +33,17 @@ def transform_to_rgb(value):
|
|||
|
||||
|
||||
def update_screen_from_array(array, screen):
|
||||
for x in range(array.shape[0]):
|
||||
for y in range(array.shape[1]):
|
||||
for y in range(array.shape[0]):
|
||||
for x in range(array.shape[1]):
|
||||
x_transformed = x * 30
|
||||
y_transformed = y * 30 + 161
|
||||
|
||||
color = transform_to_rgb(array[x, y])
|
||||
color = transform_to_rgb(array[y, x])
|
||||
pygame.draw.rect(screen, color, (x_transformed, y_transformed, 30, 30), 0)
|
||||
|
||||
|
||||
########################
|
||||
HARDNESS = 0.7
|
||||
HARDNESS = 1
|
||||
########################
|
||||
|
||||
pygame.init()
|
||||
|
|
@ -84,12 +85,17 @@ while not stopFlag:
|
|||
|
||||
for coord in coords:
|
||||
if not coord[0] < 0 and not coord[0] > 27 and not coord[1] < 0 and not coord[1] > 27:
|
||||
img_array[coord[0], coord[1]] += HARDNESS
|
||||
if img_array[coord[1], coord[0]] < 1:
|
||||
img_array[coord[1], coord[0]] += HARDNESS
|
||||
|
||||
# Clear image on right mouse button press
|
||||
if pygame.mouse.get_pressed()[2] == 1:
|
||||
img_array = np.zeros((28, 28))
|
||||
|
||||
if pygame.mouse.get_pressed()[1] == 1:
|
||||
plt.imshow(img_array)
|
||||
plt.show()
|
||||
|
||||
update_screen_from_array(img_array, screen)
|
||||
# Draw vertical lines
|
||||
for i in range(30, 28 * 30, 30):
|
||||
|
|
@ -101,14 +107,14 @@ while not stopFlag:
|
|||
|
||||
pygame.display.flip()
|
||||
clock.tick(60)
|
||||
np.rot90(img_array, k=1)
|
||||
|
||||
tensor = torch.from_numpy(img_array).view(1, 28 * 28).float()
|
||||
with torch.no_grad():
|
||||
prediction = torch.argmax(net(tensor))
|
||||
|
||||
text = font.render('Prediction: ' + str(prediction.item()), True, WHITE, BLACK)
|
||||
textRect = text.get_rect()
|
||||
textRect.center = (420, 80)
|
||||
screen.blit(text, textRect)
|
||||
text = font.render('Prediction: ' + str(prediction.item()), True, WHITE, BLACK)
|
||||
textRect = text.get_rect()
|
||||
textRect.center = (420, 80)
|
||||
screen.blit(text, textRect)
|
||||
|
||||
pygame.quit()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue