Skip to content

Commit

Permalink
Merge genboard
Browse files Browse the repository at this point in the history
  • Loading branch information
lightvector committed Nov 13, 2020
2 parents 0e6ca47 + 41d1c12 commit 78cafee
Show file tree
Hide file tree
Showing 5 changed files with 841 additions and 9 deletions.
17 changes: 10 additions & 7 deletions python/board.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import random
import numpy as np

class IllegalMoveError(ValueError):
pass

#Implements legal moves without superko
class Board:
EMPTY = 0
Expand Down Expand Up @@ -290,9 +293,9 @@ def is_on_board(self,loc):
#Set a given location with error checking. Suicide setting allowed.
def set_stone(self,pla,loc):
if pla != Board.EMPTY and pla != Board.BLACK and pla != Board.WHITE:
raise ValueError("Invalid pla for board.set")
raise IllegalMoveError("Invalid pla for board.set")
if not self.is_on_board(loc):
raise ValueError("Invalid loc for board.set")
raise IllegalMoveError("Invalid loc for board.set")

if self.board[loc] == pla:
pass
Expand All @@ -312,17 +315,17 @@ def set_stone(self,pla,loc):
#Single stone suicide is disallowed but suicide is allowed, to support rule sets and sgfs that have suicide
def play(self,pla,loc):
if pla != Board.BLACK and pla != Board.WHITE:
raise ValueError("Invalid pla for board.play")
raise IllegalMoveError("Invalid pla for board.play")

if loc != Board.PASS_LOC:
if not self.is_on_board(loc):
raise ValueError("Invalid loc for board.set")
raise IllegalMoveError("Invalid loc for board.set")
if self.board[loc] != Board.EMPTY:
raise ValueError("Location is nonempty")
raise IllegalMoveError("Location is nonempty")
if self.would_be_single_stone_suicide(pla,loc):
raise ValueError("Move would be illegal single stone suicide")
raise IllegalMoveError("Move would be illegal single stone suicide")
if loc == self.simple_ko_point:
raise ValueError("Move would be illegal simple ko recapture")
raise IllegalMoveError("Move would be illegal simple ko recapture")

self.playUnsafe(pla,loc)

Expand Down
6 changes: 4 additions & 2 deletions python/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from board import Board

class Metadata:
def __init__(self, size, bname, wname, brank, wrank, komi):
def __init__(self, size, bname, wname, brank, wrank, komi, handicap):
self.size = size
self.bname = bname
self.wname = wname
self.brank = brank
self.wrank = wrank
self.komi = komi
self.handicap = handicap

#Returns (metadata, list of setup stones, list of move stones)
#Setup and move stones are both pairs of (pla,loc)
Expand Down Expand Up @@ -89,6 +90,7 @@ def load_sgf_moves_exn(path):
wrank = (root.get("WR") if root.has_property("WR") else None)
komi = (root.get("KM") if root.has_property("KM") else None)
rulesstr = (root.get("RU") if root.has_property("RU") else None)
handicap = (root.get("HA") if root.has_property("HA") else None)

rules = None
if rulesstr is not None:
Expand Down Expand Up @@ -157,5 +159,5 @@ def load_sgf_moves_exn(path):
else:
raise Exception("Could not parse rules: " + origrulesstr)

metadata = Metadata(size, bname, wname, brank, wrank, komi)
metadata = Metadata(size, bname, wname, brank, wrank, komi, handicap)
return metadata, setup, moves, rules
131 changes: 131 additions & 0 deletions python/genboard_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import traceback
import json
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
def __init__(self,num_channels,scale_init):
super(ResBlock, self).__init__()
kernel_size = 3
self.biasa = nn.Parameter(torch.zeros(num_channels,1,1))
self.conva = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.normal_(self.conva.weight,std=math.sqrt(2.0 / num_channels / kernel_size / kernel_size)*scale_init)
self.biasb = nn.Parameter(torch.zeros(num_channels,1,1))
self.scalb = nn.Parameter(torch.ones(num_channels,1,1))
self.convb = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.zeros_(self.convb.weight)

def forward(self, trunk):
x = F.relu(trunk+self.biasa)
x = self.conva(x)
x = F.relu(x*self.scalb+self.biasb)
x = self.convb(x)
return trunk+x

