diff --git a/TicTacToe_AI/Net/datasets/MNIST/processed/test.pt b/TicTacToe_AI/Net/datasets/MNIST/processed/test.pt new file mode 100644 index 0000000..43775b9 Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/processed/test.pt differ diff --git a/TicTacToe_AI/Net/datasets/MNIST/processed/training.pt b/TicTacToe_AI/Net/datasets/MNIST/processed/training.pt new file mode 100644 index 0000000..10d193d Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/processed/training.pt differ diff --git a/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-images-idx3-ubyte b/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-images-idx3-ubyte differ diff --git a/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz b/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000..5ace8ea Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz differ diff --git a/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-labels-idx1-ubyte b/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz b/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000..a7e1415 Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/TicTacToe_AI/Net/datasets/MNIST/raw/train-images-idx3-ubyte b/TicTacToe_AI/Net/datasets/MNIST/raw/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/raw/train-images-idx3-ubyte differ diff --git a/TicTacToe_AI/Net/datasets/MNIST/raw/train-images-idx3-ubyte.gz b/TicTacToe_AI/Net/datasets/MNIST/raw/train-images-idx3-ubyte.gz new file mode 100644 index 0000000..b50e4b6 Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/raw/train-images-idx3-ubyte.gz differ diff --git a/TicTacToe_AI/Net/datasets/MNIST/raw/train-labels-idx1-ubyte b/TicTacToe_AI/Net/datasets/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/TicTacToe_AI/Net/datasets/MNIST/raw/train-labels-idx1-ubyte.gz b/TicTacToe_AI/Net/datasets/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000..707a576 Binary files /dev/null and b/TicTacToe_AI/Net/datasets/MNIST/raw/train-labels-idx1-ubyte.gz differ diff --git a/TicTacToe_AI/Net/nets/net_0.pt b/TicTacToe_AI/Net/nets/net_0.pt new file mode 100644 index 0000000..4c970bd Binary files /dev/null and b/TicTacToe_AI/Net/nets/net_0.pt differ diff --git a/TicTacToe_AI/Net/nets/net_1.pt b/TicTacToe_AI/Net/nets/net_1.pt new file mode 100644 index 0000000..c7b5e1c Binary files /dev/null and b/TicTacToe_AI/Net/nets/net_1.pt differ diff --git a/TicTacToe_AI/Net/nets/net_2.pt b/TicTacToe_AI/Net/nets/net_2.pt new file mode 100644 index 0000000..9177878 Binary files /dev/null and b/TicTacToe_AI/Net/nets/net_2.pt differ diff --git a/TicTacToe_AI/Net/pytorch_ai.py b/TicTacToe_AI/Net/pytorch_ai.py index 72e19e2..c111016 100644 --- a/TicTacToe_AI/Net/pytorch_ai.py +++ b/TicTacToe_AI/Net/pytorch_ai.py @@ -1,5 +1,8 @@ import random import torch +import torch.optim as optim +from torch import nn +import torch.nn.functional as F from tqdm import tqdm @@ -10,9 +13,11 @@ def to_set(raw_list): raw_board, raw_label = line.split('|')[0], line.split('|')[1] # convert string label to tensor - label = torch.zeros([1, 9]) + label = torch.zeros([1, 1], dtype=torch.long) if not (int(raw_label) is -1): - label[0][int(raw_label)] = 1 + label[0][0] = int(raw_label) + else: + label[0][0] = 9 # convert board to tensor raw_board = raw_board.split(',') @@ -30,13 +35,65 @@ def to_set(raw_list): return out_set -with open('boards.bds', 'r') as infile: - print('Loading file...') - alllines = infile.readlines() - random.shuffle(alllines) +def buildsets(): + with open('boards.bds', 'r') as infile: + print('Loading file...') + alllines = infile.readlines() + print(len(alllines)) + random.shuffle(alllines) - print('Generating testset...') - testset = to_set(alllines[0:50000]) + print('Generating testset...') + testset = to_set(alllines[0:10000]) - print('Generating trainset...') - trainset = to_set(alllines[50001:]) + print('Generating trainset...') + trainset = to_set(alllines[10001:200000]) + + return trainset, testset + + +def testnet(net, testset): + correct = 0 + total = 0 + with torch.no_grad(): + for X, label in testset: + output = net(X) + if torch.argmax(output) == label[0]: + correct += 1 + total += 1 + print("Accuracy: ", round(correct / total, 3)) + + +class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(9, 9) + self.fc2 = nn.Linear(9, 20) + self.fc3 = nn.Linear(20, 50) + self.fc4 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = F.relu(self.fc3(x)) + x = self.fc4(x) + return F.log_softmax(x, dim=1) + + +net = Net() + +optimizer = optim.Adam(net.parameters(), lr=0.001) + +trainset, testset = buildsets() + +for epoch in range(100): + print('Epoch: ' + str(epoch)) + for X, label in tqdm(trainset): + net.zero_grad() + output = net(X) + loss = F.nll_loss(output.view(1, 10), label[0]) + loss.backward() + optimizer.step() + + print(loss) + torch.save(net, './nets/net_' + str(epoch) + '.pt') + testnet(net, testset) diff --git a/mnist_classifier.py b/mnist_classifier.py index f328aad..8eac991 100644 --- a/mnist_classifier.py +++ b/mnist_classifier.py @@ -14,7 +14,7 @@ test = datasets.MNIST('./datasets', train=False, download=True, transforms.ToTensor() ])) -trainset = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True) +trainset = torch.utils.data.DataLoader(train, batch_size=15, shuffle=True) testset = torch.utils.data.DataLoader(test, batch_size=10, shuffle=False)