00001
00002
00003
00004
00005 #ifndef GOUCT_GLOBALSEARCH_H
00006 #define GOUCT_GLOBALSEARCH_H
00007
00008 #include <limits>
00009 #include <boost/scoped_ptr.hpp>
00010 #include <boost/version.hpp>
00011 #include "GoBoard.h"
00012 #include "GoBoardUtil.h"
00013 #include "GoEyeUtil.h"
00014 #include "GoRegionBoard.h"
00015 #include "GoSafetySolver.h"
00016 #include "GoUctDefaultPriorKnowledge.h"
00017 #include "GoUctSearch.h"
00018 #include "GoUctUtil.h"
00019
00020 #define BOOST_VERSION_MAJOR (BOOST_VERSION / 100000)
00021 #define BOOST_VERSION_MINOR (BOOST_VERSION / 100 % 1000)
00022
00023
00024
00025
00026
00027
00028
00029
00030 const bool GOUCT_USE_SAFETY_SOLVER = false;
00031
00032
00033
00034
00035 struct GoUctGlobalSearchStateParam
00036 {
00037
00038
00039
00040 bool m_mercyRule;
00041
00042
00043 bool m_territoryStatistics;
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055 SgUctValue m_lengthModification;
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068 SgUctValue m_scoreModification;
00069
00070 GoUctGlobalSearchStateParam();
00071 };
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102 template<class POLICY>
00103 class GoUctGlobalSearchState
00104 : public GoUctState
00105 {
00106 public:
00107 const SgBWSet& m_safe;
00108
00109 const SgPointArray<bool>& m_allSafe;
00110
00111
00112
00113 SgPointArray<SgUctStatistics> m_territoryStatistics;
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127 GoUctGlobalSearchState(unsigned int threadId, const GoBoard& bd,
00128 POLICY* policy,
00129 const GoUctGlobalSearchStateParam& param,
00130 const GoUctPlayoutPolicyParam& policyParam,
00131 const SgBWSet& safe,
00132 const SgPointArray<bool>& allSafe);
00133
00134 ~GoUctGlobalSearchState();
00135
00136 SgUctValue Evaluate();
00137
00138 bool GenerateAllMoves(SgUctValue count, std::vector<SgUctMoveInfo>& moves,
00139 SgUctProvenType& provenType);
00140
00141 SgMove GeneratePlayoutMove(bool& skipRaveUpdate);
00142
00143 void ExecutePlayout(SgMove move);
00144
00145 void GameStart();
00146
00147 void EndPlayout();
00148
00149 void StartPlayout();
00150
00151 void StartPlayouts();
00152
00153 void StartSearch();
00154
00155 POLICY* Policy();
00156
00157
00158
00159 void SetPolicy(POLICY* policy);
00160
00161 void ClearTerritoryStatistics();
00162
00163 private:
00164 const GoUctGlobalSearchStateParam& m_param;
00165
00166 const GoUctPlayoutPolicyParam& m_policyParam;
00167
00168
00169 bool m_mercyRuleTriggered;
00170
00171
00172 int m_passMovesPlayoutPhase;
00173
00174
00175 int m_mercyRuleThreshold;
00176
00177
00178
00179 int m_stoneDiff;
00180
00181
00182 int m_initialMoveNumber;
00183
00184
00185 GoPointList m_area;
00186
00187
00188 SgUctValue m_mercyRuleResult;
00189
00190
00191
00192 SgUctValue m_invMaxScore;
00193
00194 SgRandom m_random;
00195
00196 GoUctDefaultPriorKnowledge m_priorKnowledge;
00197
00198 boost::scoped_ptr<POLICY> m_policy;
00199
00200
00201 GoUctGlobalSearchState(const GoUctGlobalSearchState& search);
00202
00203
00204 GoUctGlobalSearchState& operator=(const GoUctGlobalSearchState& search);
00205
00206 bool CheckMercyRule();
00207
00208 template<class BOARD>
00209 SgUctValue EvaluateBoard(const BOARD& bd, float komi);
00210
00211
00212 void GenerateLegalMoves(std::vector<SgUctMoveInfo>& moves);
00213
00214 float GetKomi() const;
00215 };
00216
00217 template<class POLICY>
00218 GoUctGlobalSearchState<POLICY>::GoUctGlobalSearchState(unsigned int threadId,
00219 const GoBoard& bd, POLICY* policy,
00220 const GoUctGlobalSearchStateParam& param,
00221 const GoUctPlayoutPolicyParam& policyParam,
00222 const SgBWSet& safe, const SgPointArray<bool>& allSafe)
00223 : GoUctState(threadId, bd),
00224 m_safe(safe),
00225 m_allSafe(allSafe),
00226 m_param(param),
00227 m_policyParam(policyParam),
00228 m_priorKnowledge(Board(), m_policyParam),
00229 m_policy(policy)
00230 {
00231 ClearTerritoryStatistics();
00232 }
00233
00234 template<class POLICY>
00235 GoUctGlobalSearchState<POLICY>::~GoUctGlobalSearchState()
00236 {
00237 }
00238
00239
00240 template<class POLICY>
00241 bool GoUctGlobalSearchState<POLICY>::CheckMercyRule()
00242 {
00243 SG_ASSERT(m_param.m_mercyRule);
00244
00245 SG_ASSERT(IsInPlayout());
00246 if (m_stoneDiff >= m_mercyRuleThreshold)
00247 {
00248 m_mercyRuleTriggered = true;
00249 m_mercyRuleResult = (UctBoard().ToPlay() == SG_BLACK ? 1 : 0);
00250 }
00251 else if (m_stoneDiff <= -m_mercyRuleThreshold)
00252 {
00253 m_mercyRuleTriggered = true;
00254 m_mercyRuleResult = (UctBoard().ToPlay() == SG_WHITE ? 1 : 0);
00255 }
00256 else
00257 SG_ASSERT(! m_mercyRuleTriggered);
00258 return m_mercyRuleTriggered;
00259 }
00260
00261 template<class POLICY>
00262 void GoUctGlobalSearchState<POLICY>::ClearTerritoryStatistics()
00263 {
00264 for (SgPointArray<SgUctStatistics>::NonConstIterator
00265 it(m_territoryStatistics); it; ++it)
00266 (*it).Clear();
00267 }
00268
00269 template<class POLICY>
00270 void GoUctGlobalSearchState<POLICY>::EndPlayout()
00271 {
00272 GoUctState::EndPlayout();
00273 m_policy->EndPlayout();
00274 }
00275
00276 template<class POLICY>
00277 SgUctValue GoUctGlobalSearchState<POLICY>::Evaluate()
00278 {
00279 float komi = GetKomi();
00280 if (IsInPlayout())
00281 return EvaluateBoard(UctBoard(), komi);
00282 else
00283 return EvaluateBoard(Board(), komi);
00284 }
00285
00286 template<class POLICY>
00287 template<class BOARD>
00288 SgUctValue GoUctGlobalSearchState<POLICY>::EvaluateBoard(const BOARD& bd,
00289 float komi)
00290 {
00291 SgUctValue score;
00292 SgPointArray<SgEmptyBlackWhite> scoreBoard;
00293 SgPointArray<SgEmptyBlackWhite>* scoreBoardPtr;
00294 if (m_param.m_territoryStatistics)
00295 scoreBoardPtr = &scoreBoard;
00296 else
00297 scoreBoardPtr = 0;
00298 if (m_param.m_mercyRule && m_mercyRuleTriggered)
00299 return m_mercyRuleResult;
00300 else if (m_passMovesPlayoutPhase < 2)
00301
00302 score = (SgUctValue)GoBoardUtil::TrompTaylorScore(bd, komi, scoreBoardPtr);
00303 else
00304 {
00305 score =
00306 SgUctValue(GoBoardUtil::ScoreSimpleEndPosition(bd, komi, m_safe,
00307 false,
00308 scoreBoardPtr));
00309 }
00310 if (m_param.m_territoryStatistics)
00311 for (typename BOARD::Iterator it(bd); it; ++it)
00312 switch (scoreBoard[*it])
00313 {
00314 case SG_BLACK:
00315 m_territoryStatistics[*it].Add(1);
00316 break;
00317 case SG_WHITE:
00318 m_territoryStatistics[*it].Add(0);
00319 break;
00320 case SG_EMPTY:
00321 m_territoryStatistics[*it].Add(0.5);
00322 break;
00323 }
00324 if (bd.ToPlay() != SG_BLACK)
00325 score *= -1;
00326 SgUctValue lengthMod =
00327 SgUctValue(GameLength()) * m_param.m_lengthModification;
00328 if (lengthMod > 0.5)
00329 lengthMod = 0.5;
00330 if (score > std::numeric_limits<SgUctValue>::epsilon())
00331 return
00332 (1 - m_param.m_scoreModification)
00333 + m_param.m_scoreModification * score * m_invMaxScore
00334 - lengthMod;
00335 else if (score < -std::numeric_limits<SgUctValue>::epsilon())
00336 return
00337 m_param.m_scoreModification
00338 + m_param.m_scoreModification * score * m_invMaxScore
00339 + lengthMod;
00340 else
00341
00342 return 0;
00343 }
00344
00345 template<class POLICY>
00346 void GoUctGlobalSearchState<POLICY>::ExecutePlayout(SgMove move)
00347 {
00348 GoUctState::ExecutePlayout(move);
00349 const GoUctBoard& bd = UctBoard();
00350 if (bd.ToPlay() == SG_BLACK)
00351 m_stoneDiff -= bd.NuCapturedStones();
00352 else
00353 m_stoneDiff += bd.NuCapturedStones();
00354 m_policy->OnPlay();
00355 }
00356
00357 template<class POLICY>
00358 void GoUctGlobalSearchState<POLICY>::GameStart()
00359 {
00360 GoUctState::GameStart();
00361 m_passMovesPlayoutPhase = 0;
00362 m_mercyRuleTriggered = false;
00363 }
00364
00365 template<class POLICY>
00366 void GoUctGlobalSearchState<POLICY>::GenerateLegalMoves(
00367 std::vector<SgUctMoveInfo>& moves)
00368 {
00369
00370 const GoBoard& bd = Board();
00371 SG_ASSERT(! bd.Rules().AllowSuicide());
00372
00373 if (GoBoardUtil::TwoPasses(bd))
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383 if (bd.Rules().CaptureDead()
00384 || bd.MoveNumber() - m_initialMoveNumber >= 2)
00385 return;
00386
00387 SgBlackWhite toPlay = bd.ToPlay();
00388 for (GoBoard::Iterator it(bd); it; ++it)
00389 {
00390 SgPoint p = *it;
00391 if (bd.IsEmpty(p)
00392 && ! GoEyeUtil::IsSimpleEye(bd, p, toPlay)
00393 && ! m_allSafe[p]
00394 && bd.IsLegal(p, toPlay))
00395 moves.push_back(SgUctMoveInfo(p));
00396 }
00397
00398
00399
00400
00401
00402
00403 if (moves.size() > 1)
00404 std::swap(moves[0], moves[m_random.SmallInt(moves.size())]);
00405 moves.push_back(SgUctMoveInfo(SG_PASS));
00406 }
00407
00408 template<class POLICY>
00409 bool GoUctGlobalSearchState<POLICY>::GenerateAllMoves(SgUctValue count,
00410 std::vector<SgUctMoveInfo>& moves,
00411 SgUctProvenType& provenType)
00412 {
00413 provenType = SG_NOT_PROVEN;
00414 moves.clear();
00415 GenerateLegalMoves(moves);
00416 if (! moves.empty())
00417 {
00418 if (count == 0)
00419 m_priorKnowledge.ProcessPosition(moves);
00420 }
00421 return false;
00422 }
00423
00424 template<class POLICY>
00425 SgMove GoUctGlobalSearchState<POLICY>::GeneratePlayoutMove(
00426 bool& skipRaveUpdate)
00427 {
00428 SG_ASSERT(IsInPlayout());
00429 if (m_param.m_mercyRule && CheckMercyRule())
00430 return SG_NULLMOVE;
00431 SgPoint move = m_policy->GenerateMove();
00432 SG_ASSERT(move != SG_NULLMOVE);
00433 #ifndef NDEBUG
00434
00435
00436 if (move == SG_PASS)
00437 {
00438 const GoUctBoard& bd = UctBoard();
00439 SgBalancer balancer(100);
00440 for (GoUctBoard::Iterator it(bd); it; ++it)
00441 SG_ASSERT( bd.Occupied(*it)
00442 || m_safe.OneContains(*it)
00443 || GoBoardUtil::SelfAtari(bd, *it)
00444 || ! GoUctUtil::GeneratePoint(bd, balancer,
00445 *it, bd.ToPlay())
00446 );
00447 }
00448 else
00449 SG_ASSERT(! m_safe.OneContains(move));
00450 #endif
00451
00452
00453
00454
00455 if (move == SG_PASS)
00456 {
00457 skipRaveUpdate = true;
00458 if (m_passMovesPlayoutPhase < 2)
00459 ++m_passMovesPlayoutPhase;
00460 else
00461 return SG_NULLMOVE;
00462 }
00463 else
00464 m_passMovesPlayoutPhase = 0;
00465 return move;
00466 }
00467
00468
00469 template<class POLICY>
00470 float GoUctGlobalSearchState<POLICY>::GetKomi() const
00471 {
00472 const GoRules& rules = Board().Rules();
00473 float komi = rules.Komi().ToFloat();
00474 if (rules.ExtraHandicapKomi())
00475 komi += float(rules.Handicap());
00476 return komi;
00477 }
00478
00479 template<class POLICY>
00480 inline POLICY* GoUctGlobalSearchState<POLICY>::Policy()
00481 {
00482 return m_policy.get();
00483 }
00484
00485 template<class POLICY>
00486 void GoUctGlobalSearchState<POLICY>::SetPolicy(POLICY* policy)
00487 {
00488 m_policy.reset(policy);
00489 }
00490
00491 template<class POLICY>
00492 void GoUctGlobalSearchState<POLICY>::StartPlayout()
00493 {
00494 GoUctState::StartPlayout();
00495 m_passMovesPlayoutPhase = 0;
00496 m_mercyRuleTriggered = false;
00497 const GoBoard& bd = Board();
00498 m_stoneDiff = bd.All(SG_BLACK).Size() - bd.All(SG_WHITE).Size();
00499 m_policy->StartPlayout();
00500 }
00501
00502 template<class POLICY>
00503 void GoUctGlobalSearchState<POLICY>::StartPlayouts()
00504 {
00505 GoUctState::StartPlayouts();
00506 }
00507
00508 template<class POLICY>
00509 void GoUctGlobalSearchState<POLICY>::StartSearch()
00510 {
00511 GoUctState::StartSearch();
00512 const GoBoard& bd = Board();
00513 int size = bd.Size();
00514 float maxScore = float(size * size) + GetKomi();
00515 m_invMaxScore = (SgUctValue)(1 / maxScore);
00516 m_initialMoveNumber = bd.MoveNumber();
00517 m_mercyRuleThreshold = static_cast<int>(0.3 * size * size);
00518 ClearTerritoryStatistics();
00519 }
00520
00521
00522
00523
00524
00525
00526 template<class POLICY, class FACTORY>
00527 class GoUctGlobalSearchStateFactory
00528 : public SgUctThreadStateFactory
00529 {
00530 public:
00531
00532
00533
00534
00535
00536
00537
00538
00539 GoUctGlobalSearchStateFactory(GoBoard& bd,
00540 FACTORY& playoutPolicyFactory,
00541 const GoUctPlayoutPolicyParam& policyParam,
00542 const SgBWSet& safe,
00543 const SgPointArray<bool>& allSafe);
00544
00545 SgUctThreadState* Create(unsigned int threadId, const SgUctSearch& search);
00546
00547 private:
00548 GoBoard& m_bd;
00549
00550 FACTORY& m_playoutPolicyFactory;
00551
00552 const GoUctPlayoutPolicyParam& m_policyParam;
00553
00554 const SgBWSet& m_safe;
00555
00556 const SgPointArray<bool>& m_allSafe;
00557 };
00558
00559 template<class POLICY, class FACTORY>
00560 GoUctGlobalSearchStateFactory<POLICY,FACTORY>
00561 ::GoUctGlobalSearchStateFactory(GoBoard& bd,
00562 FACTORY& playoutPolicyFactory,
00563 const GoUctPlayoutPolicyParam& policyParam,
00564 const SgBWSet& safe,
00565 const SgPointArray<bool>& allSafe)
00566 : m_bd(bd),
00567 m_playoutPolicyFactory(playoutPolicyFactory),
00568 m_policyParam(policyParam),
00569 m_safe(safe),
00570 m_allSafe(allSafe)
00571 {
00572 }
00573
00574
00575
00576
00577
00578
00579 template<class POLICY, class FACTORY>
00580 class GoUctGlobalSearch
00581 : public GoUctSearch
00582 {
00583 public:
00584 GoUctGlobalSearchStateParam m_param;
00585
00586
00587
00588
00589
00590
00591
00592
00593 GoUctGlobalSearch(GoBoard& bd,
00594 FACTORY* playoutPolicyFactory,
00595 const GoUctPlayoutPolicyParam& policyParam);
00596
00597
00598
00599
00600 SgUctValue UnknownEval() const;
00601
00602
00603
00604
00605
00606
00607
00608 void OnStartSearch();
00609
00610 void DisplayGfx();
00611
00612
00613
00614
00615
00616 void SetDefaultParameters(int boardSize);
00617
00618
00619
00620
00621
00622
00623
00624 bool GlobalSearchLiveGfx() const;
00625
00626
00627 void SetGlobalSearchLiveGfx(bool enable);
00628
00629 private:
00630 SgBWSet m_safe;
00631
00632 SgPointArray<bool> m_allSafe;
00633
00634 boost::scoped_ptr<FACTORY> m_playoutPolicyFactory;
00635
00636 GoRegionBoard m_regions;
00637
00638
00639 bool m_globalSearchLiveGfx;
00640 };
00641
00642 template<class POLICY, class FACTORY>
00643 GoUctGlobalSearch<POLICY,FACTORY>::GoUctGlobalSearch(GoBoard& bd,
00644 FACTORY* playoutFactory,
00645 const GoUctPlayoutPolicyParam& policyParam)
00646 : GoUctSearch(bd, 0),
00647 m_playoutPolicyFactory(playoutFactory),
00648 m_regions(bd),
00649 m_globalSearchLiveGfx(GOUCT_LIVEGFX_NONE)
00650 {
00651 SgUctThreadStateFactory* stateFactory =
00652 new GoUctGlobalSearchStateFactory<POLICY,FACTORY>(bd,
00653 *playoutFactory,
00654 policyParam,
00655 m_safe, m_allSafe);
00656 SetThreadStateFactory(stateFactory);
00657 SetDefaultParameters(bd.Size());
00658
00659
00660
00661
00662
00663
00664
00665 #if BOOST_VERSION_MAJOR == 1 && BOOST_VERSION_MINOR >= 35
00666 if (LockFree())
00667 {
00668 unsigned int nuThreads = boost::thread::hardware_concurrency();
00669 if (nuThreads > 4)
00670 nuThreads = 4;
00671 SgDebug() << "GoUctGlobalSearch: setting default number of threads to "
00672 << nuThreads << '\n';
00673 SetNumberThreads(nuThreads);
00674 }
00675 #endif
00676 }
00677
00678 template<class POLICY, class FACTORY>
00679 inline bool GoUctGlobalSearch<POLICY,FACTORY>::GlobalSearchLiveGfx() const
00680 {
00681 return m_globalSearchLiveGfx;
00682 }
00683
00684 template<class POLICY, class FACTORY>
00685 void GoUctGlobalSearch<POLICY,FACTORY>::DisplayGfx()
00686 {
00687 GoUctSearch::DisplayGfx();
00688 if (m_globalSearchLiveGfx)
00689 {
00690 const GoUctGlobalSearchState<POLICY>& state =
00691 dynamic_cast<GoUctGlobalSearchState<POLICY>&>(ThreadState(0));
00692 SgDebug() << "gogui-gfx:\n";
00693 GoUctUtil::GfxBestMove(*this, ToPlay(), SgDebug());
00694 GoUctUtil::GfxTerritoryStatistics(state.m_territoryStatistics,
00695 Board(), SgDebug());
00696 GoUctUtil::GfxStatus(*this, SgDebug());
00697 SgDebug() << '\n';
00698 }
00699 }
00700
00701 template<class POLICY, class FACTORY>
00702 void GoUctGlobalSearch<POLICY,FACTORY>::OnStartSearch()
00703 {
00704 GoUctSearch::OnStartSearch();
00705 m_safe.Clear();
00706 m_allSafe.Fill(false);
00707 if (GOUCT_USE_SAFETY_SOLVER)
00708 {
00709 GoBoard& bd = Board();
00710 GoSafetySolver solver(bd, &m_regions);
00711 solver.FindSafePoints(&m_safe);
00712 for (GoBoard::Iterator it(bd); it; ++it)
00713 m_allSafe[*it] = m_safe.OneContains(*it);
00714 }
00715 if (m_globalSearchLiveGfx && ! m_param.m_territoryStatistics)
00716 SgWarning() <<
00717 "GoUctGlobalSearch: "
00718 "live graphics need territory statistics enabled\n";
00719 }
00720
00721 template<class POLICY, class FACTORY>
00722 void GoUctGlobalSearch<POLICY,FACTORY>::SetDefaultParameters(int boardSize)
00723 {
00724 SetFirstPlayUrgency(1);
00725 SetMoveSelect(SG_UCTMOVESELECT_COUNT);
00726 SetRave(true);
00727 SetExpandThreshold(std::numeric_limits<SgUctValue>::is_integer ? (SgUctValue)1 : std::numeric_limits<SgUctValue>::epsilon());
00728 SetVirtualLoss(true);
00729 SetBiasTermConstant(0.0);
00730 SetExpandThreshold(3);
00731 if (boardSize < 15)
00732 {
00733
00734
00735 SetRaveWeightInitial(1.0);
00736 SetRaveWeightFinal(5000);
00737 m_param.m_lengthModification = 0;
00738 }
00739 else
00740 {
00741
00742
00743 SetRaveWeightInitial(0.9f);
00744 SetRaveWeightFinal(5000);
00745 m_param.m_lengthModification = 0.00028f;
00746 }
00747 }
00748
00749 template<class POLICY, class FACTORY>
00750 inline void GoUctGlobalSearch<POLICY,FACTORY>::SetGlobalSearchLiveGfx(
00751 bool enable)
00752 {
00753 m_globalSearchLiveGfx = enable;
00754 }
00755
00756 template<class POLICY, class FACTORY>
00757 SgUctValue GoUctGlobalSearch<POLICY,FACTORY>::UnknownEval() const
00758 {
00759
00760
00761 return (SgUctValue)0.5;
00762 }
00763
00764
00765
00766 template<class POLICY, class FACTORY>
00767 SgUctThreadState* GoUctGlobalSearchStateFactory<POLICY,FACTORY>::Create(
00768 unsigned int threadId, const SgUctSearch& search)
00769 {
00770 const GoUctGlobalSearch<POLICY,FACTORY>& globalSearch =
00771 dynamic_cast<const GoUctGlobalSearch<POLICY,FACTORY>&>(search);
00772 GoUctGlobalSearchState<POLICY>* state =
00773 new GoUctGlobalSearchState<POLICY>(threadId, globalSearch.Board(), 0,
00774 globalSearch.m_param,
00775 m_policyParam,
00776 m_safe, m_allSafe);
00777 POLICY* policy = m_playoutPolicyFactory.Create(state->UctBoard());
00778 state->SetPolicy(policy);
00779 return state;
00780 }
00781
00782
00783
00784 #endif // GOUCT_GLOBALSEARCH_H