pytorch-ai/editor.py
2020-04-25 17:59:23 +02:00

120 lines
3.4 KiB
Python

import pygame
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 64)
self.fc2 = nn.Linear(64, 120)
self.fc3 = nn.Linear(120, 120)
self.fc4 = nn.Linear(120, 64)
self.fc5 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = self.fc5(x)
return F.log_softmax(x, dim=1)
def transform_to_rgb(value):
if value < 0 or value > 1:
return tuple((255, 255, 255))
else:
return tuple((value * 255, value * 255, value * 255))
def update_screen_from_array(array, screen):
for y in range(array.shape[0]):
for x in range(array.shape[1]):
x_transformed = x * 30
y_transformed = y * 30 + 161
color = transform_to_rgb(array[y, x])
pygame.draw.rect(screen, color, (x_transformed, y_transformed, 30, 30), 0)
########################
HARDNESS = 1
########################
pygame.init()
screen = pygame.display.set_mode((28 * 30, 1000))
clock = pygame.time.Clock()
stopFlag = False
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
img_array = np.zeros((28, 28))
font = pygame.font.SysFont('linuxbiolinum', 80)
net = torch.load('./nets/net_gpu_large_batch_199.pt')
while not stopFlag:
# get events
for event in pygame.event.get():
if event.type == pygame.QUIT:
stopFlag = True
pygame.draw.line(screen, WHITE, [0, 160], [pygame.display.get_surface().get_size()[0], 160], 1) # Trennlinie
if pygame.mouse.get_pressed()[0] == 1:
pos = pygame.mouse.get_pos()
# transform mouse positions to array indices
x_transformed = math.floor((pos[0] / 30))
y_transformed = math.floor((pos[1] - 161) / 30)
if x_transformed < 0:
x_transformed = 0
if y_transformed < 0:
y_transformed = 0
else:
pass
coords = [(x_transformed, y_transformed), (x_transformed + 1, y_transformed),
(x_transformed - 1, y_transformed), (x_transformed, y_transformed + 1),
(x_transformed, y_transformed - 1)]
for coord in coords:
if not coord[0] < 0 and not coord[0] > 27 and not coord[1] < 0 and not coord[1] > 27:
if img_array[coord[1], coord[0]] < 1:
img_array[coord[1], coord[0]] += HARDNESS
# Clear image on right mouse button press
if pygame.mouse.get_pressed()[2] == 1:
img_array = np.zeros((28, 28))
if pygame.mouse.get_pressed()[1] == 1:
plt.imshow(img_array)
plt.show()
update_screen_from_array(img_array, screen)
# Draw vertical lines
for i in range(30, 28 * 30, 30):
pygame.draw.line(screen, WHITE, [i, 161], [i, 1000], 1)
# Draw horizontal lines
for i in range(189, 1000, 30):
pygame.draw.line(screen, WHITE, [0, i], [28 * 30, i], 1)
pygame.display.flip()
clock.tick(60)
tensor = torch.from_numpy(img_array).view(1, 28 * 28).float()
with torch.no_grad():
prediction = torch.argmax(net(tensor))
text = font.render('Prediction: ' + str(prediction.item()), True, WHITE, BLACK)
textRect = text.get_rect()
textRect.center = (420, 80)
screen.blit(text, textRect)
pygame.quit()