pytorch-ai/TicTacToe_AI/Net/wandb/run-20200128_092953-yloo6l66/diff.patch
2020-01-28 14:45:00 +01:00

67 lines
1.8 KiB
Diff

diff --git a/TicTacToe_AI/Net/pytorch_ai.py b/TicTacToe_AI/Net/pytorch_ai.py
index efea5ae..701918f 100644
--- a/TicTacToe_AI/Net/pytorch_ai.py
+++ b/TicTacToe_AI/Net/pytorch_ai.py
@@ -4,6 +4,9 @@ import torch.optim as optim
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
+import wandb
+
+wandb.init(project="tictactoe")
def to_set(raw_list):
@@ -46,7 +49,7 @@ def buildsets():
testset = to_set(alllines[0:10000])
print('Generating trainset...')
- trainset = to_set(alllines[10001:200000])
+ trainset = to_set(alllines[10001:])
return trainset, testset
@@ -60,6 +63,7 @@ def testnet(net, testset):
if torch.argmax(output) == label[0]:
correct += 1
total += 1
+ wandb.log({'test_accuracy': correct / total})
print("Accuracy: ", round(correct / total, 3))
@@ -79,7 +83,15 @@ class Net(torch.nn.Module):
return F.log_softmax(x, dim=1)
-net = torch.load('./nets/net_3.pt')
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+print('running on %s' % device)
+
+# net = torch.load('./nets/net_3.pt')
+
+net = Net()
+wandb.watch(net)
+
+net.to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001)
@@ -87,13 +99,16 @@ trainset, testset = buildsets()
for epoch in range(100):
print('Epoch: ' + str(epoch))
+ wandb.log({'epoch': epoch})
for X, label in tqdm(trainset):
net.zero_grad()
+ X.to(device)
output = net(X)
+ output.cpu()
loss = F.nll_loss(output.view(1, 10), label[0])
loss.backward()
optimizer.step()
+ wandb.log({'loss': loss})
- print(loss)
- torch.save(net, './nets/net_' + str(epoch + 3) + '.pt')
+ torch.save(net, './nets/gpunets/net_' + str(epoch) + '.pt')
testnet(net, testset)