class GPoolResBlock(nn.Module):
def __init__(self,num_channels,scale_init):
super(GPoolResBlock, self).__init__()
kernel_size = 3
self.biasa = nn.Parameter(torch.zeros(num_channels,1,1))
self.conva = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.normal_(self.conva.weight,std=math.sqrt(1.0 / num_channels / kernel_size / kernel_size)*scale_init)
self.convg = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.normal_(self.convg.weight,std=math.sqrt(1.0 / num_channels / kernel_size / kernel_size)*math.sqrt(scale_init))
self.matg = nn.Parameter(torch.zeros(num_channels,num_channels))
torch.nn.init.normal_(self.matg,std=math.sqrt(1.0 / num_channels)*math.sqrt(scale_init))
self.biasb = nn.Parameter(torch.zeros(num_channels,1,1))
self.scalb = nn.Parameter(torch.ones(num_channels,1,1))
self.convb = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.zeros_(self.convb.weight)

def forward(self, trunk):
x = F.relu(trunk+self.biasa)
x = self.conva(x)
g = self.convg(x)
gsize = g.size()
g = torch.sum(g,(2,3)) / (gsize[2] * gsize[3]) # nchw -> nc
g = torch.matmul(g,self.matg)
g = g.view(gsize[0],gsize[1],1,1)
x = x + g
x = F.relu(x*self.scalb+self.biasb)
x = self.convb(x)
return trunk+x


class Model(nn.Module):
def __init__(self, num_channels, num_blocks):
super(Model, self).__init__()
# Channel 0: Next inference point
# Channel 1: On-board
# Channel 2: Black
# Channel 3: White
# Channel 4: Unknown
# Channel 5: Turn number / 100
# Channel 6: Noise stdev in turn number / 50
# Channel 7: Source

self.inference_channel = 0
self.num_channels = num_channels
self.num_blocks = num_blocks
self.conv0 = nn.Conv2d(in_channels=8, out_channels=self.num_channels, kernel_size=3, padding=1, bias=False)

self.blocks = nn.ModuleList([])
self.fixup_scale_init = 1.0 / math.sqrt(self.num_blocks)
self.blocks.append(ResBlock(self.num_channels,self.fixup_scale_init))
self.blocks.append(ResBlock(self.num_channels,self.fixup_scale_init))

next_is_gpool = True
for b in range(num_blocks-2):
if next_is_gpool:
self.blocks.append(GPoolResBlock(self.num_channels,self.fixup_scale_init))
else:
self.blocks.append(ResBlock(self.num_channels,self.fixup_scale_init))
next_is_gpool = not next_is_gpool

assert(len(self.blocks) == self.num_blocks)

self.endtrunk_bias_focus = nn.Parameter(torch.zeros(self.num_channels,1,1))
self.endtrunk_bias_g = nn.Parameter(torch.zeros(self.num_channels,1,1))
self.convg = nn.Conv2d(in_channels=self.num_channels, out_channels=self.num_channels, kernel_size=1, padding=0, bias=False)

self.fc1 = nn.Linear(self.num_channels*2, self.num_channels)
self.fc2 = nn.Linear(self.num_channels,3)
self.convaux = nn.Conv2d(in_channels=self.num_channels, out_channels=3, kernel_size=1, padding=0, bias=True)

def forward(self, inputs):
trunk = self.conv0(inputs)
for i in range(self.num_blocks):
trunk = self.blocks[i](trunk)

head_focus = F.relu(trunk+self.endtrunk_bias_focus)
head_g = F.relu(trunk+self.endtrunk_bias_g)
aux = self.convaux(head_focus)
gsize = head_g.size()

x = torch.sum(head_focus * inputs[:,self.inference_channel:self.inference_channel+1,:,:],(2,3))
g = torch.sum(head_g,(2,3)) / (gsize[2] * gsize[3]) # nchw -> nc

x = torch.cat((x,g),dim=1)

x = F.relu(self.fc1(x))
x = self.fc2(x)
return x,aux

def save_to_file(self, filename):
state_dict = self.state_dict()
data = {}
data["num_channels"] = self.num_channels
data["num_blocks"] = self.num_blocks
data["state_dict"] = state_dict
torch.save(data, filename)

@staticmethod
def load_from_file(filename):
data = torch.load(filename)
model = Model(data["num_channels"], data["num_blocks"])
model.load_state_dict(data["state_dict"])
return model


Loading

0 comments on commit 78cafee

Please sign in to comment.