corrected editor bug
This commit is contained in:
parent
d2971cf070
commit
c0043b8997
2 changed files with 21 additions and 27 deletions
18
editor.py
18
editor.py
|
|
@ -4,6 +4,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
|
|
@ -32,17 +33,17 @@ def transform_to_rgb(value):
|
||||||
|
|
||||||
|
|
||||||
def update_screen_from_array(array, screen):
|
def update_screen_from_array(array, screen):
|
||||||
for x in range(array.shape[0]):
|
for y in range(array.shape[0]):
|
||||||
for y in range(array.shape[1]):
|
for x in range(array.shape[1]):
|
||||||
x_transformed = x * 30
|
x_transformed = x * 30
|
||||||
y_transformed = y * 30 + 161
|
y_transformed = y * 30 + 161
|
||||||
|
|
||||||
color = transform_to_rgb(array[x, y])
|
color = transform_to_rgb(array[y, x])
|
||||||
pygame.draw.rect(screen, color, (x_transformed, y_transformed, 30, 30), 0)
|
pygame.draw.rect(screen, color, (x_transformed, y_transformed, 30, 30), 0)
|
||||||
|
|
||||||
|
|
||||||
########################
|
########################
|
||||||
HARDNESS = 0.7
|
HARDNESS = 1
|
||||||
########################
|
########################
|
||||||
|
|
||||||
pygame.init()
|
pygame.init()
|
||||||
|
|
@ -84,12 +85,17 @@ while not stopFlag:
|
||||||
|
|
||||||
for coord in coords:
|
for coord in coords:
|
||||||
if not coord[0] < 0 and not coord[0] > 27 and not coord[1] < 0 and not coord[1] > 27:
|
if not coord[0] < 0 and not coord[0] > 27 and not coord[1] < 0 and not coord[1] > 27:
|
||||||
img_array[coord[0], coord[1]] += HARDNESS
|
if img_array[coord[1], coord[0]] < 1:
|
||||||
|
img_array[coord[1], coord[0]] += HARDNESS
|
||||||
|
|
||||||
# Clear image on right mouse button press
|
# Clear image on right mouse button press
|
||||||
if pygame.mouse.get_pressed()[2] == 1:
|
if pygame.mouse.get_pressed()[2] == 1:
|
||||||
img_array = np.zeros((28, 28))
|
img_array = np.zeros((28, 28))
|
||||||
|
|
||||||
|
if pygame.mouse.get_pressed()[1] == 1:
|
||||||
|
plt.imshow(img_array)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
update_screen_from_array(img_array, screen)
|
update_screen_from_array(img_array, screen)
|
||||||
# Draw vertical lines
|
# Draw vertical lines
|
||||||
for i in range(30, 28 * 30, 30):
|
for i in range(30, 28 * 30, 30):
|
||||||
|
|
@ -101,7 +107,7 @@ while not stopFlag:
|
||||||
|
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
clock.tick(60)
|
clock.tick(60)
|
||||||
np.rot90(img_array, k=1)
|
|
||||||
tensor = torch.from_numpy(img_array).view(1, 28 * 28).float()
|
tensor = torch.from_numpy(img_array).view(1, 28 * 28).float()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prediction = torch.argmax(net(tensor))
|
prediction = torch.argmax(net(tensor))
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,17 @@
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms, datasets
|
from torchvision import transforms, datasets
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
data = datasets.MNIST('../datasets', train=True, download=True,
|
data = datasets.MNIST('../datasets', train=True, download=True,
|
||||||
transform=transforms.Compose([
|
transform=transforms.Compose([
|
||||||
transforms.ToTensor()
|
transforms.ToTensor()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
loader = torch.utils.data.DataLoader(data, batch_size=15, shuffle=False)
|
loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
|
||||||
set = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0, '5': 0, '6': 0, '7': 0, '8': 0, '9': 0}
|
|
||||||
|
|
||||||
for data in loader:
|
for data in loader:
|
||||||
print(data[1].shape)
|
tensor = data[0].view([28, 28])
|
||||||
|
plt.imshow(tensor)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
for _, label in tqdm(loader):
|
|
||||||
set[str(label[0].item())] += 1
|
|
||||||
|
|
||||||
print(set)
|
|
||||||
|
|
||||||
num = 0
|
|
||||||
for x in set:
|
|
||||||
num += set[x]
|
|
||||||
print(num)
|
|
||||||
|
|
||||||
for x in set:
|
|
||||||
set[x] /= num
|
|
||||||
set[x] *= 100
|
|
||||||
print(set)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue