Implemented batching for TicTacToe AI

This commit is contained in:
Clemens Dautermann 2020-01-28 14:45:00 +01:00
parent 55cff9b18f
commit 56ee2635b5
96 changed files with 8426 additions and 7 deletions

View file

@ -0,0 +1,9 @@
wandb_version: 1
_wandb:
desc: null
value:
cli_version: 0.8.22
framework: torch
is_jupyter_run: false
python_version: 3.7.5

View file

@ -0,0 +1,135 @@
diff --git a/TicTacToe_AI/Net/pytorch_ai.py b/TicTacToe_AI/Net/pytorch_ai.py
index efea5ae..ba862ae 100644
--- a/TicTacToe_AI/Net/pytorch_ai.py
+++ b/TicTacToe_AI/Net/pytorch_ai.py
@@ -4,6 +4,11 @@ import torch.optim as optim
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
+import wandb
+
+wandb.init(project="tictactoe")
+
+BATCH_SIZE = 3
def to_set(raw_list):
@@ -35,6 +40,40 @@ def to_set(raw_list):
return out_set
+def to_batched_set(raw_list):
+ counter = 0
+ out_set = []
+ boardtensor = torch.zeros((BATCH_SIZE, 1, 9))
+ labeltensor = torch.zeros(BATCH_SIZE, dtype=torch.long)
+ for line in tqdm(raw_list):
+ line = line.replace('\n', '')
+ raw_board, raw_label = line.split('|')[0], line.split('|')[1]
+
+ if not (int(raw_label) is -1):
+ labeltensor[counter] = int(raw_label)
+ else:
+ labeltensor[counter] = 9
+
+ raw_board = raw_board.split(',')
+ for n, block in enumerate(raw_board):
+ if int(block) is -1:
+ boardtensor[counter][0][n] = 0
+ elif int(block) is 0:
+ boardtensor[counter][0][n] = 0.5
+ elif int(block) is 1:
+ boardtensor[counter][0][n] = 1
+
+ if counter == (BATCH_SIZE - 1):
+ out_set.append([boardtensor, labeltensor])
+ boardtensor = torch.zeros((BATCH_SIZE, 1, 9))
+ labeltensor = torch.zeros(BATCH_SIZE, dtype=torch.long)
+ counter = 0
+ else:
+ counter += 1
+
+ return out_set
+
+
def buildsets():
with open('boards.bds', 'r') as infile:
print('Loading file...')
@@ -43,10 +82,10 @@ def buildsets():
random.shuffle(alllines)
print('Generating testset...')
- testset = to_set(alllines[0:10000])
+ testset = to_batched_set(alllines[0:10000])
print('Generating trainset...')
- trainset = to_set(alllines[10001:200000])
+ trainset = to_batched_set(alllines[10001:20000])
return trainset, testset
@@ -60,6 +99,7 @@ def testnet(net, testset):
if torch.argmax(output) == label[0]:
correct += 1
total += 1
+ wandb.log({'test_accuracy': correct / total})
print("Accuracy: ", round(correct / total, 3))
@@ -79,7 +119,15 @@ class Net(torch.nn.Module):
return F.log_softmax(x, dim=1)
-net = torch.load('./nets/net_3.pt')
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+print('running on %s' % device)
+
+# net = torch.load('./nets/net_3.pt')
+
+net = Net()
+wandb.watch(net)
+
+net.to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001)
@@ -87,13 +135,18 @@ trainset, testset = buildsets()
for epoch in range(100):
print('Epoch: ' + str(epoch))
+ wandb.log({'epoch': epoch})
for X, label in tqdm(trainset):
net.zero_grad()
+ X.to(device)
output = net(X)
- loss = F.nll_loss(output.view(1, 10), label[0])
+ output.cpu()
+ print(output)
+ print(label)
+ loss = F.nll_loss(output.view(-1, 10), label)
loss.backward()
optimizer.step()
+ wandb.log({'loss': loss})
- print(loss)
- torch.save(net, './nets/net_' + str(epoch + 3) + '.pt')
+ torch.save(net, './nets/gpunets/net_' + str(epoch) + '.pt')
testnet(net, testset)
diff --git a/other_scripts/setcounter.py b/other_scripts/setcounter.py
index 9735f20..e9eb00c 100644
--- a/other_scripts/setcounter.py
+++ b/other_scripts/setcounter.py
@@ -7,9 +7,12 @@ data = datasets.MNIST('../datasets', train=True, download=True,
transforms.ToTensor()
]))
-loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
+loader = torch.utils.data.DataLoader(data, batch_size=15, shuffle=False)
set = {'0': 0, '1': 0, '2': 0, '3': 0, '4': 0, '5': 0, '6': 0, '7': 0, '8': 0, '9': 0}
+for data in loader:
+ print(data[1].shape)
+
for _, label in tqdm(loader):
set[str(label[0].item())] += 1

