You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
374 lines
10 KiB
374 lines
10 KiB
require 'lalrtlib.util.strict' |
|
local class = require 'lalrtlib.util.class' |
|
local List = require 'lalrtlib.util.list' |
|
|
|
local GtpEngine = class() |
|
|
|
function GtpEngine:init() |
|
self.cur_board_state_str = "" |
|
self.cur_board = {} |
|
self.token = tostring(mp.token()) .. "_" |
|
end |
|
|
|
DECLARE 'trim_cr' |
|
function trim_cr(str) |
|
if (string.sub(str, string.len(str)) == "\r") then |
|
return string.sub(str, 1, string.len(str) - 1) |
|
else |
|
return str |
|
end |
|
end |
|
|
|
function GtpEngine:read_line(instream) |
|
local r = mp.wait_infinite(self.token .. instream) |
|
r[4] = trim_cr(r[4]) |
|
--d-- print("LINE[" .. r[4] .. "]") |
|
return r[4] |
|
end |
|
|
|
function GtpEngine:read_stdout_line() |
|
return self:read_line("stdout") |
|
end |
|
|
|
function GtpEngine:consume_output() |
|
local m = mp.check_available({ self.token .. "stderr" , self.token .. "stdout" }) |
|
while (m) do |
|
if m[4] ~= "" then |
|
if m[3] == (self.token .. "stdout") then |
|
print(">>>OUT: " .. m[4]) |
|
else |
|
print(">>>ERR: " .. m[4]) |
|
end |
|
end |
|
m = mp.check_available({ self.token .. "stderr" , self.token .. "stdout" }) |
|
end |
|
end |
|
|
|
function GtpEngine:check_stopped() |
|
self:consume_output() |
|
if (bgproc.is_readloop_stopped(self.bg)) then |
|
print("ERROR: process stopped!") |
|
return true |
|
end |
|
return false |
|
end |
|
|
|
function GtpEngine:send_gtp(cmd) |
|
if (self:check_stopped()) then return end |
|
print("<<< " .. cmd) |
|
local lines = List() |
|
local reply = nil |
|
|
|
bgproc.send(self.bg, cmd .. "\n") |
|
local r = self:read_stdout_line() |
|
while (string.sub(r, 1, 2) ~= "= ") do |
|
lines:push(r) |
|
r = self:read_stdout_line() |
|
end |
|
|
|
reply = string.sub(r, 3) |
|
|
|
r = self:read_stdout_line() |
|
while (r ~= "") do |
|
lines:push(r) |
|
r = self:read_stdout_line() |
|
end |
|
|
|
print(">>> " .. reply) |
|
return reply, lines |
|
end |
|
|
|
function GtpEngine:connect(binary, args) |
|
if (self.bg) then |
|
if (not bgproc.is_readloop_stopped(self.bg)) then |
|
self:send_gtp("quit") |
|
end |
|
bgproc.wait(self.bg) |
|
end |
|
|
|
-- mp.set_debug_logging(true) |
|
self.bg = |
|
sys.exec("background", self.token .. "stdout", self.token .. "stdout") |
|
bgproc.start(self.bg, binary, args) |
|
bgproc.start_readloop(self.bg, "\n") |
|
|
|
print("starting " .. binary .. " (" .. rt.dump(args) .. ")") |
|
mp.thread_sleep(1000) |
|
if (self:check_stopped()) then |
|
return |
|
end |
|
|
|
self.gtp_engine_name = self:send_gtp("name") |
|
self.gtp_engine_version = self:send_gtp("version") |
|
self.gtp_engine_proto_version = self:send_gtp("protocol_version") |
|
print("started " .. self.gtp_engine_name .. " - " .. self.gtp_engine_version) |
|
-- "C:/Arbeit/p/go/gnugo-3.8/gnugo.exe", { "--mode=gtp" }) |
|
end |
|
|
|
local nr2gtp = { |
|
"A", |
|
"B", |
|
"C", |
|
"D", |
|
"E", |
|
"F", |
|
"G", |
|
"H", |
|
"J", |
|
"K", |
|
"L", |
|
"M", |
|
"N", |
|
"O", |
|
"P", |
|
"Q", |
|
"R", |
|
"S", |
|
"T" |
|
} |
|
|
|
DECLARE 'pt2gtp_coord' |
|
function pt2gtp_coord(pt) |
|
return nr2gtp[pt[1]] .. tostring(pt[2]) |
|
end |
|
|
|
local gtp2nr = { |
|
A = 1, |
|
B = 2, |
|
C = 3, |
|
D = 4, |
|
E = 5, |
|
F = 6, |
|
G = 7, |
|
H = 8, |
|
J = 9, |
|
K = 10, |
|
L = 11, |
|
M = 12, |
|
N = 13, |
|
O = 14, |
|
P = 15, |
|
Q = 16, |
|
R = 17, |
|
S = 18, |
|
T = 19 |
|
} |
|
|
|
DECLARE 'gtp2pt_coord' |
|
function gtp2pt_coord(coord) |
|
local x = gtp2nr[string.sub(coord, 1, 1)] |
|
local y = tonumber(string.sub(coord, 2)) |
|
return { x, y } |
|
end |
|
|
|
function GtpEngine:set_board(boardsize, komi, moves) |
|
self.komi = komi |
|
self.boardsize = boardsize |
|
self:send_gtp("komi " .. komi) |
|
self:send_gtp("boardsize " .. tostring(boardsize)) |
|
self:send_gtp("clear_board") |
|
if (moves) then |
|
self:play_moves(moves) |
|
end |
|
end |
|
|
|
function GtpEngine:play_moves(moves) |
|
moves:foreach(function (move) |
|
local color = "B" |
|
if (move[1] == 1) then color = "W" end |
|
self:send_gtp("play " .. color .. " " .. pt2gtp_coord({ move[2], move[3] })) |
|
end) |
|
local a, b = self:send_gtp("showboard") |
|
end |
|
|
|
function GtpEngine:genmove_from_board(moves, color, time_s) |
|
if time_s then |
|
self:send_gtp("time_settings 0 " .. tostring(time_s) .. " 1") |
|
end |
|
self:send_gtp("clear_board") |
|
self:play_moves(moves) |
|
return self:genmove(color) |
|
end |
|
|
|
function GtpEngine:genmove(color_n) |
|
local color = "B" |
|
if (color_n == 1) then color = "W" end |
|
local move, restlines = self:send_gtp("genmove " .. color) |
|
return gtp2pt_coord(move), restlines |
|
end |
|
|
|
DECLARE 'split_ws' |
|
function split_ws(str) |
|
local variation = util.re(str, { [[\s+]], "s" }) |
|
return variation |
|
end |
|
|
|
function GtpEngine:analyze_lz(lines) |
|
-- R4 -> 9 (V: 52.49%) (N: 21.75%) PV: R4 Q16 C17 C16 D17 |
|
-- ^- this is winrate! |
|
lines:foreach(function (l) print("#> " .. l) end) |
|
-- local r = util.re(lines:table(), { |
|
-- [[\s*(\S+)\s+->\s*\d+\s*\(W:\s*(\d+\.\d+)%\).*PV: (.*)]] |
|
-- }); |
|
-- r = List(r):map(function(l) |
|
-- return { |
|
-- move = gtp2pt_coord(l[1]), |
|
-- winrate = tonumber(l[2]), |
|
-- variation = List(split_ws(l[3])):map(gtp2pt_coord):table() |
|
-- } |
|
-- end) |
|
-- return r |
|
end |
|
|
|
function GtpEngine:analyze_l11(lines) |
|
lines:foreach(function (l) print("#11> " .. l) end) |
|
local r = util.re(lines:table(), { |
|
[[\s*(\S+)\s+->\s*(\d+)\s*\(W:\s*(\d+\.\d+)%\).*PV: (.*)]] |
|
}); |
|
r = List(r):map(function(l) |
|
return { |
|
move = gtp2pt_coord(l[1]), |
|
visits = tonumber(l[2]), |
|
winrate = tonumber(l[3]), |
|
variation = List(split_ws(l[4])):map(gtp2pt_coord):table() |
|
} |
|
end) |
|
if (#r <= 0) then return nil end |
|
return r |
|
end |
|
|
|
function GtpEngine:get_heatmap() |
|
local x, hm_lines |
|
hm_lines = {} |
|
if (self.gtp_engine_name == "Leela") then |
|
--- XXX: Somtimes Leela just sends 17 lines in time until the version |
|
--- reply... |
|
local max = 0 |
|
while (max < 10 and #hm_lines < 18) do |
|
max = max + 1 |
|
self:send_gtp("heatmap average") |
|
-- XXX: Due to a bug in Leela, the heatmap can only be read as |
|
-- response to the next command. |
|
x, hm_lines = self:send_gtp("version") |
|
end |
|
elseif (self.gtp_engine_name == "Leela Zero") then |
|
x, hm_lines = self:send_gtp("heatmap average") |
|
end |
|
|
|
if (#hm_lines > 1) then |
|
local values = List() |
|
local max_value = 0 |
|
for i = 1, 19 do |
|
local dbg_s = "" |
|
|
|
if (not values) then values = List() end |
|
local vals = split_ws(hm_lines[i]) |
|
table.remove(vals, 1) |
|
for i = 1, 19 do |
|
vals[i] = tonumber(vals[i]) |
|
dbg_s = dbg_s .. " " .. string.format("%3d", vals[i]) |
|
end |
|
print("rw> " .. dbg_s) |
|
values:push(vals) |
|
end |
|
|
|
values:foreach(function (vals) |
|
for i = 1, 19 do |
|
if (vals[i] > max_value) then max_value = vals[i] end |
|
end |
|
end) |
|
values:foreach(function (vals) |
|
for i = 1, 19 do |
|
vals[i] = (vals[i] * 100) / max_value |
|
vals[i] = math.floor((vals[i] * 10) + 0.5) / 10 |
|
end |
|
end) |
|
return values |
|
else |
|
return nil |
|
end |
|
end |
|
|
|
function GtpEngine:analyze_position(in_color, time) |
|
local color = "B" |
|
if (in_color == 1) then color = "W" end |
|
if (time) then |
|
self:send_gtp("time_settings 0 " .. tostring(time) .. " 1") |
|
end |
|
local move, restlines = self:send_gtp("genmove " .. color) |
|
self:send_gtp("undo") |
|
local ana_res = self:analyze_l11(restlines) |
|
if (ana_res) then |
|
ana_res:foreach(function(m) m.color = in_color end) |
|
return ana_res |
|
else |
|
ana_res = self:analyze_lz(restlines) |
|
end |
|
end |
|
|
|
--function GtpEngine:handle_msg(msg) |
|
-- if (msg[3] == (self.token .. "stdout")) then |
|
-- print("GOT MSG:" .. self.token) |
|
-- end |
|
--end |
|
|
|
local m = { } |
|
|
|
function m.main() |
|
mp.send({"gtp_startup"}) |
|
|
|
local ge = GtpEngine() |
|
---- ge:connect("C:/Arbeit/p/go/gnugo-3.8/gnugo.exe", { "--mode=gtp" }) |
|
---- ge:connect("C:/Arbeit/p/go/Leela0110GTP/Leela0110GTP.exe", { "-g", "--noponder", "--nobook" }) |
|
-- ge:connect("C:/Entwicklung/git/go/Leela0110GTP/Leela0110GTP.exe", |
|
-- { "-g", "--noponder", "--nobook" }) |
|
---- ge:connect("C:/Arbeit/p/go/leela-zero-0.15-cpuonly-win32/leelaz.exe", |
|
---- {"-w", |
|
---- "C:/Arbeit/p/go/leela-zero-0.15-cpuonly-win32/68d7c8fcabe792dfe2b8e8360629d08171ec8e02530b14f0451e59fa181733ce.gz", |
|
---- "-g", "--noponder" }) |
|
-- ge:set_board(19, 7.5, List({ |
|
-- { 0, 4, 4 }, |
|
-- { 1, 16, 4 }, |
|
-- })) |
|
---- print(rt.dump({ge:genmove(0)})) |
|
-- local r = ge:analyze_position(0, 10) |
|
---- local r, l = ge:send_gtp("list_commands") |
|
-- print(rt.dump(r)) |
|
-- ge:play_moves(List({ |
|
-- { 0, 4, 16 } |
|
-- })) |
|
-- r = ge:analyze_position(1, 10) |
|
-- print(rt.dump(r)) |
|
-- local hm = ge:get_heatmap() |
|
-- hm:foreach(function (row) |
|
-- local s = "" |
|
-- List(row):foreach(function (v) |
|
-- s = s .. " " .. string.format("%3.0f", v) |
|
-- end) |
|
-- print("hm> " .. s) |
|
-- end) |
|
-- print(rt.dump(hm)) |
|
|
|
-- mp.set_debug_logging(true) |
|
while true do |
|
local r = mp.wait(nil, 1000) |
|
if (r ~= nil) then |
|
print("R:" .. rt.dump(r)) |
|
|
|
if r[3] == "gtp_connect" then |
|
ge:connect(r[4].path, r[4].args) |
|
elseif r[3] == "genmove_from_board" then |
|
local pt = |
|
ge:genmove_from_board( |
|
List(r[4].moves), |
|
r[4].color, |
|
r[4].think_time_s) |
|
mp.send({ 'genmove_from_board', pt, r[4].color }) |
|
end |
|
end |
|
end |
|
end |
|
|
|
--m.main() |
|
return m |
|
--return 1
|
|
|