00001
00002
00003
00004
00005 #include "SgSystem.h"
00006 #include "SgUctSearch.h"
00007
00008 #include <algorithm>
00009 #include <cmath>
00010 #include <iomanip>
00011 #include <boost/format.hpp>
00012 #include <boost/io/ios_state.hpp>
00013 #include <boost/version.hpp>
00014 #include "SgDebug.h"
00015 #include "SgHashTable.h"
00016 #include "SgMath.h"
00017 #include "SgPlatform.h"
00018 #include "SgWrite.h"
00019
00020 using namespace std;
00021 using boost::barrier;
00022 using boost::condition;
00023 using boost::format;
00024 using boost::mutex;
00025 using boost::shared_ptr;
00026 using boost::io::ios_all_saver;
00027
00028 #define BOOST_VERSION_MAJOR (BOOST_VERSION / 100000)
00029 #define BOOST_VERSION_MINOR (BOOST_VERSION / 100 % 1000)
00030
00031
00032
00033 namespace {
00034
00035 const bool DEBUG_THREADS = false;
00036
00037
00038
00039
00040
00041
00042 bool GetLockFreeDefault()
00043 {
00044 #if defined(WIN32) || defined(ENABLE_CACHE_SYNC)
00045 return true;
00046 #elif defined(HOST_CPU)
00047 string hostCpu(HOST_CPU);
00048 return hostCpu == "i386" || hostCpu == "i486" || hostCpu == "i586"
00049 || hostCpu == "i686" || hostCpu == "x86_64";
00050 #else
00051 return false;
00052 #endif
00053 }
00054
00055
00056
00057
00058
00059 size_t GetMaxNodesDefault()
00060 {
00061 size_t totalMemory = SgPlatform::TotalMemory();
00062 SgDebug() << "SgUctSearch: system memory ";
00063 if (totalMemory == 0)
00064 SgDebug() << "unknown";
00065 else
00066 SgDebug() << totalMemory;
00067
00068
00069 size_t searchMemory = totalMemory / 2;
00070 if (searchMemory < 384000000)
00071 searchMemory = 384000000;
00072 if (searchMemory > 1000000000)
00073 searchMemory = 1000000000;
00074 size_t memoryPerTree = searchMemory / 2;
00075 size_t nodesPerTree = memoryPerTree / sizeof(SgUctNode);
00076 SgDebug() << ", using " << searchMemory << " (" << nodesPerTree
00077 << " nodes)\n";
00078 return nodesPerTree;
00079 }
00080
00081 void Notify(mutex& aMutex, condition& aCondition)
00082 {
00083 mutex::scoped_lock lock(aMutex);
00084 aCondition.notify_all();
00085 }
00086
00087 }
00088
00089
00090
00091 void SgUctGameInfo::Clear(std::size_t numberPlayouts)
00092 {
00093 m_nodes.clear();
00094 m_inTreeSequence.clear();
00095 if (numberPlayouts != m_sequence.size())
00096 {
00097 m_sequence.resize(numberPlayouts);
00098 m_skipRaveUpdate.resize(numberPlayouts);
00099 m_eval.resize(numberPlayouts);
00100 m_aborted.resize(numberPlayouts);
00101 }
00102 for (size_t i = 0; i < numberPlayouts; ++i)
00103 {
00104 m_sequence[i].clear();
00105 m_skipRaveUpdate[i].clear();
00106 }
00107 }
00108
00109
00110
00111 SgUctThreadState::SgUctThreadState(unsigned int threadId, int moveRange)
00112 : m_threadId(threadId),
00113 m_isSearchInitialized(false),
00114 m_isTreeOutOfMem(false)
00115 {
00116 if (moveRange > 0)
00117 {
00118 m_firstPlay.reset(new size_t[moveRange]);
00119 m_firstPlayOpp.reset(new size_t[moveRange]);
00120 }
00121 }
00122
00123 SgUctThreadState::~SgUctThreadState()
00124 {
00125 }
00126
00127 void SgUctThreadState::EndPlayout()
00128 {
00129
00130 }
00131
00132 void SgUctThreadState::GameStart()
00133 {
00134
00135 }
00136
00137 void SgUctThreadState::StartPlayout()
00138 {
00139
00140 }
00141
00142 void SgUctThreadState::StartPlayouts()
00143 {
00144
00145 }
00146
00147
00148
00149 SgUctThreadStateFactory::~SgUctThreadStateFactory()
00150 {
00151 }
00152
00153
00154
00155 SgUctSearch::Thread::Function::Function(Thread& thread)
00156 : m_thread(thread)
00157 {
00158 }
00159
00160 void SgUctSearch::Thread::Function::operator()()
00161 {
00162 m_thread();
00163 }
00164
00165 SgUctSearch::Thread::Thread(SgUctSearch& search,
00166 auto_ptr<SgUctThreadState> state)
00167 : m_state(state),
00168 m_search(search),
00169 m_quit(false),
00170 m_threadReady(2),
00171 m_playFinishedLock(m_playFinishedMutex),
00172 #if BOOST_VERSION_MAJOR == 1 && BOOST_VERSION_MINOR <= 34
00173 m_globalLock(search.m_globalMutex, false),
00174 #else
00175 m_globalLock(search.m_globalMutex, boost::defer_lock),
00176 #endif
00177 m_thread(Function(*this))
00178 {
00179 m_threadReady.wait();
00180 }
00181
00182 SgUctSearch::Thread::~Thread()
00183 {
00184 m_quit = true;
00185 StartPlay();
00186 m_thread.join();
00187 }
00188
00189 void SgUctSearch::Thread::operator()()
00190 {
00191 if (DEBUG_THREADS)
00192 SgDebug() << "SgUctSearch::Thread: starting thread "
00193 << m_state->m_threadId << '\n';
00194 mutex::scoped_lock lock(m_startPlayMutex);
00195 m_threadReady.wait();
00196 while (true)
00197 {
00198 m_startPlay.wait(lock);
00199 if (m_quit)
00200 break;
00201 m_search.SearchLoop(*m_state, &m_globalLock);
00202 Notify(m_playFinishedMutex, m_playFinished);
00203 }
00204 if (DEBUG_THREADS)
00205 SgDebug() << "SgUctSearch::Thread: finishing thread "
00206 << m_state->m_threadId << '\n';
00207 }
00208
00209 void SgUctSearch::Thread::StartPlay()
00210 {
00211 Notify(m_startPlayMutex, m_startPlay);
00212 }
00213
00214 void SgUctSearch::Thread::WaitPlayFinished()
00215 {
00216 m_playFinished.wait(m_playFinishedLock);
00217 }
00218
00219
00220
00221 void SgUctSearchStat::Clear()
00222 {
00223 m_time = 0;
00224 m_knowledge = 0;
00225 m_gamesPerSecond = 0;
00226 m_gameLength.Clear();
00227 m_movesInTree.Clear();
00228 m_aborted.Clear();
00229 }
00230
00231 void SgUctSearchStat::Write(std::ostream& out) const
00232 {
00233 ios_all_saver saver(out);
00234 out << SgWriteLabel("Time") << setprecision(2) << m_time << '\n'
00235 << SgWriteLabel("GameLen") << fixed << setprecision(1);
00236 m_gameLength.Write(out);
00237 out << '\n'
00238 << SgWriteLabel("InTree");
00239 m_movesInTree.Write(out);
00240 out << '\n'
00241 << SgWriteLabel("Aborted")
00242 << static_cast<int>(100 * m_aborted.Mean()) << "%\n"
00243 << SgWriteLabel("Games/s") << fixed << setprecision(1)
00244 << m_gamesPerSecond << '\n';
00245 }
00246
00247
00248
00249 SgUctSearch::SgUctSearch(SgUctThreadStateFactory* threadStateFactory,
00250 int moveRange)
00251 : m_threadStateFactory(threadStateFactory),
00252 m_logGames(false),
00253 m_rave(false),
00254 m_knowledgeThreshold(),
00255 m_moveSelect(SG_UCTMOVESELECT_COUNT),
00256 m_raveCheckSame(false),
00257 m_randomizeRaveFrequency(20),
00258 m_lockFree(GetLockFreeDefault()),
00259 m_weightRaveUpdates(true),
00260 m_pruneFullTree(true),
00261 m_checkFloatPrecision(true),
00262 m_numberThreads(1),
00263 m_numberPlayouts(1),
00264 m_maxNodes(GetMaxNodesDefault()),
00265 m_pruneMinCount(16),
00266 m_moveRange(moveRange),
00267 m_maxGameLength(numeric_limits<size_t>::max()),
00268 m_expandThreshold(numeric_limits<SgUctValue>::is_integer ? (SgUctValue)1 : numeric_limits<SgUctValue>::epsilon()),
00269 m_biasTermConstant(0.7f),
00270 m_firstPlayUrgency(10000),
00271 m_raveWeightInitial(0.9f),
00272 m_raveWeightFinal(20000),
00273 m_virtualLoss(false),
00274 m_logFileName("uctsearch.log"),
00275 m_fastLog(10),
00276 m_mpiSynchronizer(SgMpiNullSynchronizer::Create())
00277 {
00278
00279
00280
00281 }
00282
00283 SgUctSearch::~SgUctSearch()
00284 {
00285 DeleteThreads();
00286 }
00287
00288 void SgUctSearch::ApplyRootFilter(vector<SgUctMoveInfo>& moves)
00289 {
00290
00291 vector<SgUctMoveInfo> filteredMoves;
00292 for (vector<SgUctMoveInfo>::const_iterator it = moves.begin();
00293 it != moves.end(); ++it)
00294 if (find(m_rootFilter.begin(), m_rootFilter.end(), it->m_move)
00295 == m_rootFilter.end())
00296 filteredMoves.push_back(*it);
00297 moves = filteredMoves;
00298 }
00299
00300 SgUctValue SgUctSearch::GamesPlayed() const
00301 {
00302 return m_tree.Root().MoveCount() - m_startRootMoveCount;
00303 }
00304
00305 bool SgUctSearch::CheckAbortSearch(SgUctThreadState& state)
00306 {
00307 if (SgUserAbort())
00308 {
00309 Debug(state, "SgUctSearch: abort flag");
00310 return true;
00311 }
00312 const SgUctNode& root = m_tree.Root();
00313 if (! SgUctValueUtil::IsPrecise(root.MoveCount()) && m_checkFloatPrecision)
00314 {
00315 Debug(state, "SgUctSearch: floating point type precision reached");
00316 return true;
00317 }
00318 SgUctValue rootCount = root.MoveCount();
00319 if (rootCount >= m_maxGames)
00320 {
00321 Debug(state, "SgUctSearch: max games reached");
00322 return true;
00323 }
00324 if (root.IsProven())
00325 {
00326 if (root.IsProvenWin())
00327 Debug(state, "SgUctSearch: root is proven win!");
00328 else
00329 Debug(state, "SgUctSearch: root is proven loss!");
00330 return true;
00331 }
00332 const bool isEarlyAbort = CheckEarlyAbort();
00333 if ( isEarlyAbort
00334 && m_earlyAbort->m_reductionFactor * rootCount >= m_maxGames
00335 )
00336 {
00337 Debug(state, "SgUctSearch: max games reached (early abort)");
00338 m_wasEarlyAbort = true;
00339 return true;
00340 }
00341 if (m_numberGames >= m_nextCheckTime)
00342 {
00343 m_nextCheckTime = m_numberGames + m_checkTimeInterval;
00344 double time = m_timer.GetTime();
00345 if (time > m_maxTime)
00346 {
00347 Debug(state, "SgUctSearch: max time reached");
00348 return true;
00349 }
00350 if (isEarlyAbort
00351 && m_earlyAbort->m_reductionFactor * time > m_maxTime)
00352 {
00353 Debug(state, "SgUctSearch: max time reached (early abort)");
00354 m_wasEarlyAbort = true;
00355 return true;
00356 }
00357 UpdateCheckTimeInterval(time);
00358 if (m_moveSelect == SG_UCTMOVESELECT_COUNT)
00359 {
00360 double remainingGamesDouble = m_maxGames - rootCount - 1;
00361
00362
00363 if (time > 1.)
00364 {
00365 double remainingTime = m_maxTime - time;
00366 remainingGamesDouble =
00367 min(remainingGamesDouble,
00368 remainingTime * m_statistics.m_gamesPerSecond);
00369 }
00370 SgUctValue uctCountMax = numeric_limits<SgUctValue>::max();
00371 SgUctValue remainingGames;
00372 if (remainingGamesDouble >= static_cast<double>(uctCountMax - 1))
00373 remainingGames = uctCountMax;
00374 else
00375 remainingGames = SgUctValue(remainingGamesDouble);
00376 if (CheckCountAbort(state, remainingGames))
00377 {
00378 Debug(state, "SgUctSearch: move cannot change anymore");
00379 return true;
00380 }
00381 }
00382 }
00383 return false;
00384 }
00385
00386 bool SgUctSearch::CheckCountAbort(SgUctThreadState& state,
00387 SgUctValue remainingGames) const
00388 {
00389 const SgUctNode& root = m_tree.Root();
00390 const SgUctNode* bestChild = FindBestChild(root);
00391 if (bestChild == 0)
00392 return false;
00393 SgUctValue bestCount = bestChild->MoveCount();
00394 vector<SgMove>& excludeMoves = state.m_excludeMoves;
00395 excludeMoves.clear();
00396 excludeMoves.push_back(bestChild->Move());
00397 const SgUctNode* secondBestChild = FindBestChild(root, &excludeMoves);
00398 if (secondBestChild == 0)
00399 return false;
00400 SgUctValue secondBestCount = secondBestChild->MoveCount();
00401 SG_ASSERT(secondBestCount <= bestCount || m_numberThreads > 1);
00402 return (remainingGames <= bestCount - secondBestCount);
00403 }
00404
00405 bool SgUctSearch::CheckEarlyAbort() const
00406 {
00407 const SgUctNode& root = m_tree.Root();
00408 return m_earlyAbort.get() != 0
00409 && root.HasMean()
00410 && root.MoveCount() > m_earlyAbort->m_minGames
00411 && root.Mean() > m_earlyAbort->m_threshold;
00412 }
00413
00414 void SgUctSearch::CreateThreads()
00415 {
00416 DeleteThreads();
00417 for (unsigned int i = 0; i < m_numberThreads; ++i)
00418 {
00419 auto_ptr<SgUctThreadState> state(
00420 m_threadStateFactory->Create(i, *this));
00421 shared_ptr<Thread> thread(new Thread(*this, state));
00422 m_threads.push_back(thread);
00423 }
00424 m_tree.CreateAllocators(m_numberThreads);
00425 m_tree.SetMaxNodes(m_maxNodes);
00426
00427 m_searchLoopFinished.reset(new barrier(m_numberThreads));
00428 }
00429
00430
00431
00432
00433
00434
00435
00436 void SgUctSearch::Debug(const SgUctThreadState& state,
00437 const std::string& textLine)
00438 {
00439 if (m_numberThreads > 1)
00440 {
00441
00442 GlobalLock lock(m_globalMutex);
00443 SgDebug() << (format("[%1%] %2%\n") % state.m_threadId % textLine);
00444 }
00445 else
00446 SgDebug() << (format("%1%\n") % textLine);
00447 }
00448
00449 void SgUctSearch::DeleteThreads()
00450 {
00451 m_threads.clear();
00452 }
00453
00454
00455
00456
00457 void SgUctSearch::ExpandNode(SgUctThreadState& state, const SgUctNode& node)
00458 {
00459 unsigned int threadId = state.m_threadId;
00460 if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00461 {
00462 Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00463 % m_tree.MaxNodes()));
00464 state.m_isTreeOutOfMem = true;
00465 m_isTreeOutOfMemory = true;
00466 SgSynchronizeThreadMemory();
00467 return;
00468 }
00469 m_tree.CreateChildren(threadId, node, state.m_moves);
00470 }
00471
00472 const SgUctNode*
00473 SgUctSearch::FindBestChild(const SgUctNode& node,
00474 const vector<SgMove>* excludeMoves) const
00475 {
00476 if (! node.HasChildren())
00477 return 0;
00478 const SgUctNode* bestChild = 0;
00479 SgUctValue bestValue = 0;
00480 for (SgUctChildIterator it(m_tree, node); it; ++it)
00481 {
00482 const SgUctNode& child = *it;
00483 if (excludeMoves != 0)
00484 {
00485 vector<SgMove>::const_iterator begin = excludeMoves->begin();
00486 vector<SgMove>::const_iterator end = excludeMoves->end();
00487 if (find(begin, end, child.Move()) != end)
00488 continue;
00489 }
00490 if ( ! child.HasMean()
00491 && ! ( ( m_moveSelect == SG_UCTMOVESELECT_BOUND
00492 || m_moveSelect == SG_UCTMOVESELECT_ESTIMATE
00493 )
00494 && m_rave
00495 && child.HasRaveValue()
00496 )
00497 )
00498 continue;
00499 if (child.IsProvenLoss())
00500 {
00501 bestChild = &child;
00502 break;
00503 }
00504 SgUctValue value;
00505 switch (m_moveSelect)
00506 {
00507 case SG_UCTMOVESELECT_VALUE:
00508 value = InverseEstimate((SgUctValue)child.Mean());
00509 break;
00510 case SG_UCTMOVESELECT_COUNT:
00511 value = child.MoveCount();
00512 break;
00513 case SG_UCTMOVESELECT_BOUND:
00514 value = GetBound(m_rave, node, child);
00515 break;
00516 case SG_UCTMOVESELECT_ESTIMATE:
00517 value = GetValueEstimate(m_rave, child);
00518 break;
00519 default:
00520 SG_ASSERT(false);
00521 value = SG_UCTMOVESELECT_VALUE;
00522 }
00523 if (bestChild == 0 || value > bestValue)
00524 {
00525 bestChild = &child;
00526 bestValue = value;
00527 }
00528 }
00529 return bestChild;
00530 }
00531
00532 void SgUctSearch::FindBestSequence(vector<SgMove>& sequence) const
00533 {
00534 sequence.clear();
00535 const SgUctNode* current = &m_tree.Root();
00536 while (true)
00537 {
00538 current = FindBestChild(*current);
00539 if (current == 0)
00540 break;
00541 sequence.push_back(current->Move());
00542 if (! current->HasChildren())
00543 break;
00544 }
00545 }
00546
00547 void SgUctSearch::GenerateAllMoves(std::vector<SgUctMoveInfo>& moves)
00548 {
00549 if (m_threads.size() == 0)
00550 CreateThreads();
00551 moves.clear();
00552 OnStartSearch();
00553 SgUctThreadState& state = ThreadState(0);
00554 state.StartSearch();
00555 SgUctProvenType type;
00556 state.GenerateAllMoves(0, moves, type);
00557 }
00558
00559 SgUctValue SgUctSearch::GetBound(bool useRave, const SgUctNode& node,
00560 const SgUctNode& child) const
00561 {
00562 SgUctValue posCount = node.PosCount();
00563 int virtualLossCount = node.VirtualLossCount();
00564 if (virtualLossCount > 0)
00565 {
00566 posCount += SgUctValue(virtualLossCount);
00567 }
00568 return GetBound(useRave, Log(posCount), child);
00569 }
00570
00571 SgUctValue SgUctSearch::GetBound(bool useRave, SgUctValue logPosCount,
00572 const SgUctNode& child) const
00573 {
00574 SgUctValue value;
00575 if (useRave)
00576 value = GetValueEstimateRave(child);
00577 else
00578 value = GetValueEstimate(false, child);
00579 if (m_biasTermConstant == 0.0)
00580 return value;
00581 else
00582 {
00583 SgUctValue moveCount = static_cast<SgUctValue>(child.MoveCount());
00584 SgUctValue bound =
00585 value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1));
00586 return bound;
00587 }
00588 }
00589
00590 SgUctTree& SgUctSearch::GetTempTree()
00591 {
00592 m_tempTree.Clear();
00593
00594
00595
00596 if (m_tempTree.NuAllocators() != NumberThreads())
00597 {
00598 m_tempTree.CreateAllocators(NumberThreads());
00599 m_tempTree.SetMaxNodes(MaxNodes());
00600 }
00601 else if (m_tempTree.MaxNodes() != MaxNodes())
00602 {
00603 m_tempTree.SetMaxNodes(MaxNodes());
00604 }
00605 return m_tempTree;
00606 }
00607
00608 SgUctValue SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const
00609 {
00610 SgUctValue value = 0;
00611 SgUctValue weightSum = 0;
00612 bool hasValue = false;
00613
00614 SgUctStatistics uctStats;
00615 if (child.HasMean())
00616 {
00617 uctStats.Initialize(child.Mean(), child.MoveCount());
00618 }
00619 int virtualLossCount = child.VirtualLossCount();
00620 if (virtualLossCount > 0)
00621 {
00622 uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
00623 }
00624
00625 if (uctStats.IsDefined())
00626 {
00627 SgUctValue weight = static_cast<SgUctValue>(uctStats.Count());
00628 value += weight * InverseEstimate((SgUctValue)uctStats.Mean());
00629 weightSum += weight;
00630 hasValue = true;
00631 }
00632
00633 if (useRave)
00634 {
00635 SgUctStatistics raveStats;
00636 if (child.HasRaveValue())
00637 {
00638 raveStats.Initialize(child.RaveValue(), child.RaveCount());
00639 }
00640 if (virtualLossCount > 0)
00641 {
00642 raveStats.Add(0, SgUctValue(virtualLossCount));
00643 }
00644 if (raveStats.IsDefined())
00645 {
00646 SgUctValue raveCount = raveStats.Count();
00647 SgUctValue weight =
00648 raveCount
00649 / ( m_raveWeightParam1
00650 + m_raveWeightParam2 * raveCount
00651 );
00652 value += weight * raveStats.Mean();
00653 weightSum += weight;
00654 hasValue = true;
00655 }
00656 }
00657 if (hasValue)
00658 return value / weightSum;
00659 else
00660 return m_firstPlayUrgency;
00661 }
00662
00663
00664
00665
00666
00667
00668 SgUctValue SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const
00669 {
00670 SG_ASSERT(m_rave);
00671 SgUctValue value;
00672 SgUctStatistics uctStats;
00673 if (child.HasMean())
00674 {
00675 uctStats.Initialize(child.Mean(), child.MoveCount());
00676 }
00677 SgUctStatistics raveStats;
00678 if (child.HasRaveValue())
00679 {
00680 raveStats.Initialize(child.RaveValue(), child.RaveCount());
00681 }
00682 int virtualLossCount = child.VirtualLossCount();
00683 if (virtualLossCount > 0)
00684 {
00685 uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
00686 raveStats.Add(0, SgUctValue(virtualLossCount));
00687 }
00688 bool hasRave = raveStats.IsDefined();
00689
00690 if (uctStats.IsDefined())
00691 {
00692 SgUctValue moveValue = InverseEstimate((SgUctValue)uctStats.Mean());
00693 if (hasRave)
00694 {
00695 SgUctValue moveCount = uctStats.Count();
00696 SgUctValue raveCount = raveStats.Count();
00697 SgUctValue weight =
00698 raveCount
00699 / (moveCount
00700 * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount)
00701 + raveCount);
00702 value = weight * raveStats.Mean() + (1.f - weight) * moveValue;
00703 }
00704 else
00705 {
00706
00707
00708
00709
00710 SG_ASSERT(m_numberThreads > 1 && m_lockFree);
00711 value = moveValue;
00712 }
00713 }
00714 else if (hasRave)
00715 value = raveStats.Mean();
00716 else
00717 value = m_firstPlayUrgency;
00718 SG_ASSERT(m_numberThreads > 1
00719 || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3);
00720 return value;
00721 }
00722
00723 string SgUctSearch::LastGameSummaryLine() const
00724 {
00725 return SummaryLine(LastGameInfo());
00726 }
00727
00728 SgUctValue SgUctSearch::Log(SgUctValue x) const
00729 {
00730 #if SG_UCTFASTLOG
00731 return SgUctValue(m_fastLog.Log(float(x)));
00732 #else
00733 return log(x);
00734 #endif
00735 }
00736
00737
00738
00739 void SgUctSearch::CreateChildren(SgUctThreadState& state,
00740 const SgUctNode& node,
00741 bool deleteChildTrees)
00742 {
00743 unsigned int threadId = state.m_threadId;
00744 if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00745 {
00746 Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00747 % m_tree.MaxNodes()));
00748 state.m_isTreeOutOfMem = true;
00749 m_isTreeOutOfMemory = true;
00750 SgSynchronizeThreadMemory();
00751 return;
00752 }
00753 m_tree.MergeChildren(threadId, node, state.m_moves, deleteChildTrees);
00754 }
00755
00756 bool SgUctSearch::NeedToComputeKnowledge(const SgUctNode* current)
00757 {
00758 if (m_knowledgeThreshold.empty())
00759 return false;
00760 for (std::size_t i = 0; i < m_knowledgeThreshold.size(); ++i)
00761 {
00762 const SgUctValue threshold = m_knowledgeThreshold[i];
00763 if (current->KnowledgeCount() < threshold)
00764 {
00765 if (current->MoveCount() >= threshold)
00766 {
00767
00768
00769
00770 m_tree.SetKnowledgeCount(*current, threshold);
00771 SG_ASSERT(current->MoveCount());
00772 return true;
00773 }
00774 return false;
00775 }
00776 }
00777 return false;
00778 }
00779
00780 void SgUctSearch::OnStartSearch()
00781 {
00782 m_mpiSynchronizer->OnStartSearch(*this);
00783 }
00784
00785 void SgUctSearch::OnEndSearch()
00786 {
00787 m_mpiSynchronizer->OnEndSearch(*this);
00788 }
00789
00790
00791 void SgUctSearch::PrintSearchProgress(double currTime) const
00792 {
00793 const int MAX_SEQ_PRINT_LENGTH = 15;
00794 const SgUctValue MIN_MOVE_COUNT = 10;
00795 SgUctValue rootMoveCount = m_tree.Root().MoveCount();
00796 SgUctValue rootMean = m_tree.Root().Mean();
00797 ostringstream out;
00798 const SgUctNode* current = &m_tree.Root();
00799 out << (format("%s | %.3f | %.0f ")
00800 % SgTime::Format(currTime, true) % rootMean % rootMoveCount);
00801 for (int i = 0; i <= MAX_SEQ_PRINT_LENGTH && current->HasChildren(); ++i)
00802 {
00803 current = FindBestChild(*current);
00804 if (current == 0 || current->MoveCount() < MIN_MOVE_COUNT)
00805 break;
00806 if (i == 0)
00807 out << "|";
00808 if (i < MAX_SEQ_PRINT_LENGTH)
00809 out << " " << MoveString(current->Move());
00810 else
00811 out << " *";
00812 }
00813 SgDebug() << out.str() << endl;
00814 }
00815
00816 void SgUctSearch::OnSearchIteration(SgUctValue gameNumber,
00817 unsigned int threadId,
00818 const SgUctGameInfo& info)
00819 {
00820 const int DISPLAY_INTERVAL = 5;
00821
00822 m_mpiSynchronizer->OnSearchIteration(*this, gameNumber, threadId, info);
00823 double currTime = m_timer.GetTime();
00824
00825 if (threadId == 0 && currTime - m_lastScoreDisplayTime > DISPLAY_INTERVAL)
00826 {
00827 PrintSearchProgress(currTime);
00828 m_lastScoreDisplayTime = currTime;
00829 }
00830 }
00831
00832 void SgUctSearch::PlayGame(SgUctThreadState& state, GlobalLock* lock)
00833 {
00834 state.m_isTreeOutOfMem = false;
00835 state.GameStart();
00836 SgUctGameInfo& info = state.m_gameInfo;
00837 info.Clear(m_numberPlayouts);
00838 bool isTerminal;
00839 bool abortInTree = ! PlayInTree(state, isTerminal);
00840
00841
00842 if (lock != 0)
00843 lock->unlock();
00844
00845 if (!info.m_nodes.empty() && isTerminal)
00846 {
00847 const SgUctNode& terminalNode = *info.m_nodes.back();
00848 SgUctValue eval = state.Evaluate();
00849 if (eval > 0.6)
00850 m_tree.SetProvenType(terminalNode, SG_PROVEN_WIN);
00851 else if (eval < 0.6)
00852 m_tree.SetProvenType(terminalNode, SG_PROVEN_LOSS);
00853 PropagateProvenStatus(info.m_nodes);
00854 }
00855
00856 size_t nuMovesInTree = info.m_inTreeSequence.size();
00857
00858
00859 if (! info.m_nodes.empty() && info.m_nodes.back()->IsProven())
00860 {
00861 for (size_t i = 0; i < m_numberPlayouts; ++i)
00862 {
00863 info.m_sequence[i] = info.m_inTreeSequence;
00864 info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00865 SgUctValue eval = info.m_nodes.back()->IsProvenWin() ? 1 : 0;
00866 size_t nuMoves = info.m_sequence[i].size();
00867 if (nuMoves % 2 != 0)
00868 eval = InverseEval(eval);
00869 info.m_aborted[i] = abortInTree || state.m_isTreeOutOfMem;
00870 info.m_eval[i] = eval;
00871 }
00872 }
00873 else
00874 {
00875 state.StartPlayouts();
00876 for (size_t i = 0; i < m_numberPlayouts; ++i)
00877 {
00878 state.StartPlayout();
00879 info.m_sequence[i] = info.m_inTreeSequence;
00880
00881 info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00882 bool abort = abortInTree || state.m_isTreeOutOfMem;
00883 if (! abort && ! isTerminal)
00884 abort = ! PlayoutGame(state, i);
00885 SgUctValue eval;
00886 if (abort)
00887 eval = UnknownEval();
00888 else
00889 eval = state.Evaluate();
00890 size_t nuMoves = info.m_sequence[i].size();
00891 if (nuMoves % 2 != 0)
00892 eval = InverseEval(eval);
00893 info.m_aborted[i] = abort;
00894 info.m_eval[i] = eval;
00895 state.EndPlayout();
00896 state.TakeBackPlayout(nuMoves - nuMovesInTree);
00897 }
00898 }
00899 state.TakeBackInTree(nuMovesInTree);
00900
00901
00902 if (lock != 0)
00903 lock->lock();
00904
00905 UpdateTree(info);
00906 if (m_rave)
00907 UpdateRaveValues(state);
00908 UpdateStatistics(info);
00909 }
00910
00911
00912
00913 void SgUctSearch::PropagateProvenStatus(const vector<const SgUctNode*>& nodes)
00914 {
00915 if (nodes.size() <= 1)
00916 return;
00917 size_t i = nodes.size() - 2;
00918 while (true)
00919 {
00920 const SgUctNode& parent = *nodes[i];
00921 SgUctProvenType type = SG_PROVEN_LOSS;
00922 for (SgUctChildIterator it(m_tree, parent); it; ++it)
00923 {
00924 const SgUctNode& child = *it;
00925 if (!child.IsProven())
00926 type = SG_NOT_PROVEN;
00927 else if (child.IsProvenLoss())
00928 {
00929 type = SG_PROVEN_WIN;
00930 break;
00931 }
00932 }
00933 if (type == SG_NOT_PROVEN)
00934 break;
00935 else
00936 m_tree.SetProvenType(parent, type);
00937 if (i == 0)
00938 break;
00939 --i;
00940 }
00941 }
00942
00943
00944
00945
00946
00947
00948 bool SgUctSearch::PlayInTree(SgUctThreadState& state, bool& isTerminal)
00949 {
00950 vector<SgMove>& sequence = state.m_gameInfo.m_inTreeSequence;
00951 vector<const SgUctNode*>& nodes = state.m_gameInfo.m_nodes;
00952 const SgUctNode* root = &m_tree.Root();
00953 const SgUctNode* current = root;
00954 if (m_virtualLoss && m_numberThreads > 1)
00955 m_tree.AddVirtualLoss(*current);
00956 nodes.push_back(current);
00957 bool breakAfterSelect = false;
00958 isTerminal = false;
00959 bool deepenTree = false;
00960 while (true)
00961 {
00962 if (sequence.size() == m_maxGameLength)
00963 return false;
00964 if (current->IsProven())
00965 break;
00966 if (! current->HasChildren())
00967 {
00968 state.m_moves.clear();
00969 SgUctProvenType provenType = SG_NOT_PROVEN;
00970 state.GenerateAllMoves(0, state.m_moves, provenType);
00971 if (current == root)
00972 ApplyRootFilter(state.m_moves);
00973 if (provenType != SG_NOT_PROVEN)
00974 {
00975 m_tree.SetProvenType(*current, provenType);
00976 PropagateProvenStatus(nodes);
00977 break;
00978 }
00979 if (state.m_moves.empty())
00980 {
00981 isTerminal = true;
00982 break;
00983 }
00984 if ( deepenTree
00985 || current->MoveCount() >= m_expandThreshold
00986 )
00987 {
00988 deepenTree = false;
00989 ExpandNode(state, *current);
00990 if (state.m_isTreeOutOfMem)
00991 return true;
00992 if (! deepenTree)
00993 breakAfterSelect = true;
00994 }
00995 else
00996 break;
00997 }
00998 else if (NeedToComputeKnowledge(current))
00999 {
01000 m_statistics.m_knowledge++;
01001 deepenTree = false;
01002 SgUctProvenType provenType = SG_NOT_PROVEN;
01003 bool truncate = state.GenerateAllMoves(current->KnowledgeCount(),
01004 state.m_moves,
01005 provenType);
01006 if (current == root)
01007 ApplyRootFilter(state.m_moves);
01008 CreateChildren(state, *current, truncate);
01009 if (provenType != SG_NOT_PROVEN)
01010 {
01011 m_tree.SetProvenType(*current, provenType);
01012 PropagateProvenStatus(nodes);
01013 break;
01014 }
01015 if (state.m_moves.empty())
01016 {
01017 isTerminal = true;
01018 break;
01019 }
01020 if (state.m_isTreeOutOfMem)
01021 return true;
01022 if (! deepenTree)
01023 breakAfterSelect = true;
01024 }
01025 current = &SelectChild(state.m_randomizeCounter, *current);
01026 if (m_virtualLoss && m_numberThreads > 1)
01027 m_tree.AddVirtualLoss(*current);
01028 nodes.push_back(current);
01029 SgMove move = current->Move();
01030 state.Execute(move);
01031 sequence.push_back(move);
01032 if (breakAfterSelect)
01033 break;
01034 }
01035 return true;
01036 }
01037
01038
01039
01040
01041
01042 bool SgUctSearch::PlayoutGame(SgUctThreadState& state, std::size_t playout)
01043 {
01044 SgUctGameInfo& info = state.m_gameInfo;
01045 vector<SgMove>& sequence = info.m_sequence[playout];
01046 vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
01047 while (true)
01048 {
01049 if (sequence.size() == m_maxGameLength)
01050 return false;
01051 bool skipRave = false;
01052 SgMove move = state.GeneratePlayoutMove(skipRave);
01053 if (move == SG_NULLMOVE)
01054 break;
01055 state.ExecutePlayout(move);
01056 sequence.push_back(move);
01057 skipRaveUpdate.push_back(skipRave);
01058 }
01059 return true;
01060 }
01061
01062 SgUctValue SgUctSearch::Search(SgUctValue maxGames, double maxTime,
01063 vector<SgMove>& sequence,
01064 const vector<SgMove>& rootFilter,
01065 SgUctTree* initTree,
01066 SgUctEarlyAbortParam* earlyAbort)
01067 {
01068 m_timer.Start();
01069 m_rootFilter = rootFilter;
01070 if (m_logGames)
01071 {
01072 m_log.open(m_mpiSynchronizer->ToNodeFilename(m_logFileName).c_str());
01073 m_log << "StartSearch maxGames=" << maxGames << '\n';
01074 }
01075 m_maxGames = maxGames;
01076 m_maxTime = maxTime;
01077 m_earlyAbort.reset(0);
01078 if (earlyAbort != 0)
01079 m_earlyAbort.reset(new SgUctEarlyAbortParam(*earlyAbort));
01080
01081 for (size_t i = 0; i < m_threads.size(); ++i)
01082 {
01083 m_threads[i]->m_state->m_isSearchInitialized = false;
01084 }
01085 StartSearch(rootFilter, initTree);
01086 SgUctValue pruneMinCount = m_pruneMinCount;
01087 while (true)
01088 {
01089 m_isTreeOutOfMemory = false;
01090 SgSynchronizeThreadMemory();
01091 for (size_t i = 0; i < m_threads.size(); ++i)
01092 m_threads[i]->StartPlay();
01093 for (size_t i = 0; i < m_threads.size(); ++i)
01094 m_threads[i]->WaitPlayFinished();
01095 if (m_aborted || ! m_pruneFullTree)
01096 break;
01097 else
01098 {
01099 double startPruneTime = m_timer.GetTime();
01100 SgDebug() << "SgUctSearch: pruning nodes with count < "
01101 << pruneMinCount << " (at time " << fixed << setprecision(1)
01102 << startPruneTime << ")\n";
01103 SgUctTree& tempTree = GetTempTree();
01104 m_tree.CopyPruneLowCount(tempTree, pruneMinCount, true);
01105 int prunedSizePercentage =
01106 static_cast<int>(tempTree.NuNodes() * 100 / m_tree.NuNodes());
01107 SgDebug() << "SgUctSearch: pruned size: " << tempTree.NuNodes()
01108 << " (" << prunedSizePercentage << "%) time: "
01109 << (m_timer.GetTime() - startPruneTime) << "\n";
01110 if (prunedSizePercentage > 50)
01111 pruneMinCount *= 2;
01112 else
01113 pruneMinCount = m_pruneMinCount;
01114 m_tree.Swap(tempTree);
01115 }
01116 }
01117 EndSearch();
01118 m_statistics.m_time = m_timer.GetTime();
01119 if (m_statistics.m_time > numeric_limits<double>::epsilon())
01120 m_statistics.m_gamesPerSecond = GamesPlayed() / m_statistics.m_time;
01121 if (m_logGames)
01122 m_log.close();
01123 FindBestSequence(sequence);
01124 return (m_tree.Root().MoveCount() > 0) ? (SgUctValue)m_tree.Root().Mean() : (SgUctValue)0.5;
01125 }
01126
01127
01128 void SgUctSearch::SearchLoop(SgUctThreadState& state, GlobalLock* lock)
01129 {
01130 if (! state.m_isSearchInitialized)
01131 {
01132 OnThreadStartSearch(state);
01133 state.m_isSearchInitialized = true;
01134 }
01135
01136 if (NumberThreads() == 1 || m_lockFree)
01137 lock = 0;
01138 if (lock != 0)
01139 lock->lock();
01140 state.m_isTreeOutOfMem = false;
01141 while (! state.m_isTreeOutOfMem)
01142 {
01143 PlayGame(state, lock);
01144 OnSearchIteration(m_numberGames + 1, state.m_threadId,
01145 state.m_gameInfo);
01146 if (m_logGames)
01147 m_log << SummaryLine(state.m_gameInfo) << '\n';
01148 ++m_numberGames;
01149 if (m_isTreeOutOfMemory)
01150 break;
01151 if (m_aborted || CheckAbortSearch(state))
01152 {
01153 m_aborted = true;
01154 SgSynchronizeThreadMemory();
01155 break;
01156 }
01157 }
01158 if (lock != 0)
01159 lock->unlock();
01160
01161 m_searchLoopFinished->wait();
01162 if (m_aborted || ! m_pruneFullTree)
01163 OnThreadEndSearch(state);
01164 }
01165
01166 void SgUctSearch::OnThreadStartSearch(SgUctThreadState& state)
01167 {
01168 m_mpiSynchronizer->OnThreadStartSearch(*this, state);
01169 }
01170
01171 void SgUctSearch::OnThreadEndSearch(SgUctThreadState& state)
01172 {
01173 m_mpiSynchronizer->OnThreadEndSearch(*this, state);
01174 }
01175
01176 SgPoint SgUctSearch::SearchOnePly(SgUctValue maxGames, double maxTime,
01177 SgUctValue& value)
01178 {
01179 if (m_threads.size() == 0)
01180 CreateThreads();
01181 OnStartSearch();
01182
01183
01184 SgUctThreadState& state = ThreadState(0);
01185 state.StartSearch();
01186 vector<SgUctMoveInfo> moves;
01187 SgUctProvenType provenType;
01188 state.GameStart();
01189 state.GenerateAllMoves(0, moves, provenType);
01190 vector<SgUctStatistics> statistics(moves.size());
01191 SgUctValue games = 0;
01192 m_timer.Start();
01193 SgUctGameInfo& info = state.m_gameInfo;
01194 while (games < maxGames && m_timer.GetTime() < maxTime && ! SgUserAbort())
01195 {
01196 for (size_t i = 0; i < moves.size(); ++i)
01197 {
01198 state.GameStart();
01199 info.Clear(1);
01200 SgMove move = moves[i].m_move;
01201 state.Execute(move);
01202 info.m_inTreeSequence.push_back(move);
01203 info.m_sequence[0].push_back(move);
01204 info.m_skipRaveUpdate[0].push_back(false);
01205 state.StartPlayouts();
01206 state.StartPlayout();
01207 bool abortGame = ! PlayoutGame(state, 0);
01208 SgUctValue eval;
01209 if (abortGame)
01210 eval = UnknownEval();
01211 else
01212 eval = state.Evaluate();
01213 state.EndPlayout();
01214 state.TakeBackPlayout(info.m_sequence[0].size() - 1);
01215 state.TakeBackInTree(1);
01216 statistics[i].Add(info.m_sequence[0].size() % 2 == 0 ?
01217 eval : InverseEval(eval));
01218 OnSearchIteration(games + 1, 0, info);
01219 games += 1;
01220 }
01221 }
01222 SgMove bestMove = SG_NULLMOVE;
01223 for (size_t i = 0; i < moves.size(); ++i)
01224 {
01225 SgDebug() << MoveString(moves[i].m_move)
01226 << ' ' << statistics[i].Mean() << '\n';
01227 if (bestMove == SG_NULLMOVE || statistics[i].Mean() > value)
01228 {
01229 bestMove = moves[i].m_move;
01230 value = statistics[i].Mean();
01231 }
01232 }
01233 return bestMove;
01234 }
01235
01236 const SgUctNode& SgUctSearch::SelectChild(int& randomizeCounter,
01237 const SgUctNode& node)
01238 {
01239 bool useRave = m_rave;
01240 if (m_randomizeRaveFrequency > 0 && --randomizeCounter == 0)
01241 {
01242 useRave = false;
01243 randomizeCounter = m_randomizeRaveFrequency;
01244 }
01245 SG_ASSERT(node.HasChildren());
01246 SgUctValue posCount = node.PosCount();
01247 int virtualLossCount = node.VirtualLossCount();
01248 if (virtualLossCount > 1)
01249 {
01250
01251
01252 posCount += SgUctValue(virtualLossCount - 1);
01253 }
01254
01255 if (posCount == 0)
01256
01257 return *SgUctChildIterator(m_tree, node);
01258 SgUctValue logPosCount = Log(posCount);
01259 const SgUctNode* bestChild = 0;
01260 SgUctValue bestUpperBound = 0;
01261 const SgUctValue epsilon = SgUctValue(1e-7);
01262 for (SgUctChildIterator it(m_tree, node); it; ++it)
01263 {
01264 const SgUctNode& child = *it;
01265 if (!child.IsProvenWin())
01266 {
01267 SgUctValue bound = GetBound(useRave, logPosCount, child);
01268
01269
01270
01271
01272
01273
01274 if (bestChild == 0 || bound > bestUpperBound + epsilon)
01275 {
01276 bestChild = &child;
01277 bestUpperBound = bound;
01278 }
01279 }
01280 }
01281 if (bestChild != 0)
01282 return *bestChild;
01283
01284
01285
01286
01287 return *node.FirstChild();
01288 }
01289
01290 void SgUctSearch::SetNumberThreads(unsigned int n)
01291 {
01292 SG_ASSERT(n >= 1);
01293 if (m_numberThreads == n)
01294 return;
01295 m_numberThreads = n;
01296 CreateThreads();
01297 }
01298
01299 void SgUctSearch::SetRave(bool enable)
01300 {
01301 if (enable && m_moveRange <= 0)
01302 throw SgException("RAVE not supported for this game");
01303 m_rave = enable;
01304 }
01305
01306 void SgUctSearch::SetThreadStateFactory(SgUctThreadStateFactory* factory)
01307 {
01308 SG_ASSERT(m_threadStateFactory.get() == 0);
01309 m_threadStateFactory.reset(factory);
01310 DeleteThreads();
01311
01312
01313
01314 }
01315
01316 void SgUctSearch::StartSearch(const vector<SgMove>& rootFilter,
01317 SgUctTree* initTree)
01318 {
01319 if (m_threads.size() == 0)
01320 CreateThreads();
01321 if (m_numberThreads > 1 && SgTime::DefaultMode() == SG_TIME_CPU)
01322
01323
01324
01325
01326
01327 SgWarning() << "SgUctSearch: using cpu time with multiple threads\n";
01328 m_raveWeightParam1 = (SgUctValue)(1.0 / m_raveWeightInitial);
01329 m_raveWeightParam2 = (SgUctValue)(1.0 / m_raveWeightFinal);
01330 if (initTree == 0)
01331 m_tree.Clear();
01332 else
01333 {
01334 m_tree.Swap(*initTree);
01335 if (m_tree.HasCapacity(0, m_tree.Root().NuChildren()))
01336 m_tree.ApplyFilter(0, m_tree.Root(), rootFilter);
01337 else
01338 SgWarning() <<
01339 "SgUctSearch: "
01340 "root filter not applied (tree reached maximum size)\n";
01341 }
01342 m_statistics.Clear();
01343 m_aborted = false;
01344 m_wasEarlyAbort = false;
01345 m_checkTimeInterval = 1;
01346 m_numberGames = 0;
01347 m_lastScoreDisplayTime = m_timer.GetTime();
01348 OnStartSearch();
01349
01350 m_nextCheckTime = (SgUctValue)m_checkTimeInterval;
01351 m_startRootMoveCount = m_tree.Root().MoveCount();
01352
01353 for (unsigned int i = 0; i < m_threads.size(); ++i)
01354 {
01355 SgUctThreadState& state = ThreadState(i);
01356 state.m_randomizeCounter = m_randomizeRaveFrequency;
01357 state.StartSearch();
01358 }
01359 }
01360
01361 void SgUctSearch::EndSearch()
01362 {
01363 OnEndSearch();
01364 }
01365
01366 string SgUctSearch::SummaryLine(const SgUctGameInfo& info) const
01367 {
01368 ostringstream buffer;
01369 const vector<const SgUctNode*>& nodes = info.m_nodes;
01370 for (size_t i = 1; i < nodes.size(); ++i)
01371 {
01372 const SgUctNode* node = nodes[i];
01373 SgMove move = node->Move();
01374 buffer << ' ' << MoveString(move) << " (" << fixed << setprecision(2)
01375 << node->Mean() << ',' << node->MoveCount() << ')';
01376 }
01377 for (size_t i = 0; i < info.m_eval.size(); ++i)
01378 buffer << ' ' << fixed << setprecision(2) << info.m_eval[i];
01379 return buffer.str();
01380 }
01381
01382 void SgUctSearch::UpdateCheckTimeInterval(double time)
01383 {
01384 if (time < numeric_limits<double>::epsilon())
01385 return;
01386
01387
01388 double wantedTimeDiff = (m_maxTime > 1 ? 0.1 : 0.1 * m_maxTime);
01389 if (time < wantedTimeDiff / 10)
01390 {
01391
01392 m_checkTimeInterval *= 2;
01393 return;
01394 }
01395 m_statistics.m_gamesPerSecond = GamesPlayed() / time;
01396 double gamesPerSecondPerThread =
01397 m_statistics.m_gamesPerSecond / double(m_numberThreads);
01398 m_checkTimeInterval = SgUctValue(wantedTimeDiff * gamesPerSecondPerThread);
01399 if (m_checkTimeInterval == 0)
01400 m_checkTimeInterval = 1;
01401 }
01402
01403
01404
01405
01406 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state)
01407 {
01408 for (size_t i = 0; i < m_numberPlayouts; ++i)
01409 UpdateRaveValues(state, i);
01410 }
01411
01412 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01413 std::size_t playout)
01414 {
01415 SgUctGameInfo& info = state.m_gameInfo;
01416 const vector<SgMove>& sequence = info.m_sequence[playout];
01417 if (sequence.size() == 0)
01418 return;
01419 SG_ASSERT(m_moveRange > 0);
01420 size_t* firstPlay = state.m_firstPlay.get();
01421 size_t* firstPlayOpp = state.m_firstPlayOpp.get();
01422 fill_n(firstPlay, m_moveRange, numeric_limits<size_t>::max());
01423 fill_n(firstPlayOpp, m_moveRange, numeric_limits<size_t>::max());
01424 const vector<const SgUctNode*>& nodes = info.m_nodes;
01425 const vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
01426 SgUctValue eval = info.m_eval[playout];
01427 SgUctValue invEval = InverseEval(eval);
01428 size_t nuNodes = nodes.size();
01429 size_t i = sequence.size() - 1;
01430 bool opp = (i % 2 != 0);
01431
01432
01433 for ( ; i >= nuNodes; --i)
01434 {
01435 SG_ASSERT(i < skipRaveUpdate.size());
01436 SG_ASSERT(i < sequence.size());
01437 if (! skipRaveUpdate[i])
01438 {
01439 SgMove mv = sequence[i];
01440 size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01441 if (i < first)
01442 first = i;
01443 }
01444 opp = ! opp;
01445 }
01446
01447 while (true)
01448 {
01449 SG_ASSERT(i < skipRaveUpdate.size());
01450 SG_ASSERT(i < sequence.size());
01451
01452 SG_ASSERT(i >= info.m_inTreeSequence.size() || ! skipRaveUpdate[i]);
01453 if (! skipRaveUpdate[i])
01454 {
01455 SgMove mv = sequence[i];
01456 size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01457 if (i < first)
01458 first = i;
01459 if (opp)
01460 UpdateRaveValues(state, playout, invEval, i,
01461 firstPlayOpp, firstPlay);
01462 else
01463 UpdateRaveValues(state, playout, eval, i,
01464 firstPlay, firstPlayOpp);
01465 }
01466 if (i == 0)
01467 break;
01468 --i;
01469 opp = ! opp;
01470 }
01471 }
01472
01473 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01474 std::size_t playout, SgUctValue eval,
01475 std::size_t i,
01476 const std::size_t firstPlay[],
01477 const std::size_t firstPlayOpp[])
01478 {
01479 SG_ASSERT(i < state.m_gameInfo.m_nodes.size());
01480 const SgUctNode* node = state.m_gameInfo.m_nodes[i];
01481 if (! node->HasChildren())
01482 return;
01483 size_t len = state.m_gameInfo.m_sequence[playout].size();
01484 for (SgUctChildIterator it(m_tree, *node); it; ++it)
01485 {
01486 const SgUctNode& child = *it;
01487 SgMove mv = child.Move();
01488 size_t first = firstPlay[mv];
01489 SG_ASSERT(first >= i);
01490 if (first == numeric_limits<size_t>::max())
01491 continue;
01492 if (m_raveCheckSame && SgUtil::InRange(firstPlayOpp[mv], i, first))
01493 continue;
01494 SgUctValue weight;
01495 if (m_weightRaveUpdates)
01496 weight = 2 - SgUctValue(first - i) / SgUctValue(len - i);
01497 else
01498 weight = 1;
01499 m_tree.AddRaveValue(child, eval, weight);
01500 }
01501 }
01502
01503 void SgUctSearch::UpdateStatistics(const SgUctGameInfo& info)
01504 {
01505 m_statistics.m_movesInTree.Add(
01506 static_cast<float>(info.m_inTreeSequence.size()));
01507 for (size_t i = 0; i < m_numberPlayouts; ++i)
01508 {
01509 m_statistics.m_gameLength.Add(
01510 static_cast<float>(info.m_sequence[i].size()));
01511 m_statistics.m_aborted.Add(info.m_aborted[i] ? 1.f : 0.f);
01512 }
01513 }
01514
01515 void SgUctSearch::UpdateTree(const SgUctGameInfo& info)
01516 {
01517 SgUctValue eval = 0;
01518 for (size_t i = 0; i < m_numberPlayouts; ++i)
01519 eval += info.m_eval[i];
01520 eval /= SgUctValue(m_numberPlayouts);
01521 SgUctValue inverseEval = InverseEval(eval);
01522 const vector<const SgUctNode*>& nodes = info.m_nodes;
01523 SgUctValue count = SgUctValue(m_numberPlayouts);
01524 for (size_t i = 0; i < nodes.size(); ++i)
01525 {
01526 const SgUctNode& node = *nodes[i];
01527 const SgUctNode* father = (i > 0 ? nodes[i - 1] : 0);
01528 m_tree.AddGameResults(node, father, i % 2 == 0 ? eval : inverseEval,
01529 count);
01530
01531 if (m_virtualLoss && m_numberThreads > 1)
01532 m_tree.RemoveVirtualLoss(node);
01533 }
01534 }
01535
01536 void SgUctSearch::WriteStatistics(ostream& out) const
01537 {
01538 out << SgWriteLabel("Count") << m_tree.Root().MoveCount() << '\n'
01539 << SgWriteLabel("GamesPlayed") << GamesPlayed() << '\n'
01540 << SgWriteLabel("Nodes") << m_tree.NuNodes() << '\n';
01541 if (!m_knowledgeThreshold.empty())
01542 out << SgWriteLabel("Knowledge")
01543 << m_statistics.m_knowledge << " (" << fixed << setprecision(1)
01544 << m_statistics.m_knowledge * 100.0 / m_tree.Root().MoveCount()
01545 << "%)\n";
01546 m_statistics.Write(out);
01547 m_mpiSynchronizer->WriteStatistics(out);
01548 }
01549
01550