00001
00002
00003
00004
00005 #include "SgSystem.h"
00006 #include "GoUctCommands.h"
00007
00008 #include <fstream>
00009 #include <boost/format.hpp>
00010 #include "GoEyeUtil.h"
00011 #include "GoGtpCommandUtil.h"
00012 #include "GoBoardUtil.h"
00013 #include "GoSafetySolver.h"
00014 #include "GoUctDefaultPriorKnowledge.h"
00015 #include "GoUctDefaultRootFilter.h"
00016 #include "GoUctEstimatorStat.h"
00017 #include "GoUctGlobalSearch.h"
00018 #include "GoUctPatterns.h"
00019 #include "GoUctPlayer.h"
00020 #include "GoUctPlayoutPolicy.h"
00021 #include "GoUctUtil.h"
00022 #include "GoUtil.h"
00023 #include "SgException.h"
00024 #include "SgPointSetUtil.h"
00025 #include "SgRestorer.h"
00026 #include "SgUctTreeUtil.h"
00027 #include "SgWrite.h"
00028
00029 using namespace std;
00030 using boost::format;
00031 using GoGtpCommandUtil::BlackWhiteArg;
00032 using GoGtpCommandUtil::EmptyPointArg;
00033 using GoGtpCommandUtil::PointArg;
00034
00035 typedef GoUctPlayer<GoUctGlobalSearch<GoUctPlayoutPolicy<GoUctBoard>,
00036 GoUctPlayoutPolicyFactory<GoUctBoard> >,
00037 GoUctGlobalSearchState<GoUctPlayoutPolicy<GoUctBoard> > >
00038 GoUctPlayerType;
00039
00040
00041
00042
00043 namespace {
00044
00045 GoUctLiveGfx LiveGfxArg(const GtpCommand& cmd, size_t number)
00046 {
00047 string arg = cmd.ArgToLower(number);
00048 if (arg == "none")
00049 return GOUCT_LIVEGFX_NONE;
00050 if (arg == "counts")
00051 return GOUCT_LIVEGFX_COUNTS;
00052 if (arg == "sequence")
00053 return GOUCT_LIVEGFX_SEQUENCE;
00054 throw GtpFailure() << "unknown live-gfx argument \"" << arg << '"';
00055 }
00056
00057 string LiveGfxToString(GoUctLiveGfx mode)
00058 {
00059 switch (mode)
00060 {
00061 case GOUCT_LIVEGFX_NONE:
00062 return "none";
00063 case GOUCT_LIVEGFX_COUNTS:
00064 return "counts";
00065 case GOUCT_LIVEGFX_SEQUENCE:
00066 return "sequence";
00067 default:
00068 SG_ASSERT(false);
00069 return "?";
00070 }
00071 }
00072
00073 SgUctMoveSelect MoveSelectArg(const GtpCommand& cmd, size_t number)
00074 {
00075 string arg = cmd.ArgToLower(number);
00076 if (arg == "value")
00077 return SG_UCTMOVESELECT_VALUE;
00078 if (arg == "count")
00079 return SG_UCTMOVESELECT_COUNT;
00080 if (arg == "bound")
00081 return SG_UCTMOVESELECT_BOUND;
00082 if (arg == "estimate")
00083 return SG_UCTMOVESELECT_ESTIMATE;
00084 throw GtpFailure() << "unknown move select argument \"" << arg << '"';
00085 }
00086
00087 string MoveSelectToString(SgUctMoveSelect moveSelect)
00088 {
00089 switch (moveSelect)
00090 {
00091 case SG_UCTMOVESELECT_VALUE:
00092 return "value";
00093 case SG_UCTMOVESELECT_COUNT:
00094 return "count";
00095 case SG_UCTMOVESELECT_BOUND:
00096 return "bound";
00097 case SG_UCTMOVESELECT_ESTIMATE:
00098 return "estimate";
00099 default:
00100 SG_ASSERT(false);
00101 return "?";
00102 }
00103 }
00104
00105 GoUctGlobalSearchMode SearchModeArg(const GtpCommand& cmd, size_t number)
00106 {
00107 string arg = cmd.ArgToLower(number);
00108 if (arg == "playout_policy")
00109 return GOUCT_SEARCHMODE_PLAYOUTPOLICY;
00110 if (arg == "uct")
00111 return GOUCT_SEARCHMODE_UCT;
00112 if (arg == "one_ply")
00113 return GOUCT_SEARCHMODE_ONEPLY;
00114 throw GtpFailure() << "unknown search mode argument \"" << arg << '"';
00115 }
00116
00117 string SearchModeToString(GoUctGlobalSearchMode mode)
00118 {
00119 switch (mode)
00120 {
00121 case GOUCT_SEARCHMODE_PLAYOUTPOLICY:
00122 return "playout_policy";
00123 case GOUCT_SEARCHMODE_UCT:
00124 return "uct";
00125 case GOUCT_SEARCHMODE_ONEPLY:
00126 return "one_ply";
00127 default:
00128 SG_ASSERT(false);
00129 return "?";
00130 }
00131 }
00132
00133 string KnowledgeThresholdToString(const std::vector<SgUctValue>& t)
00134 {
00135 if (t.empty())
00136 return "0";
00137 std::ostringstream os;
00138 os << '\"';
00139 for (std::size_t i = 0; i < t.size(); ++i)
00140 {
00141 if (i > 0)
00142 os << ' ';
00143 os << t[i];
00144 }
00145 os << '\"';
00146 return os.str();
00147 }
00148
00149 std::vector<SgUctValue> KnowledgeThresholdFromString(const std::string& val)
00150 {
00151 std::vector<SgUctValue> v;
00152 std::istringstream is(val);
00153 std::size_t t;
00154 while (is >> t)
00155 v.push_back(SgUctValue(t));
00156 if (v.size() == 1 && v[0] == 0)
00157 v.clear();
00158 return v;
00159 }
00160
00161 }
00162
00163
00164
00165 GoUctCommands::GoUctCommands(const GoBoard& bd, GoPlayer*& player)
00166 : m_bd(bd),
00167 m_player(player)
00168 {
00169 }
00170
00171 void GoUctCommands::AddGoGuiAnalyzeCommands(GtpCommand& cmd)
00172 {
00173 cmd <<
00174 "gfx/Uct Bounds/uct_bounds\n"
00175 "plist/Uct Default Policy/uct_default_policy\n"
00176 "gfx/Uct Gfx/uct_gfx\n"
00177 "none/Uct Max Memory/uct_max_memory %s\n"
00178 "plist/Uct Moves/uct_moves\n"
00179 "param/Uct Param GlobalSearch/uct_param_globalsearch\n"
00180 "param/Uct Param Policy/uct_param_policy\n"
00181 "param/Uct Param Player/uct_param_player\n"
00182 "param/Uct Param RootFilter/uct_param_rootfilter\n"
00183 "param/Uct Param Search/uct_param_search\n"
00184 "plist/Uct Patterns/uct_patterns\n"
00185 "pstring/Uct Policy Moves/uct_policy_moves\n"
00186 "gfx/Uct Prior Knowledge/uct_prior_knowledge\n"
00187 "sboard/Uct Rave Values/uct_rave_values\n"
00188 "plist/Uct Root Filter/uct_root_filter\n"
00189 "none/Uct SaveGames/uct_savegames %w\n"
00190 "none/Uct SaveTree/uct_savetree %w\n"
00191 "gfx/Uct Sequence/uct_sequence\n"
00192 "hstring/Uct Stat Player/uct_stat_player\n"
00193 "none/Uct Stat Player Clear/uct_stat_player_clear\n"
00194 "hstring/Uct Stat Policy/uct_stat_policy\n"
00195 "none/Uct Stat Policy Clear/uct_stat_policy_clear\n"
00196 "hstring/Uct Stat Search/uct_stat_search\n"
00197 "dboard/Uct Stat Territory/uct_stat_territory\n";
00198 }
00199
00200
00201
00202
00203
00204
00205 void GoUctCommands::CmdBounds(GtpCommand& cmd)
00206 {
00207 cmd.CheckArgNone();
00208 const GoUctSearch& search = Search();
00209 const SgUctTree& tree = search.Tree();
00210 const SgUctNode& root = tree.Root();
00211 bool hasPass = false;
00212 SgUctValue passBound = 0;
00213 cmd << "LABEL";
00214 for (SgUctChildIterator it(tree, root); it; ++it)
00215 {
00216 const SgUctNode& child = *it;
00217 SgPoint move = child.Move();
00218 SgUctValue bound = search.GetBound(search.Rave(), root, child);
00219 if (move == SG_PASS)
00220 {
00221 hasPass = true;
00222 passBound = bound;
00223 }
00224 else
00225 cmd << ' ' << SgWritePoint(move) << ' ' << fixed
00226 << setprecision(2) << bound;
00227 }
00228 cmd << '\n';
00229 if (hasPass)
00230 cmd << "TEXT PASS=" << fixed << setprecision(2) << passBound << '\n';
00231 }
00232
00233
00234 void GoUctCommands::CmdDefaultPolicy(GtpCommand& cmd)
00235 {
00236 cmd.CheckArgNone();
00237 GoUctDefaultPriorKnowledge knowledge(m_bd, GoUctPlayoutPolicyParam());
00238 SgPointSet pattern;
00239 SgPointSet atari;
00240 GoPointList empty;
00241 knowledge.FindGlobalPatternAndAtariMoves(pattern, atari, empty);
00242 cmd << SgWritePointSet(atari, "", false) << '\n';
00243 }
00244
00245
00246
00247
00248 void GoUctCommands::CmdEstimatorStat(GtpCommand& cmd)
00249 {
00250 cmd.CheckNuArg(4);
00251 size_t trueValueMaxGames = cmd.Arg<size_t>(0);
00252 size_t maxGames = cmd.Arg<size_t>(1);
00253 size_t stepSize = cmd.Arg<size_t>(2);
00254 string fileName = cmd.Arg(3);
00255 GoUctEstimatorStat::Compute(Search(), trueValueMaxGames, maxGames,
00256 stepSize, fileName);
00257 }
00258
00259
00260
00261
00262 void GoUctCommands::CmdFinalScore(GtpCommand& cmd)
00263 {
00264 cmd.CheckArgNone();
00265 SgPointSet deadStones = DoFinalStatusSearch();
00266 float score;
00267 if (! GoBoardUtil::ScorePosition(m_bd, deadStones, score))
00268 throw GtpFailure("cannot score");
00269 cmd << GoUtil::ScoreToString(score);
00270 }
00271
00272
00273
00274
00275
00276
00277
00278 void GoUctCommands::CmdFinalStatusList(GtpCommand& cmd)
00279 {
00280 string arg = cmd.Arg();
00281 if (arg == "seki")
00282 return;
00283 bool getDead;
00284 if (arg == "alive")
00285 getDead = false;
00286 else if (arg == "dead")
00287 getDead = true;
00288 else
00289 throw GtpFailure("invalid final status argument");
00290 SgPointSet deadPoints = DoFinalStatusSearch();
00291
00292
00293 for (GoBlockIterator it(m_bd); it; ++it)
00294 {
00295 if ((getDead && deadPoints.Contains(*it))
00296 || (! getDead && ! deadPoints.Contains(*it)))
00297 {
00298 for (GoBoard::StoneIterator it2(m_bd, *it); it2; ++it2)
00299 cmd << SgWritePoint(*it2) << ' ';
00300 cmd << '\n';
00301 }
00302 }
00303 }
00304
00305
00306
00307
00308 void GoUctCommands::CmdGfx(GtpCommand& cmd)
00309 {
00310 cmd.CheckArgNone();
00311 const GoUctSearch& s = Search();
00312 SgBlackWhite toPlay = s.ToPlay();
00313 GoUctUtil::GfxBestMove(s, toPlay, cmd);
00314 GoUctUtil::GfxMoveValues(s, toPlay, cmd);
00315 GoUctUtil::GfxCounts(s.Tree(), cmd);
00316 GoUctUtil::GfxStatus(s, cmd);
00317 }
00318
00319
00320
00321
00322
00323 void GoUctCommands::CmdMaxMemory(GtpCommand& cmd)
00324 {
00325 cmd.CheckNuArgLessEqual(1);
00326 if (cmd.NuArg() == 0)
00327 cmd << Search().MaxNodes() * 2 * sizeof(SgUctNode);
00328 else
00329 {
00330 std::size_t memory = cmd.ArgMin<size_t>(0, 2 * sizeof(SgUctNode));
00331 Search().SetMaxNodes(memory / 2 / sizeof(SgUctNode));
00332 }
00333 }
00334
00335
00336
00337
00338
00339 void GoUctCommands::CmdMoves(GtpCommand& cmd)
00340 {
00341 cmd.CheckArgNone();
00342 vector<SgUctMoveInfo> moves;
00343 Search().GenerateAllMoves(moves);
00344 for (std::size_t i = 0; i < moves.size(); ++i)
00345 cmd << SgWritePoint(moves[i].m_move) << ' ';
00346 cmd << '\n';
00347 }
00348
00349
00350
00351
00352
00353
00354
00355
00356
00357
00358
00359
00360
00361 void GoUctCommands::CmdParamGlobalSearch(GtpCommand& cmd)
00362 {
00363 cmd.CheckNuArgLessEqual(2);
00364 GoUctGlobalSearch<GoUctPlayoutPolicy<GoUctBoard>,
00365 GoUctPlayoutPolicyFactory<GoUctBoard> >&
00366 s = GlobalSearch();
00367 GoUctGlobalSearchStateParam& p = s.m_param;
00368 if (cmd.NuArg() == 0)
00369 {
00370
00371
00372 cmd << "[bool] live_gfx " << s.GlobalSearchLiveGfx() << '\n'
00373 << "[bool] mercy_rule " << p.m_mercyRule << '\n'
00374 << "[bool] territory_statistics " << p.m_territoryStatistics
00375 << '\n'
00376 << "[string] length_modification " << p.m_lengthModification
00377 << '\n'
00378 << "[string] score_modification " << p.m_scoreModification
00379 << '\n';
00380 }
00381 else if (cmd.NuArg() == 2)
00382 {
00383 string name = cmd.Arg(0);
00384 if (name == "live_gfx")
00385 s.SetGlobalSearchLiveGfx(cmd.Arg<bool>(1));
00386 else if (name == "mercy_rule")
00387 p.m_mercyRule = cmd.Arg<bool>(1);
00388 else if (name == "territory_statistics")
00389 p.m_territoryStatistics = cmd.Arg<bool>(1);
00390 else if (name == "length_modification")
00391 p.m_lengthModification = cmd.Arg<SgUctValue>(1);
00392 else if (name == "score_modification")
00393 p.m_scoreModification = cmd.Arg<SgUctValue>(1);
00394 else
00395 throw GtpFailure() << "unknown parameter: " << name;
00396 }
00397 else
00398 throw GtpFailure() << "need 0 or 2 arguments";
00399 }
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411
00412
00413
00414
00415
00416
00417 void GoUctCommands::CmdParamPlayer(GtpCommand& cmd)
00418 {
00419 cmd.CheckNuArgLessEqual(2);
00420 GoUctPlayerType& p = Player();
00421 if (cmd.NuArg() == 0)
00422 {
00423
00424
00425 cmd << "[bool] auto_param " << p.AutoParam() << '\n'
00426 << "[bool] early_pass " << p.EarlyPass() << '\n'
00427 << "[bool] forced_opening_moves " << p.ForcedOpeningMoves() << '\n'
00428 << "[bool] ignore_clock " << p.IgnoreClock() << '\n'
00429 << "[bool] ponder " << p.EnablePonder() << '\n'
00430 << "[bool] reuse_subtree " << p.ReuseSubtree() << '\n'
00431 << "[bool] use_root_filter " << p.UseRootFilter() << '\n'
00432 << "[string] max_games " << p.MaxGames() << '\n'
00433 << "[string] max_ponder_time " << p.MaxPonderTime() << '\n'
00434 << "[string] resign_min_games " << p.ResignMinGames() << '\n'
00435 << "[string] resign_threshold " << p.ResignThreshold() << '\n'
00436 << "[list/playout_policy/uct/one_ply] search_mode "
00437 << SearchModeToString(p.SearchMode()) << '\n';
00438 }
00439 else if (cmd.NuArg() >= 1 && cmd.NuArg() <= 2)
00440 {
00441 string name = cmd.Arg(0);
00442 if (name == "auto_param")
00443 p.SetAutoParam(cmd.Arg<bool>(1));
00444 else if (name == "early_pass")
00445 p.SetEarlyPass(cmd.Arg<bool>(1));
00446 else if (name == "forced_opening_moves")
00447 p.SetForcedOpeningMoves(cmd.Arg<bool>(1));
00448 else if (name == "ignore_clock")
00449 p.SetIgnoreClock(cmd.Arg<bool>(1));
00450 else if (name == "ponder")
00451 p.SetEnablePonder(cmd.Arg<bool>(1));
00452 else if (name == "reuse_subtree")
00453 p.SetReuseSubtree(cmd.Arg<bool>(1));
00454 else if (name == "use_root_filter")
00455 p.SetUseRootFilter(cmd.Arg<bool>(1));
00456 else if (name == "max_games")
00457 p.SetMaxGames(cmd.ArgMin<SgUctValue>(1, SgUctValue(1)));
00458 else if (name == "max_ponder_time")
00459 p.SetMaxPonderTime(cmd.ArgMin<SgUctValue>(1, 0));
00460 else if (name == "resign_min_games")
00461 p.SetResignMinGames(cmd.ArgMin<SgUctValue>(1, SgUctValue(0)));
00462 else if (name == "resign_threshold")
00463 p.SetResignThreshold(cmd.ArgMinMax<SgUctValue>(1, 0, 1));
00464 else if (name == "search_mode")
00465 p.SetSearchMode(SearchModeArg(cmd, 1));
00466 else
00467 throw GtpFailure() << "unknown parameter: " << name;
00468 }
00469 else
00470 throw GtpFailure() << "need 0 or 2 arguments";
00471 }
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483 void GoUctCommands::CmdParamPolicy(GtpCommand& cmd)
00484 {
00485 cmd.CheckNuArgLessEqual(2);
00486 GoUctPlayoutPolicyParam& p = Player().m_playoutPolicyParam;
00487 if (cmd.NuArg() == 0)
00488 {
00489
00490
00491 cmd << "[bool] nakade_heuristic " << p.m_useNakadeHeuristic << '\n'
00492 << "[bool] statistics_enabled " << p.m_statisticsEnabled << '\n'
00493 << "fillboard_tries " << p.m_fillboardTries << '\n';
00494 }
00495 else if (cmd.NuArg() == 2)
00496 {
00497 string name = cmd.Arg(0);
00498 if (name == "nakade_heuristic")
00499 p.m_useNakadeHeuristic = cmd.Arg<bool>(1);
00500 else if (name == "statistics_enabled")
00501 p.m_statisticsEnabled = cmd.Arg<bool>(1);
00502 else if (name == "fillboard_tries")
00503 p.m_fillboardTries = cmd.Arg<int>(1);
00504 else
00505 throw GtpFailure() << "unknown parameter: " << name;
00506 }
00507 else
00508 throw GtpFailure() << "need 0 or 2 arguments";
00509 }
00510
00511
00512
00513
00514
00515
00516 void GoUctCommands::CmdParamRootFilter(GtpCommand& cmd)
00517 {
00518 cmd.CheckNuArgLessEqual(2);
00519 GoUctDefaultRootFilter* f =
00520 dynamic_cast<GoUctDefaultRootFilter*>(&Player().RootFilter());
00521 if (f == 0)
00522 throw GtpFailure("root filter is not GoUctDefaultRootFilter");
00523 if (cmd.NuArg() == 0)
00524 {
00525
00526
00527 cmd << "[bool] check_ladders " << f->CheckLadders() << '\n';
00528 }
00529 else if (cmd.NuArg() == 2)
00530 {
00531 string name = cmd.Arg(0);
00532 if (name == "check_ladders")
00533 f->SetCheckLadders(cmd.Arg<bool>(1));
00534 else
00535 throw GtpFailure() << "unknown parameter: " << name;
00536 }
00537 else
00538 throw GtpFailure() << "need 0 or 2 arguments";
00539 }
00540
00541
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565 void GoUctCommands::CmdParamSearch(GtpCommand& cmd)
00566 {
00567 cmd.CheckNuArgLessEqual(2);
00568 GoUctSearch& s = Search();
00569 if (cmd.NuArg() == 0)
00570 {
00571
00572
00573 cmd << "[bool] check_float_precision " << s.CheckFloatPrecision()
00574 << '\n'
00575 << "[bool] keep_games " << s.KeepGames() << '\n'
00576 << "[bool] lock_free " << s.LockFree() << '\n'
00577 << "[bool] log_games " << s.LogGames() << '\n'
00578 << "[bool] prune_full_tree " << s.PruneFullTree() << '\n'
00579 << "[bool] rave " << s.Rave() << '\n'
00580 << "[bool] virtual_loss " << s.VirtualLoss() << '\n'
00581 << "[bool] weight_rave_updates " << s.WeightRaveUpdates() << '\n'
00582 << "[string] bias_term_constant " << s.BiasTermConstant() << '\n'
00583 << "[string] expand_threshold " << s.ExpandThreshold() << '\n'
00584 << "[string] first_play_urgency " << s.FirstPlayUrgency() << '\n'
00585 << "[string] knowledge_threshold "
00586 << KnowledgeThresholdToString(s.KnowledgeThreshold()) << '\n'
00587 << "[list/none/counts/sequence] live_gfx "
00588 << LiveGfxToString(s.LiveGfx()) << '\n'
00589 << "[string] live_gfx_interval " << s.LiveGfxInterval() << '\n'
00590 << "[string] max_nodes " << s.MaxNodes() << '\n'
00591 << "[list/value/count/bound/estimate] move_select "
00592 << MoveSelectToString(s.MoveSelect()) << '\n'
00593 << "[string] number_threads " << s.NumberThreads() << '\n'
00594 << "[string] number_playouts " << s.NumberPlayouts() << '\n'
00595 << "[string] prune_min_count " << s.PruneMinCount() << '\n'
00596 << "[string] randomize_rave_frequency "
00597 << s.RandomizeRaveFrequency() << '\n'
00598 << "[string] rave_weight_final " << s.RaveWeightFinal() << '\n'
00599 << "[string] rave_weight_initial "
00600 << s.RaveWeightInitial() << '\n';
00601
00602 }
00603 else if (cmd.NuArg() == 2)
00604 {
00605 string name = cmd.Arg(0);
00606 if (name == "check_float_precision")
00607 s.SetCheckFloatPrecision(cmd.Arg<bool>(1));
00608 else if (name == "keep_games")
00609 s.SetKeepGames(cmd.Arg<bool>(1));
00610 else if (name == "knowledge_threshold")
00611 s.SetKnowledgeThreshold(KnowledgeThresholdFromString(cmd.Arg(1)));
00612 else if (name == "lock_free")
00613 s.SetLockFree(cmd.Arg<bool>(1));
00614 else if (name == "log_games")
00615 s.SetLogGames(cmd.Arg<bool>(1));
00616 else if (name == "prune_full_tree")
00617 s.SetPruneFullTree(cmd.Arg<bool>(1));
00618 else if (name == "randomize_rave_frequency")
00619 s.SetRandomizeRaveFrequency(cmd.ArgMin<int>(1, 0));
00620 else if (name == "rave")
00621 s.SetRave(cmd.Arg<bool>(1));
00622 else if (name == "weight_rave_updates")
00623 s.SetWeightRaveUpdates(cmd.Arg<bool>(1));
00624 else if (name == "virtual_loss")
00625 s.SetVirtualLoss(cmd.Arg<bool>(1));
00626 else if (name == "bias_term_constant")
00627 s.SetBiasTermConstant(cmd.Arg<float>(1));
00628 else if (name == "expand_threshold")
00629 s.SetExpandThreshold(cmd.ArgMin<SgUctValue>(1, 0));
00630 else if (name == "first_play_urgency")
00631 s.SetFirstPlayUrgency(cmd.Arg<SgUctValue>(1));
00632 else if (name == "live_gfx")
00633 s.SetLiveGfx(LiveGfxArg(cmd, 1));
00634 else if (name == "live_gfx_interval")
00635 s.SetLiveGfxInterval(cmd.ArgMin<SgUctValue>(1, 1));
00636 else if (name == "max_nodes")
00637 s.SetMaxNodes(cmd.ArgMin<size_t>(1, 1));
00638 else if (name == "move_select")
00639 s.SetMoveSelect(MoveSelectArg(cmd, 1));
00640 else if (name == "number_threads")
00641 s.SetNumberThreads(cmd.ArgMin<unsigned int>(1, 1));
00642 else if (name == "number_playouts")
00643 s.SetNumberPlayouts(cmd.ArgMin<int>(1, 1));
00644 else if (name == "prune_min_count")
00645 s.SetPruneMinCount(cmd.ArgMin<SgUctValue>(1, SgUctValue(1)));
00646 else if (name == "rave_weight_final")
00647 s.SetRaveWeightFinal(cmd.Arg<float>(1));
00648 else if (name == "rave_weight_initial")
00649 s.SetRaveWeightInitial(cmd.Arg<float>(1));
00650 else
00651 throw GtpFailure() << "unknown parameter: " << name;
00652 }
00653 else
00654 throw GtpFailure() << "need 0 or 2 arguments";
00655 }
00656
00657
00658
00659
00660 void GoUctCommands::CmdPatterns(GtpCommand& cmd)
00661 {
00662 cmd.CheckArgNone();
00663 GoUctPatterns<GoBoard> patterns(m_bd);
00664 for (GoBoard::Iterator it(m_bd); it; ++it)
00665 if (m_bd.IsEmpty(*it) && patterns.MatchAny(*it))
00666 cmd << SgWritePoint(*it) << ' ';
00667 }
00668
00669
00670
00671
00672
00673 void GoUctCommands::CmdPolicyMoves(GtpCommand& cmd)
00674 {
00675 cmd.CheckArgNone();
00676 GoUctPlayoutPolicy<GoBoard> policy(m_bd, Player().m_playoutPolicyParam);
00677 policy.StartPlayout();
00678 policy.GenerateMove();
00679 cmd << GoUctPlayoutPolicyTypeStr(policy.MoveType());
00680 GoPointList moves = policy.GetEquivalentBestMoves();
00681
00682
00683
00684
00685 moves.Sort();
00686 for (int i = 0; i < moves.Length(); ++i)
00687 cmd << ' ' << SgWritePoint(moves[i]);
00688 }
00689
00690
00691
00692
00693
00694 void GoUctCommands::CmdPriorKnowledge(GtpCommand& cmd)
00695 {
00696 cmd.CheckNuArgLessEqual(1);
00697 SgUctValue count = 0;
00698 if (cmd.NuArg() == 1)
00699 count = SgUctValue(cmd.ArgMin<size_t>(0, 0));
00700 GoUctGlobalSearchState<GoUctPlayoutPolicy<GoUctBoard> >& state
00701 = ThreadState(0);
00702 state.StartSearch();
00703 vector<SgUctMoveInfo> moves;
00704 SgUctProvenType provenType;
00705 state.GenerateAllMoves(count, moves, provenType);
00706
00707 cmd << "INFLUENCE ";
00708 for (size_t i = 0; i < moves.size(); ++i)
00709 {
00710 SgMove move = moves[i].m_move;
00711 SgUctValue value = SgUctSearch::InverseEval(moves[i].m_value);
00712 SgUctValue count = moves[i].m_count;
00713 if (count > 0)
00714 {
00715 SgUctValue scaledValue = (value * 2 - 1);
00716 if (m_bd.ToPlay() != SG_BLACK)
00717 scaledValue *= -1;
00718 cmd << ' ' << SgWritePoint(move) << ' ' << scaledValue;
00719 }
00720 }
00721 cmd << "\nLABEL ";
00722 for (size_t i = 0; i < moves.size(); ++i)
00723 {
00724 SgMove move = moves[i].m_move;
00725 SgUctValue count = moves[i].m_count;
00726 if (count > 0)
00727 cmd << ' ' << SgWritePoint(move) << ' ' << count;
00728 }
00729 cmd << '\n';
00730 }
00731
00732
00733
00734
00735
00736 void GoUctCommands::CmdRaveValues(GtpCommand& cmd)
00737 {
00738 cmd.CheckArgNone();
00739 const GoUctSearch& search = Search();
00740 if (! search.Rave())
00741 throw GtpFailure("RAVE not enabled");
00742 SgPointArray<string> array("\"\"");
00743 const SgUctTree& tree = search.Tree();
00744 for (SgUctChildIterator it(tree, tree.Root()); it; ++it)
00745 {
00746 const SgUctNode& child = *it;
00747 SgPoint p = child.Move();
00748 if (p == SG_PASS || ! child.HasRaveValue())
00749 continue;
00750 ostringstream out;
00751 out << fixed << setprecision(2) << child.RaveValue();
00752 array[p] = out.str();
00753 }
00754 cmd << '\n'
00755 << SgWritePointArray<string>(array, m_bd.Size());
00756 }
00757
00758
00759
00760 void GoUctCommands::CmdRootFilter(GtpCommand& cmd)
00761 {
00762 cmd.CheckArgNone();
00763 cmd << SgWritePointList(Player().RootFilter().Get(), "", false);
00764 }
00765
00766
00767
00768
00769
00770
00771
00772 void GoUctCommands::CmdSaveTree(GtpCommand& cmd)
00773 {
00774 if (Search().MpiSynchronizer()->IsRootProcess())
00775 {
00776 cmd.CheckNuArgLessEqual(2);
00777 string fileName = cmd.Arg(0);
00778 int maxDepth = -1;
00779 if (cmd.NuArg() == 2)
00780 maxDepth = cmd.ArgMin<int>(1, 0);
00781 ofstream out(fileName.c_str());
00782 if (! out)
00783 throw GtpFailure() << "Could not open " << fileName;
00784 Search().SaveTree(out, maxDepth);
00785 }
00786 }
00787
00788
00789
00790
00791 void GoUctCommands::CmdSaveGames(GtpCommand& cmd)
00792 {
00793 string fileName = cmd.Arg();
00794 try
00795 {
00796 Search().SaveGames(fileName);
00797 }
00798 catch (const SgException& e)
00799 {
00800 throw GtpFailure(e.what());
00801 }
00802 }
00803
00804
00805
00806
00807
00808 void GoUctCommands::CmdScore(GtpCommand& cmd)
00809 {
00810 cmd.CheckArgNone();
00811 try
00812 {
00813 float komi = m_bd.Rules().Komi().ToFloat();
00814 cmd << GoBoardUtil::ScoreSimpleEndPosition(m_bd, komi);
00815 }
00816 catch (const SgException& e)
00817 {
00818 throw GtpFailure(e.what());
00819 }
00820 }
00821
00822
00823
00824
00825
00826
00827
00828
00829 void GoUctCommands::CmdSequence(GtpCommand& cmd)
00830 {
00831 cmd.CheckArgNone();
00832 GoUctUtil::GfxSequence(Search(), Search().ToPlay(), cmd);
00833 }
00834
00835
00836
00837
00838 void GoUctCommands::CmdStatPlayer(GtpCommand& cmd)
00839 {
00840 cmd.CheckArgNone();
00841 Player().GetStatistics().Write(cmd);
00842 }
00843
00844
00845
00846
00847 void GoUctCommands::CmdStatPlayerClear(GtpCommand& cmd)
00848 {
00849 cmd.CheckArgNone();
00850 Player().ClearStatistics();
00851 }
00852
00853
00854
00855
00856
00857
00858
00859 void GoUctCommands::CmdStatPolicy(GtpCommand& cmd)
00860 {
00861 cmd.CheckArgNone();
00862 if (! Player().m_playoutPolicyParam.m_statisticsEnabled)
00863 SgWarning() << "statistics not enabled in policy parameters\n";
00864 cmd << "Black Statistics:\n";
00865 Policy(0).Statistics(SG_BLACK).Write(cmd);
00866 cmd << "\nWhite Statistics:\n";
00867 Policy(0).Statistics(SG_WHITE).Write(cmd);
00868 }
00869
00870
00871
00872
00873
00874 void GoUctCommands::CmdStatPolicyClear(GtpCommand& cmd)
00875 {
00876 cmd.CheckArgNone();
00877 Policy(0).ClearStatistics();
00878 }
00879
00880
00881
00882
00883 void GoUctCommands::CmdStatSearch(GtpCommand& cmd)
00884 {
00885 cmd.CheckArgNone();
00886 const GoUctSearch& search = Search();
00887 SgUctTreeStatistics treeStatistics;
00888 treeStatistics.Compute(search.Tree());
00889 cmd << "SearchStatistics:\n";
00890 search.WriteStatistics(cmd);
00891 cmd << "TreeStatistics:\n"
00892 << treeStatistics;
00893 }
00894
00895
00896
00897
00898
00899
00900
00901 void GoUctCommands::CmdStatTerritory(GtpCommand& cmd)
00902 {
00903 cmd.CheckArgNone();
00904 SgPointArray<SgUctStatistics> territoryStatistics
00905 = ThreadState(0).m_territoryStatistics;
00906 SgPointArray<SgUctValue> array;
00907 for (GoBoard::Iterator it(m_bd); it; ++it)
00908 {
00909 if (territoryStatistics[*it].Count() == 0)
00910 throw GtpFailure("no statistics available: enable them and run search first");
00911 array[*it] = territoryStatistics[*it].Mean() * 2 - 1;
00912 }
00913 cmd << '\n'
00914 << SgWritePointArrayFloat<SgUctValue>(array, m_bd.Size(), true, 3);
00915 }
00916
00917
00918
00919 void GoUctCommands::CmdValue(GtpCommand& cmd)
00920 {
00921 cmd.CheckArgNone();
00922 cmd << Search().Tree().Root().Mean();
00923 }
00924
00925
00926
00927 void GoUctCommands::CmdValueBlack(GtpCommand& cmd)
00928 {
00929 cmd.CheckArgNone();
00930 SgUctValue value = Search().Tree().Root().Mean();
00931 if (Search().ToPlay() == SG_WHITE)
00932 value = SgUctSearch::InverseEval(value);
00933 cmd << value;
00934 }
00935
00936
00937
00938
00939 SgPointSet GoUctCommands::DoFinalStatusSearch()
00940 {
00941 SgPointSet deadStones;
00942 if (GoBoardUtil::TwoPasses(m_bd) && m_bd.Rules().CaptureDead())
00943
00944 return deadStones;
00945
00946 const size_t MAX_GAMES = 10000;
00947 SgDebug() << "GoUctCommands::DoFinalStatusSearch: doing a search with "
00948 << MAX_GAMES << " games to determine final status\n";
00949 GoUctGlobalSearch<GoUctPlayoutPolicy<GoUctBoard>,
00950 GoUctPlayoutPolicyFactory<GoUctBoard> >&
00951 search = GlobalSearch();
00952 SgRestorer<bool> restorer(&search.m_param.m_territoryStatistics);
00953 search.m_param.m_territoryStatistics = true;
00954
00955
00956 int nuUndoPass = 0;
00957 SgBlackWhite toPlay = m_bd.ToPlay();
00958 GoModBoard modBoard(m_bd);
00959 GoBoard& bd = modBoard.Board();
00960 while (bd.GetLastMove() == SG_PASS)
00961 {
00962 bd.Undo();
00963 toPlay = SgOppBW(toPlay);
00964 ++nuUndoPass;
00965 }
00966 m_player->UpdateSubscriber();
00967 if (nuUndoPass > 0)
00968 SgDebug() << "Undoing " << nuUndoPass << " passes\n";
00969 vector<SgMove> sequence;
00970 search.Search(MAX_GAMES, numeric_limits<double>::max(), sequence);
00971 SgDebug() << SgWriteLabel("Sequence")
00972 << SgWritePointList(sequence, "", false);
00973 for (int i = 0; i < nuUndoPass; ++i)
00974 {
00975 bd.Play(SG_PASS, toPlay);
00976 toPlay = SgOppBW(toPlay);
00977 }
00978 m_player->UpdateSubscriber();
00979
00980 SgPointArray<SgUctStatistics> territoryStatistics =
00981 ThreadState(0).m_territoryStatistics;
00982 GoSafetySolver safetySolver(bd);
00983 SgBWSet safe;
00984 safetySolver.FindSafePoints(&safe);
00985 for (GoBlockIterator it(bd); it; ++it)
00986 {
00987 SgBlackWhite c = bd.GetStone(*it);
00988 bool isDead = safe[SgOppBW(c)].Contains(*it);
00989 if (! isDead && ! safe[c].Contains(*it))
00990 {
00991 SgStatistics<SgUctValue,int> averageStatus;
00992 for (GoBoard::StoneIterator it2(bd, *it); it2; ++it2)
00993 {
00994 if (territoryStatistics[*it2].Count() == 0)
00995
00996
00997 return deadStones;
00998 averageStatus.Add(territoryStatistics[*it2].Mean());
00999 }
01000 const float threshold = 0.3f;
01001 isDead =
01002 ((c == SG_BLACK && averageStatus.Mean() < threshold)
01003 || (c == SG_WHITE && averageStatus.Mean() > 1 - threshold));
01004 }
01005 if (isDead)
01006 for (GoBoard::StoneIterator it2(bd, *it); it2; ++it2)
01007 deadStones.Include(*it2);
01008 }
01009 return deadStones;
01010 }
01011
01012 GoUctGlobalSearch<GoUctPlayoutPolicy<GoUctBoard>,
01013 GoUctPlayoutPolicyFactory<GoUctBoard> >&
01014 GoUctCommands::GlobalSearch()
01015 {
01016 return Player().GlobalSearch();
01017 }
01018
01019 GoUctPlayerType& GoUctCommands::Player()
01020 {
01021 if (m_player == 0)
01022 throw GtpFailure("player not GoUctPlayer");
01023 try
01024 {
01025 return dynamic_cast<GoUctPlayerType&>(*m_player);
01026 }
01027 catch (const bad_cast&)
01028 {
01029 throw GtpFailure("player not GoUctPlayer");
01030 }
01031 }
01032
01033 GoUctPlayoutPolicy<GoUctBoard>&
01034 GoUctCommands::Policy(unsigned int threadId)
01035 {
01036 GoUctPlayoutPolicy<GoUctBoard>* policy =
01037 dynamic_cast<GoUctPlayoutPolicy<GoUctBoard>*>(
01038 ThreadState(threadId).Policy());
01039 if (policy == 0)
01040 throw GtpFailure("player has no GoUctPlayoutPolicy");
01041 return *policy;
01042 }
01043
01044 void GoUctCommands::Register(GtpEngine& e)
01045 {
01046 Register(e, "final_score", &GoUctCommands::CmdFinalScore);
01047 Register(e, "final_status_list", &GoUctCommands::CmdFinalStatusList);
01048 Register(e, "uct_bounds", &GoUctCommands::CmdBounds);
01049 Register(e, "uct_default_policy", &GoUctCommands::CmdDefaultPolicy);
01050 Register(e, "uct_estimator_stat", &GoUctCommands::CmdEstimatorStat);
01051 Register(e, "uct_gfx", &GoUctCommands::CmdGfx);
01052 Register(e, "uct_max_memory", &GoUctCommands::CmdMaxMemory);
01053 Register(e, "uct_moves", &GoUctCommands::CmdMoves);
01054 Register(e, "uct_param_globalsearch",
01055 &GoUctCommands::CmdParamGlobalSearch);
01056 Register(e, "uct_param_policy", &GoUctCommands::CmdParamPolicy);
01057 Register(e, "uct_param_player", &GoUctCommands::CmdParamPlayer);
01058 Register(e, "uct_param_rootfilter", &GoUctCommands::CmdParamRootFilter);
01059 Register(e, "uct_param_search", &GoUctCommands::CmdParamSearch);
01060 Register(e, "uct_patterns", &GoUctCommands::CmdPatterns);
01061 Register(e, "uct_policy_moves", &GoUctCommands::CmdPolicyMoves);
01062 Register(e, "uct_prior_knowledge", &GoUctCommands::CmdPriorKnowledge);
01063 Register(e, "uct_rave_values", &GoUctCommands::CmdRaveValues);
01064 Register(e, "uct_root_filter", &GoUctCommands::CmdRootFilter);
01065 Register(e, "uct_savegames", &GoUctCommands::CmdSaveGames);
01066 Register(e, "uct_savetree", &GoUctCommands::CmdSaveTree);
01067 Register(e, "uct_sequence", &GoUctCommands::CmdSequence);
01068 Register(e, "uct_score", &GoUctCommands::CmdScore);
01069 Register(e, "uct_stat_player", &GoUctCommands::CmdStatPlayer);
01070 Register(e, "uct_stat_player_clear", &GoUctCommands::CmdStatPlayerClear);
01071 Register(e, "uct_stat_policy", &GoUctCommands::CmdStatPolicy);
01072 Register(e, "uct_stat_policy_clear", &GoUctCommands::CmdStatPolicyClear);
01073 Register(e, "uct_stat_search", &GoUctCommands::CmdStatSearch);
01074 Register(e, "uct_stat_territory", &GoUctCommands::CmdStatTerritory);
01075 Register(e, "uct_value", &GoUctCommands::CmdValue);
01076 Register(e, "uct_value_black", &GoUctCommands::CmdValueBlack);
01077 }
01078
01079 void GoUctCommands::Register(GtpEngine& engine, const std::string& command,
01080 GtpCallback<GoUctCommands>::Method method)
01081 {
01082 engine.Register(command, new GtpCallback<GoUctCommands>(this, method));
01083 }
01084
01085 GoUctSearch& GoUctCommands::Search()
01086 {
01087 try
01088 {
01089 GoUctObjectWithSearch& object =
01090 dynamic_cast<GoUctObjectWithSearch&>(*m_player);
01091 return object.Search();
01092 }
01093 catch (const bad_cast&)
01094 {
01095 throw GtpFailure("player is not a GoUctObjectWithSearch");
01096 }
01097 }
01098
01099
01100
01101
01102 GoUctGlobalSearchState<GoUctPlayoutPolicy<GoUctBoard> >&
01103 GoUctCommands::ThreadState(unsigned int threadId)
01104 {
01105 GoUctSearch& search = Search();
01106 if (! search.ThreadsCreated())
01107 search.CreateThreads();
01108 try
01109 {
01110 return dynamic_cast<
01111 GoUctGlobalSearchState<GoUctPlayoutPolicy<GoUctBoard> >&>(
01112 search.ThreadState(threadId));
01113 }
01114 catch (const bad_cast&)
01115 {
01116 throw GtpFailure("player has no GoUctGlobalSearchState");
01117 }
01118 }
01119
01120