View file

@ -0,0 +1,126 @@
running on cpu
Loading file...
986410
Generating testset...
0%| | 0/10000 [00:00<?, ?it/s] 4%|█▎ | 351/10000 [00:00<00:02, 3501.36it/s] 7%|██▌ | 702/10000 [00:00<00:02, 3503.08it/s] 11%|███▊ | 1055/10000 [00:00<00:02, 3508.80it/s] 14%|█████ | 1414/10000 [00:00<00:02, 3530.39it/s] 18%|██████▍ | 1783/10000 [00:00<00:02, 3574.41it/s] 21%|███████▌ | 2110/10000 [00:00<00:02, 3477.22it/s] 25%|████████▉ | 2477/10000 [00:00<00:02, 3532.52it/s] 28%|██████████▏ | 2843/10000 [00:00<00:02, 3567.42it/s] 32%|███████████▌ | 3208/10000 [00:00<00:01, 3591.39it/s] 36%|████████████▊ | 3576/10000 [00:01<00:01, 3616.43it/s] 39%|██████████████▏ | 3939/10000 [00:01<00:01, 3618.86it/s] 43%|███████████████▍ | 4299/10000 [00:01<00:01, 3611.50it/s] 47%|████████████████▊ | 4669/10000 [00:01<00:01, 3637.32it/s] 50%|██████████████████ | 5034/10000 [00:01<00:01, 3638.78it/s] 54%|███████████████████▍ | 5396/10000 [00:01<00:01, 3626.96it/s] 58%|████████████████████▋ | 5758/10000 [00:01<00:02, 2008.02it/s] 61%|██████████████████████ | 6123/10000 [00:01<00:01, 2320.98it/s] 65%|███████████████████████▎ | 6490/10000 [00:02<00:01, 2608.03it/s] 68%|████████████████████████▌ | 6819/10000 [00:02<00:01, 2780.02it/s] 72%|█████████████████████████▊ | 7181/10000 [00:02<00:00, 2986.41it/s] 75%|███████████████████████████▏ | 7544/10000 [00:02<00:00, 3153.54it/s] 79%|████████████████████████████▍ | 7911/10000 [00:02<00:00, 3290.70it/s] 83%|█████████████████████████████▊ | 8280/10000 [00:02<00:00, 3400.46it/s] 87%|███████████████████████████████▏ | 8651/10000 [00:02<00:00, 3486.08it/s] 90%|████████████████████████████████▍ | 9021/10000 [00:02<00:00, 3545.16it/s] 94%|█████████████████████████████████▊ | 9390/10000 [00:02<00:00, 3587.24it/s] 98%|███████████████████████████████████▏| 9760/10000 [00:02<00:00, 3618.90it/s] 100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 3290.24it/s]
Generating trainset...
0%| | 0/9999 [00:00<?, ?it/s] 4%|█▎ | 350/9999 [00:00<00:02, 3495.50it/s] 7%|██▋ | 715/9999 [00:00<00:02, 3538.49it/s] 11%|████ | 1081/9999 [00:00<00:02, 3573.92it/s] 14%|█████▎ | 1448/9999 [00:00<00:02, 3599.96it/s] 18%|██████▋ | 1819/9999 [00:00<00:02, 3629.98it/s] 22%|████████ | 2186/9999 [00:00<00:02, 3640.84it/s] 26%|█████████▍ | 2555/9999 [00:00<00:02, 3655.21it/s] 29%|██████████▊ | 2923/9999 [00:00<00:01, 3661.51it/s] 33%|████████████▏ | 3292/9999 [00:00<00:01, 3667.77it/s] 37%|█████████████▌ | 3661/9999 [00:01<00:01, 3672.84it/s] 40%|██████████████▊ | 4019/9999 [00:01<00:01, 3612.48it/s] 44%|████████████████▏ | 4384/9999 [00:01<00:01, 3622.65it/s] 48%|█████████████████▌ | 4752/9999 [00:01<00:01, 3638.28it/s] 51%|██████████████████▉ | 5114/9999 [00:01<00:01, 3631.61it/s] 55%|████████████████████▎ | 5475/9999 [00:01<00:01, 3592.49it/s] 58%|█████████████████████▌ | 5836/9999 [00:01<00:01, 3596.28it/s] 62%|██████████████████████▉ | 6195/9999 [00:01<00:01, 3556.03it/s] 66%|████████████████████████▎ | 6561/9999 [00:01<00:00, 3586.47it/s] 69%|█████████████████████████▋ | 6931/9999 [00:01<00:00, 3618.99it/s] 73%|██████████████████████████▉ | 7293/9999 [00:02<00:00, 3606.47it/s] 77%|████████████████████████████▎ | 7662/9999 [00:02<00:00, 3630.22it/s] 80%|█████████████████████████████▋ | 8029/9999 [00:02<00:00, 3640.63it/s] 84%|███████████████████████████████ | 8399/9999 [00:02<00:00, 3656.62it/s] 88%|████████████████████████████████▍ | 8766/9999 [00:02<00:00, 3659.70it/s] 91%|█████████████████████████████████▊ | 9133/9999 [00:02<00:00, 3662.23it/s] 95%|███████████████████████████████████▏ | 9500/9999 [00:02<00:00, 3636.64it/s] 99%|████████████████████████████████████▌| 9865/9999 [00:02<00:00, 3637.81it/s] 100%|█████████████████████████████████████| 9999/9999 [00:02<00:00, 3630.58it/s]
Epoch: 0
0%| | 0/3333 [00:00<?, ?it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([2, 7, 8])
0%| | 1/3333 [00:00<14:08, 3.93it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([9, 9, 3])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([3, 9, 9])
0%| | 3/3333 [00:00<11:33, 4.80it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([9, 1, 4])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([4, 2, 9])
0%| | 5/3333 [00:00<09:44, 5.69it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([6, 4, 9])
0%| | 6/3333 [00:00<08:30, 6.52it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([9, 9, 1])
0%| | 7/3333 [00:00<07:36, 7.28it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([9, 7, 9])
0%| | 8/3333 [00:00<06:59, 7.93it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([1, 3, 0])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([5, 2, 9])
0%| | 10/3333 [00:01<06:33, 8.45it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([9, 3, 9])
0%|▏ | 11/3333 [00:01<06:14, 8.86it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([5, 4, 0])
0%|▏ | 12/3333 [00:01<06:01, 9.17it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([7, 9, 3])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([9, 9, 1])
0%|▏ | 14/3333 [00:01<05:52, 9.41it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([9, 9, 3])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)
tensor([9, 5, 0])
0%|▏ | 16/3333 [00:01<05:46, 9.58it/s]tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
grad_fn=<LogSoftmaxBackward>)

View file

@ -0,0 +1,109 @@
apturl==0.5.2
argh==0.26.2
asn1crypto==0.24.0
bcrypt==3.1.6
binwalk==2.1.2
blinker==1.4
brlapi==0.6.7
certifi==2018.8.24
chardet==3.0.4
click==7.0
command-not-found==0.3
configparser==4.0.2
cryptography==2.6.1
cupshelpers==1.0
cycler==0.10.0
dbus-python==1.2.12
decorator==4.3.0
defer==1.0.6
distro-info==0.21ubuntu4
distro==1.3.0
docker-pycreds==0.4.0
duplicity==0.8.4
entrypoints==0.3
fasteners==0.12.0
future==0.16.0
gitdb2==2.0.6
gitpython==3.0.5
gql==0.2.0
graphql-core==1.1
httplib2==0.11.3
idna==2.6
keyring==18.0.1
keyrings.alt==3.1.1
kiwisolver==1.0.1
language-selector==0.1
launchpadlib==1.10.7
lazr.restfulclient==0.14.2
lazr.uri==1.0.3
lockfile==0.12.2
louis==3.10.0
macaroonbakery==1.2.3
mako==1.0.7
markupsafe==1.1.0
matplotlib==3.0.2
monotonic==1.5
netifaces==0.10.4
numpy==1.16.2
nvidia-ml-py3==7.352.0
oauth==1.0.1
oauthlib==2.1.0
olefile==0.46
paramiko==2.6.0
pathtools==0.1.2
pexpect==4.6.0
pillow==6.1.0
pip==18.1
promise==2.3
protobuf==3.6.1
psutil==5.6.7
pycairo==1.16.2
pycrypto==2.6.1
pycups==1.9.73
pygments==2.3.1
pygobject==3.34.0
pyjwt==1.7.0
pymacaroons==0.13.0
pynacl==1.3.0
pyopengl==3.1.0
pyparsing==2.2.0
pyqt5==5.12.3
pyqtgraph==0.11.0.dev0
pyrfc3339==1.1
python-apt==1.9.0+ubuntu1.3
python-dateutil==2.7.3
python-debian==0.1.36
pytz==2019.2
pyxdg==0.25
pyyaml==5.1.2
reportlab==3.5.23
requests-unixsocket==0.1.5
requests==2.21.0
scipy==1.2.2
secretstorage==2.3.1
sentry-sdk==0.14.0
setuptools==41.1.0
shortuuid==0.5.0
simplejson==3.16.0
sip==4.19.18
six==1.12.0
smmap2==2.0.5
subprocess32==3.5.4
system-service==0.3
systemd-python==234
torch==1.3.1+cpu
torchvision==0.4.2+cpu
tqdm==4.41.0
ubuntu-advantage-tools==19.5
ubuntu-drivers-common==0.0.0
ufw==0.36
unattended-upgrades==0.1
urllib3==1.24.1
usb-creator==0.3.7
virtualenv==15.1.0
wadllib==1.3.3
wandb==0.8.22
watchdog==0.9.0
wheel==0.32.3
xkit==0.0.0
zope.interface==4.3.2

View file

@ -0,0 +1 @@
{"system.cpu": 49.09, "system.memory": 48.41, "system.disk": 8.1, "system.proc.memory.availableMB": 3974.46, "system.proc.memory.rssMB": 165.72, "system.proc.memory.percent": 2.15, "system.proc.cpu.threads": 2.71, "system.network.sent": 45246, "system.network.recv": 121493, "_wandb": true, "_timestamp": 1580207075, "_runtime": 11}

View file

@ -0,0 +1,17 @@
{"epoch": 0, "_runtime": 10.482649564743042, "_timestamp": 1580207073.8952324, "_step": 0}
{"loss": 0.0, "_runtime": 10.755571603775024, "_timestamp": 1580207074.1681545, "_step": 1}
{"loss": 0.0, "_runtime": 10.839263677597046, "_timestamp": 1580207074.2518466, "_step": 2}
{"loss": 0.0, "_runtime": 10.93910551071167, "_timestamp": 1580207074.3516884, "_step": 3}
{"loss": 0.0, "_runtime": 11.038821935653687, "_timestamp": 1580207074.4514048, "_step": 4}
{"loss": 0.0, "_runtime": 11.138564825057983, "_timestamp": 1580207074.5511477, "_step": 5}
{"loss": 0.0, "_runtime": 11.237011909484863, "_timestamp": 1580207074.6495948, "_step": 6}
{"loss": 0.0, "_runtime": 11.339718103408813, "_timestamp": 1580207074.752301, "_step": 7}
{"loss": 0.0, "_runtime": 11.439802885055542, "_timestamp": 1580207074.8523858, "_step": 8}
{"loss": 0.0, "_runtime": 11.541845321655273, "_timestamp": 1580207074.9544282, "_step": 9}
{"loss": 0.0, "_runtime": 11.639827728271484, "_timestamp": 1580207075.0524106, "_step": 10}
{"loss": 0.0, "_runtime": 11.738152503967285, "_timestamp": 1580207075.1507354, "_step": 11}
{"loss": 0.0, "_runtime": 11.839798212051392, "_timestamp": 1580207075.252381, "_step": 12}
{"loss": 0.0, "_runtime": 11.939958333969116, "_timestamp": 1580207075.3525412, "_step": 13}
{"loss": 0.0, "_runtime": 12.040019989013672, "_timestamp": 1580207075.4526029, "_step": 14}
{"loss": 0.0, "_runtime": 12.139089345932007, "_timestamp": 1580207075.5516722, "_step": 15}
{"loss": 0.0, "_runtime": 12.240127325057983, "_timestamp": 1580207075.6527102, "_step": 16}

View file

@ -0,0 +1,23 @@
{
"root": "/home/clemens/repositorys/pytorch-ai",
"program": "pytorch_ai.py",
"git": {
"remote": "git@github.com:Clemens-Dautermann/pytorch-ai.git",
"commit": "55cff9b18f8558ae7a9170e56a3d5c6f6665d9ab"
},
"email": "clemens.dautermann@gmail.com",
"startedAt": "2020-01-28T10:24:24.158270",
"host": "ubuntu-laptop",
"username": "clemens",
"executable": "/usr/bin/python3",
"os": "Linux-5.3.0-26-generic-x86_64-with-Ubuntu-19.10-eoan",
"python": "3.7.5",
"cpu_count": 2,
"args": [],
"state": "killed",
"jobType": null,
"mode": "run",
"project": "tictactoe",
"heartbeatAt": "2020-01-28T10:24:36.253604",
"exitcode": 255
}

View file

@ -0,0 +1 @@
{"epoch": 0, "_step": 16, "_runtime": 12.240127325057983, "_timestamp": 1580207075.6527102, "graph_0": {"_type": "graph", "format": "torch", "nodes": [{"name": "fc1", "id": 139735090147280, "class_name": "Linear(in_features=9, out_features=9, bias=True)", "parameters": [["weight", [9, 9]], ["bias", [9]]], "output_shape": [[3, 1, 9]], "num_parameters": [81, 9]}, {"name": "fc2", "id": 139735110388688, "class_name": "Linear(in_features=9, out_features=20, bias=True)", "parameters": [["weight", [20, 9]], ["bias", [20]]], "output_shape": [[3, 1, 20]], "num_parameters": [180, 20]}, {"name": "fc3", "id": 139735090146960, "class_name": "Linear(in_features=20, out_features=50, bias=True)", "parameters": [["weight", [50, 20]], ["bias", [50]]], "output_shape": [[3, 1, 50]], "num_parameters": [1000, 50]}, {"name": "fc4", "id": 139735090146768, "class_name": "Linear(in_features=50, out_features=10, bias=True)", "parameters": [["weight", [10, 50]], ["bias", [10]]], "output_shape": [[3, 1, 10]], "num_parameters": [500, 10]}], "edges": []}, "loss": 0.0}