diff --git a/TicTacToe_AI/Net/pytorch_ai.py b/TicTacToe_AI/Net/pytorch_ai.py index efea5ae..701918f 100644 --- a/TicTacToe_AI/Net/pytorch_ai.py +++ b/TicTacToe_AI/Net/pytorch_ai.py @@ -4,6 +4,9 @@ import torch.optim as optim from torch import nn import torch.nn.functional as F from tqdm import tqdm +import wandb + +wandb.init(project="tictactoe") def to_set(raw_list): @@ -46,7 +49,7 @@ def buildsets(): testset = to_set(alllines[0:10000]) print('Generating trainset...') - trainset = to_set(alllines[10001:200000]) + trainset = to_set(alllines[10001:]) return trainset, testset @@ -60,6 +63,7 @@ def testnet(net, testset): if torch.argmax(output) == label[0]: correct += 1 total += 1 + wandb.log({'test_accuracy': correct / total}) print("Accuracy: ", round(correct / total, 3)) @@ -79,7 +83,15 @@ class Net(torch.nn.Module): return F.log_softmax(x, dim=1) -net = torch.load('./nets/net_3.pt') +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print('running on %s' % device) + +# net = torch.load('./nets/net_3.pt') + +net = Net() +wandb.watch(net) + +net.to(device) optimizer = optim.Adam(net.parameters(), lr=0.001) @@ -87,13 +99,16 @@ trainset, testset = buildsets() for epoch in range(100): print('Epoch: ' + str(epoch)) + wandb.log({'epoch': epoch}) for X, label in tqdm(trainset): net.zero_grad() + X.to(device) output = net(X) + output.cpu() loss = F.nll_loss(output.view(1, 10), label[0]) loss.backward() optimizer.step() + wandb.log({'loss': loss}) - print(loss) - torch.save(net, './nets/net_' + str(epoch + 3) + '.pt') + torch.save(net, './nets/gpunets/net_' + str(epoch) + '.pt') testnet(net, testset)