00001 //---------------------------------------------------------------------------- 00002 /** @file SgUctSearch.cpp */ 00003 //---------------------------------------------------------------------------- 00004 00005 #include "SgSystem.h" 00006 #include "SgUctSearch.h" 00007 00008 #include <algorithm> 00009 #include <cmath> 00010 #include <iomanip> 00011 #include <boost/format.hpp> 00012 #include <boost/io/ios_state.hpp> 00013 #include <boost/version.hpp> 00014 #include "SgDebug.h" 00015 #include "SgHashTable.h" 00016 #include "SgMath.h" 00017 #include "SgPlatform.h" 00018 #include "SgWrite.h" 00019 00020 using namespace std; 00021 using boost::barrier; 00022 using boost::condition; 00023 using boost::format; 00024 using boost::mutex; 00025 using boost::shared_ptr; 00026 using boost::io::ios_all_saver; 00027 00028 #define BOOST_VERSION_MAJOR (BOOST_VERSION / 100000) 00029 #define BOOST_VERSION_MINOR (BOOST_VERSION / 100 % 1000) 00030 00031 //---------------------------------------------------------------------------- 00032 00033 namespace { 00034 00035 const bool DEBUG_THREADS = false; 00036 00037 /** Get a default value for lock-free mode. 00038 Lock-free mode works only on IA-32/Intel-64 architectures or if the macro 00039 ENABLE_CACHE_SYNC from Fuego's configure script is defined. The 00040 architecture is determined by using the macro HOST_CPU from Fuego's 00041 configure script. On Windows, an Intel architecture is always assumed. */ 00042 bool GetLockFreeDefault() 00043 { 00044 #if defined(WIN32) || defined(ENABLE_CACHE_SYNC) 00045 return true; 00046 #elif defined(HOST_CPU) 00047 string hostCpu(HOST_CPU); 00048 return hostCpu == "i386" || hostCpu == "i486" || hostCpu == "i586" 00049 || hostCpu == "i686" || hostCpu == "x86_64"; 00050 #else 00051 return false; 00052 #endif 00053 } 00054 00055 /** Get a default value for the tree size. 00056 The default value is that both trees used by SgUctSearch take no more than 00057 half of the total amount of memory on the system (but no less than 00058 384 MB and not more than 1 GB). */ 00059 size_t GetMaxNodesDefault() 00060 { 00061 size_t totalMemory = SgPlatform::TotalMemory(); 00062 SgDebug() << "SgUctSearch: system memory "; 00063 if (totalMemory == 0) 00064 SgDebug() << "unknown"; 00065 else 00066 SgDebug() << totalMemory; 00067 // Use half of the physical memory by default but at least 284K and not 00068 // more than 1G 00069 size_t searchMemory = totalMemory / 2; 00070 if (searchMemory < 384000000) 00071 searchMemory = 384000000; 00072 if (searchMemory > 1000000000) 00073 searchMemory = 1000000000; 00074 size_t memoryPerTree = searchMemory / 2; 00075 size_t nodesPerTree = memoryPerTree / sizeof(SgUctNode); 00076 SgDebug() << ", using " << searchMemory << " (" << nodesPerTree 00077 << " nodes)\n"; 00078 return nodesPerTree; 00079 } 00080 00081 void Notify(mutex& aMutex, condition& aCondition) 00082 { 00083 mutex::scoped_lock lock(aMutex); 00084 aCondition.notify_all(); 00085 } 00086 00087 } // namespace 00088 00089 //---------------------------------------------------------------------------- 00090 00091 void SgUctGameInfo::Clear(std::size_t numberPlayouts) 00092 { 00093 m_nodes.clear(); 00094 m_inTreeSequence.clear(); 00095 if (numberPlayouts != m_sequence.size()) 00096 { 00097 m_sequence.resize(numberPlayouts); 00098 m_skipRaveUpdate.resize(numberPlayouts); 00099 m_eval.resize(numberPlayouts); 00100 m_aborted.resize(numberPlayouts); 00101 } 00102 for (size_t i = 0; i < numberPlayouts; ++i) 00103 { 00104 m_sequence[i].clear(); 00105 m_skipRaveUpdate[i].clear(); 00106 } 00107 } 00108 00109 //---------------------------------------------------------------------------- 00110 00111 SgUctThreadState::SgUctThreadState(unsigned int threadId, int moveRange) 00112 : m_threadId(threadId), 00113 m_isSearchInitialized(false), 00114 m_isTreeOutOfMem(false) 00115 { 00116 if (moveRange > 0) 00117 { 00118 m_firstPlay.reset(new size_t[moveRange]); 00119 m_firstPlayOpp.reset(new size_t[moveRange]); 00120 } 00121 } 00122 00123 SgUctThreadState::~SgUctThreadState() 00124 { 00125 } 00126 00127 void SgUctThreadState::EndPlayout() 00128 { 00129 // Default implementation does nothing 00130 } 00131 00132 void SgUctThreadState::GameStart() 00133 { 00134 // Default implementation does nothing 00135 } 00136 00137 void SgUctThreadState::StartPlayout() 00138 { 00139 // Default implementation does nothing 00140 } 00141 00142 void SgUctThreadState::StartPlayouts() 00143 { 00144 // Default implementation does nothing 00145 } 00146 00147 //---------------------------------------------------------------------------- 00148 00149 SgUctThreadStateFactory::~SgUctThreadStateFactory() 00150 { 00151 } 00152 00153 //---------------------------------------------------------------------------- 00154 00155 SgUctSearch::Thread::Function::Function(Thread& thread) 00156 : m_thread(thread) 00157 { 00158 } 00159 00160 void SgUctSearch::Thread::Function::operator()() 00161 { 00162 m_thread(); 00163 } 00164 00165 SgUctSearch::Thread::Thread(SgUctSearch& search, 00166 auto_ptr<SgUctThreadState> state) 00167 : m_state(state), 00168 m_search(search), 00169 m_quit(false), 00170 m_threadReady(2), 00171 m_playFinishedLock(m_playFinishedMutex), 00172 #if BOOST_VERSION_MAJOR == 1 && BOOST_VERSION_MINOR <= 34 00173 m_globalLock(search.m_globalMutex, false), 00174 #else 00175 m_globalLock(search.m_globalMutex, boost::defer_lock), 00176 #endif 00177 m_thread(Function(*this)) 00178 { 00179 m_threadReady.wait(); 00180 } 00181 00182 SgUctSearch::Thread::~Thread() 00183 { 00184 m_quit = true; 00185 StartPlay(); 00186 m_thread.join(); 00187 } 00188 00189 void SgUctSearch::Thread::operator()() 00190 { 00191 if (DEBUG_THREADS) 00192 SgDebug() << "SgUctSearch::Thread: starting thread " 00193 << m_state->m_threadId << '\n'; 00194 mutex::scoped_lock lock(m_startPlayMutex); 00195 m_threadReady.wait(); 00196 while (true) 00197 { 00198 m_startPlay.wait(lock); 00199 if (m_quit) 00200 break; 00201 m_search.SearchLoop(*m_state, &m_globalLock); 00202 Notify(m_playFinishedMutex, m_playFinished); 00203 } 00204 if (DEBUG_THREADS) 00205 SgDebug() << "SgUctSearch::Thread: finishing thread " 00206 << m_state->m_threadId << '\n'; 00207 } 00208 00209 void SgUctSearch::Thread::StartPlay() 00210 { 00211 Notify(m_startPlayMutex, m_startPlay); 00212 } 00213 00214 void SgUctSearch::Thread::WaitPlayFinished() 00215 { 00216 m_playFinished.wait(m_playFinishedLock); 00217 } 00218 00219 //---------------------------------------------------------------------------- 00220 00221 void SgUctSearchStat::Clear() 00222 { 00223 m_time = 0; 00224 m_knowledge = 0; 00225 m_gamesPerSecond = 0; 00226 m_gameLength.Clear(); 00227 m_movesInTree.Clear(); 00228 m_aborted.Clear(); 00229 } 00230 00231 void SgUctSearchStat::Write(std::ostream& out) const 00232 { 00233 ios_all_saver saver(out); 00234 out << SgWriteLabel("Time") << setprecision(2) << m_time << '\n' 00235 << SgWriteLabel("GameLen") << fixed << setprecision(1); 00236 m_gameLength.Write(out); 00237 out << '\n' 00238 << SgWriteLabel("InTree"); 00239 m_movesInTree.Write(out); 00240 out << '\n' 00241 << SgWriteLabel("Aborted") 00242 << static_cast<int>(100 * m_aborted.Mean()) << "%\n" 00243 << SgWriteLabel("Games/s") << fixed << setprecision(1) 00244 << m_gamesPerSecond << '\n'; 00245 } 00246 00247 //---------------------------------------------------------------------------- 00248 00249 SgUctSearch::SgUctSearch(SgUctThreadStateFactory* threadStateFactory, 00250 int moveRange) 00251 : m_threadStateFactory(threadStateFactory), 00252 m_logGames(false), 00253 m_rave(false), 00254 m_knowledgeThreshold(), 00255 m_moveSelect(SG_UCTMOVESELECT_COUNT), 00256 m_raveCheckSame(false), 00257 m_randomizeRaveFrequency(20), 00258 m_lockFree(GetLockFreeDefault()), 00259 m_weightRaveUpdates(true), 00260 m_pruneFullTree(true), 00261 m_checkFloatPrecision(true), 00262 m_numberThreads(1), 00263 m_numberPlayouts(1), 00264 m_maxNodes(GetMaxNodesDefault()), 00265 m_pruneMinCount(16), 00266 m_moveRange(moveRange), 00267 m_maxGameLength(numeric_limits<size_t>::max()), 00268 m_expandThreshold(numeric_limits<SgUctValue>::is_integer ? (SgUctValue)1 : numeric_limits<SgUctValue>::epsilon()), 00269 m_biasTermConstant(0.7f), 00270 m_firstPlayUrgency(10000), 00271 m_raveWeightInitial(0.9f), 00272 m_raveWeightFinal(20000), 00273 m_virtualLoss(false), 00274 m_logFileName("uctsearch.log"), 00275 m_fastLog(10), 00276 m_mpiSynchronizer(SgMpiNullSynchronizer::Create()) 00277 { 00278 // Don't create thread states here, because the factory passes the search 00279 // (which is not fully constructed here, because the subclass constructors 00280 // are not called yet) as an argument to the Create() function 00281 } 00282 00283 SgUctSearch::~SgUctSearch() 00284 { 00285 DeleteThreads(); 00286 } 00287 00288 void SgUctSearch::ApplyRootFilter(vector<SgUctMoveInfo>& moves) 00289 { 00290 // Filter without changing the order of the unfiltered moves 00291 vector<SgUctMoveInfo> filteredMoves; 00292 for (vector<SgUctMoveInfo>::const_iterator it = moves.begin(); 00293 it != moves.end(); ++it) 00294 if (find(m_rootFilter.begin(), m_rootFilter.end(), it->m_move) 00295 == m_rootFilter.end()) 00296 filteredMoves.push_back(*it); 00297 moves = filteredMoves; 00298 } 00299 00300 SgUctValue SgUctSearch::GamesPlayed() const 00301 { 00302 return m_tree.Root().MoveCount() - m_startRootMoveCount; 00303 } 00304 00305 bool SgUctSearch::CheckAbortSearch(SgUctThreadState& state) 00306 { 00307 if (SgUserAbort()) 00308 { 00309 Debug(state, "SgUctSearch: abort flag"); 00310 return true; 00311 } 00312 const SgUctNode& root = m_tree.Root(); 00313 if (! SgUctValueUtil::IsPrecise(root.MoveCount()) && m_checkFloatPrecision) 00314 { 00315 Debug(state, "SgUctSearch: floating point type precision reached"); 00316 return true; 00317 } 00318 SgUctValue rootCount = root.MoveCount(); 00319 if (rootCount >= m_maxGames) 00320 { 00321 Debug(state, "SgUctSearch: max games reached"); 00322 return true; 00323 } 00324 if (root.IsProven()) 00325 { 00326 if (root.IsProvenWin()) 00327 Debug(state, "SgUctSearch: root is proven win!"); 00328 else 00329 Debug(state, "SgUctSearch: root is proven loss!"); 00330 return true; 00331 } 00332 const bool isEarlyAbort = CheckEarlyAbort(); 00333 if ( isEarlyAbort 00334 && m_earlyAbort->m_reductionFactor * rootCount >= m_maxGames 00335 ) 00336 { 00337 Debug(state, "SgUctSearch: max games reached (early abort)"); 00338 m_wasEarlyAbort = true; 00339 return true; 00340 } 00341 if (m_numberGames >= m_nextCheckTime) 00342 { 00343 m_nextCheckTime = m_numberGames + m_checkTimeInterval; 00344 double time = m_timer.GetTime(); 00345 if (time > m_maxTime) 00346 { 00347 Debug(state, "SgUctSearch: max time reached"); 00348 return true; 00349 } 00350 if (isEarlyAbort 00351 && m_earlyAbort->m_reductionFactor * time > m_maxTime) 00352 { 00353 Debug(state, "SgUctSearch: max time reached (early abort)"); 00354 m_wasEarlyAbort = true; 00355 return true; 00356 } 00357 UpdateCheckTimeInterval(time); 00358 if (m_moveSelect == SG_UCTMOVESELECT_COUNT) 00359 { 00360 double remainingGamesDouble = m_maxGames - rootCount - 1; 00361 // Use time based count abort, only if time > 1, otherwise 00362 // m_gamesPerSecond is unreliable 00363 if (time > 1.) 00364 { 00365 double remainingTime = m_maxTime - time; 00366 remainingGamesDouble = 00367 min(remainingGamesDouble, 00368 remainingTime * m_statistics.m_gamesPerSecond); 00369 } 00370 SgUctValue uctCountMax = numeric_limits<SgUctValue>::max(); 00371 SgUctValue remainingGames; 00372 if (remainingGamesDouble >= static_cast<double>(uctCountMax - 1)) 00373 remainingGames = uctCountMax; 00374 else 00375 remainingGames = SgUctValue(remainingGamesDouble); 00376 if (CheckCountAbort(state, remainingGames)) 00377 { 00378 Debug(state, "SgUctSearch: move cannot change anymore"); 00379 return true; 00380 } 00381 } 00382 } 00383 return false; 00384 } 00385 00386 bool SgUctSearch::CheckCountAbort(SgUctThreadState& state, 00387 SgUctValue remainingGames) const 00388 { 00389 const SgUctNode& root = m_tree.Root(); 00390 const SgUctNode* bestChild = FindBestChild(root); 00391 if (bestChild == 0) 00392 return false; 00393 SgUctValue bestCount = bestChild->MoveCount(); 00394 vector<SgMove>& excludeMoves = state.m_excludeMoves; 00395 excludeMoves.clear(); 00396 excludeMoves.push_back(bestChild->Move()); 00397 const SgUctNode* secondBestChild = FindBestChild(root, &excludeMoves); 00398 if (secondBestChild == 0) 00399 return false; 00400 SgUctValue secondBestCount = secondBestChild->MoveCount(); 00401 SG_ASSERT(secondBestCount <= bestCount || m_numberThreads > 1); 00402 return (remainingGames <= bestCount - secondBestCount); 00403 } 00404 00405 bool SgUctSearch::CheckEarlyAbort() const 00406 { 00407 const SgUctNode& root = m_tree.Root(); 00408 return m_earlyAbort.get() != 0 00409 && root.HasMean() 00410 && root.MoveCount() > m_earlyAbort->m_minGames 00411 && root.Mean() > m_earlyAbort->m_threshold; 00412 } 00413 00414 void SgUctSearch::CreateThreads() 00415 { 00416 DeleteThreads(); 00417 for (unsigned int i = 0; i < m_numberThreads; ++i) 00418 { 00419 auto_ptr<SgUctThreadState> state( 00420 m_threadStateFactory->Create(i, *this)); 00421 shared_ptr<Thread> thread(new Thread(*this, state)); 00422 m_threads.push_back(thread); 00423 } 00424 m_tree.CreateAllocators(m_numberThreads); 00425 m_tree.SetMaxNodes(m_maxNodes); 00426 00427 m_searchLoopFinished.reset(new barrier(m_numberThreads)); 00428 } 00429 00430 /** Write a debugging line of text from within a thread. 00431 Prepends the line with the thread number if number of threads is greater 00432 than one. Also ensures that the line is written as a single string to 00433 avoid intermingling of text lines from different threads. 00434 @param state The state of the thread (only used for state.m_threadId) 00435 @param textLine The line of text without trailing newline character. */ 00436 void SgUctSearch::Debug(const SgUctThreadState& state, 00437 const std::string& textLine) 00438 { 00439 if (m_numberThreads > 1) 00440 { 00441 // SgDebug() is not necessarily thread-safe 00442 GlobalLock lock(m_globalMutex); 00443 SgDebug() << (format("[%1%] %2%\n") % state.m_threadId % textLine); 00444 } 00445 else 00446 SgDebug() << (format("%1%\n") % textLine); 00447 } 00448 00449 void SgUctSearch::DeleteThreads() 00450 { 00451 m_threads.clear(); 00452 } 00453 00454 /** Expand a node. 00455 @param state The thread state with state.m_moves already computed. 00456 @param node The node to expand. */ 00457 void SgUctSearch::ExpandNode(SgUctThreadState& state, const SgUctNode& node) 00458 { 00459 unsigned int threadId = state.m_threadId; 00460 if (! m_tree.HasCapacity(threadId, state.m_moves.size())) 00461 { 00462 Debug(state, str(format("SgUctSearch: maximum tree size %1% reached") 00463 % m_tree.MaxNodes())); 00464 state.m_isTreeOutOfMem = true; 00465 m_isTreeOutOfMemory = true; 00466 SgSynchronizeThreadMemory(); 00467 return; 00468 } 00469 m_tree.CreateChildren(threadId, node, state.m_moves); 00470 } 00471 00472 const SgUctNode* 00473 SgUctSearch::FindBestChild(const SgUctNode& node, 00474 const vector<SgMove>* excludeMoves) const 00475 { 00476 if (! node.HasChildren()) 00477 return 0; 00478 const SgUctNode* bestChild = 0; 00479 SgUctValue bestValue = 0; 00480 for (SgUctChildIterator it(m_tree, node); it; ++it) 00481 { 00482 const SgUctNode& child = *it; 00483 if (excludeMoves != 0) 00484 { 00485 vector<SgMove>::const_iterator begin = excludeMoves->begin(); 00486 vector<SgMove>::const_iterator end = excludeMoves->end(); 00487 if (find(begin, end, child.Move()) != end) 00488 continue; 00489 } 00490 if ( ! child.HasMean() 00491 && ! ( ( m_moveSelect == SG_UCTMOVESELECT_BOUND 00492 || m_moveSelect == SG_UCTMOVESELECT_ESTIMATE 00493 ) 00494 && m_rave 00495 && child.HasRaveValue() 00496 ) 00497 ) 00498 continue; 00499 if (child.IsProvenLoss()) // Always choose winning move! 00500 { 00501 bestChild = &child; 00502 break; 00503 } 00504 SgUctValue value; 00505 switch (m_moveSelect) 00506 { 00507 case SG_UCTMOVESELECT_VALUE: 00508 value = InverseEstimate((SgUctValue)child.Mean()); 00509 break; 00510 case SG_UCTMOVESELECT_COUNT: 00511 value = child.MoveCount(); 00512 break; 00513 case SG_UCTMOVESELECT_BOUND: 00514 value = GetBound(m_rave, node, child); 00515 break; 00516 case SG_UCTMOVESELECT_ESTIMATE: 00517 value = GetValueEstimate(m_rave, child); 00518 break; 00519 default: 00520 SG_ASSERT(false); 00521 value = SG_UCTMOVESELECT_VALUE; 00522 } 00523 if (bestChild == 0 || value > bestValue) 00524 { 00525 bestChild = &child; 00526 bestValue = value; 00527 } 00528 } 00529 return bestChild; 00530 } 00531 00532 void SgUctSearch::FindBestSequence(vector<SgMove>& sequence) const 00533 { 00534 sequence.clear(); 00535 const SgUctNode* current = &m_tree.Root(); 00536 while (true) 00537 { 00538 current = FindBestChild(*current); 00539 if (current == 0) 00540 break; 00541 sequence.push_back(current->Move()); 00542 if (! current->HasChildren()) 00543 break; 00544 } 00545 } 00546 00547 void SgUctSearch::GenerateAllMoves(std::vector<SgUctMoveInfo>& moves) 00548 { 00549 if (m_threads.size() == 0) 00550 CreateThreads(); 00551 moves.clear(); 00552 OnStartSearch(); 00553 SgUctThreadState& state = ThreadState(0); 00554 state.StartSearch(); 00555 SgUctProvenType type; 00556 state.GenerateAllMoves(0, moves, type); 00557 } 00558 00559 SgUctValue SgUctSearch::GetBound(bool useRave, const SgUctNode& node, 00560 const SgUctNode& child) const 00561 { 00562 SgUctValue posCount = node.PosCount(); 00563 int virtualLossCount = node.VirtualLossCount(); 00564 if (virtualLossCount > 0) 00565 { 00566 posCount += SgUctValue(virtualLossCount); 00567 } 00568 return GetBound(useRave, Log(posCount), child); 00569 } 00570 00571 SgUctValue SgUctSearch::GetBound(bool useRave, SgUctValue logPosCount, 00572 const SgUctNode& child) const 00573 { 00574 SgUctValue value; 00575 if (useRave) 00576 value = GetValueEstimateRave(child); 00577 else 00578 value = GetValueEstimate(false, child); 00579 if (m_biasTermConstant == 0.0) 00580 return value; 00581 else 00582 { 00583 SgUctValue moveCount = static_cast<SgUctValue>(child.MoveCount()); 00584 SgUctValue bound = 00585 value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1)); 00586 return bound; 00587 } 00588 } 00589 00590 SgUctTree& SgUctSearch::GetTempTree() 00591 { 00592 m_tempTree.Clear(); 00593 // Use NumberThreads() (not m_tree.NuAllocators()) and MaxNodes() (not 00594 // m_tree.MaxNodes()), because of the delayed thread (and thereby 00595 // allocator) creation in SgUctSearch 00596 if (m_tempTree.NuAllocators() != NumberThreads()) 00597 { 00598 m_tempTree.CreateAllocators(NumberThreads()); 00599 m_tempTree.SetMaxNodes(MaxNodes()); 00600 } 00601 else if (m_tempTree.MaxNodes() != MaxNodes()) 00602 { 00603 m_tempTree.SetMaxNodes(MaxNodes()); 00604 } 00605 return m_tempTree; 00606 } 00607 00608 SgUctValue SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const 00609 { 00610 SgUctValue value = 0; 00611 SgUctValue weightSum = 0; 00612 bool hasValue = false; 00613 00614 SgUctStatistics uctStats; 00615 if (child.HasMean()) 00616 { 00617 uctStats.Initialize(child.Mean(), child.MoveCount()); 00618 } 00619 int virtualLossCount = child.VirtualLossCount(); 00620 if (virtualLossCount > 0) 00621 { 00622 uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount)); 00623 } 00624 00625 if (uctStats.IsDefined()) 00626 { 00627 SgUctValue weight = static_cast<SgUctValue>(uctStats.Count()); 00628 value += weight * InverseEstimate((SgUctValue)uctStats.Mean()); 00629 weightSum += weight; 00630 hasValue = true; 00631 } 00632 00633 if (useRave) 00634 { 00635 SgUctStatistics raveStats; 00636 if (child.HasRaveValue()) 00637 { 00638 raveStats.Initialize(child.RaveValue(), child.RaveCount()); 00639 } 00640 if (virtualLossCount > 0) 00641 { 00642 raveStats.Add(0, SgUctValue(virtualLossCount)); 00643 } 00644 if (raveStats.IsDefined()) 00645 { 00646 SgUctValue raveCount = raveStats.Count(); 00647 SgUctValue weight = 00648 raveCount 00649 / ( m_raveWeightParam1 00650 + m_raveWeightParam2 * raveCount 00651 ); 00652 value += weight * raveStats.Mean(); 00653 weightSum += weight; 00654 hasValue = true; 00655 } 00656 } 00657 if (hasValue) 00658 return value / weightSum; 00659 else 00660 return m_firstPlayUrgency; 00661 } 00662 00663 /** Optimized version of GetValueEstimate() if RAVE and not other 00664 estimators are used. 00665 Previously there were more estimators than move value and RAVE value, 00666 and in the future there may be again. GetValueEstimate() is easier to 00667 extend, this function is more optimized for the special case. */ 00668 SgUctValue SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const 00669 { 00670 SG_ASSERT(m_rave); 00671 SgUctValue value; 00672 SgUctStatistics uctStats; 00673 if (child.HasMean()) 00674 { 00675 uctStats.Initialize(child.Mean(), child.MoveCount()); 00676 } 00677 SgUctStatistics raveStats; 00678 if (child.HasRaveValue()) 00679 { 00680 raveStats.Initialize(child.RaveValue(), child.RaveCount()); 00681 } 00682 int virtualLossCount = child.VirtualLossCount(); 00683 if (virtualLossCount > 0) 00684 { 00685 uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount)); 00686 raveStats.Add(0, SgUctValue(virtualLossCount)); 00687 } 00688 bool hasRave = raveStats.IsDefined(); 00689 00690 if (uctStats.IsDefined()) 00691 { 00692 SgUctValue moveValue = InverseEstimate((SgUctValue)uctStats.Mean()); 00693 if (hasRave) 00694 { 00695 SgUctValue moveCount = uctStats.Count(); 00696 SgUctValue raveCount = raveStats.Count(); 00697 SgUctValue weight = 00698 raveCount 00699 / (moveCount 00700 * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount) 00701 + raveCount); 00702 value = weight * raveStats.Mean() + (1.f - weight) * moveValue; 00703 } 00704 else 00705 { 00706 // This can happen only in lock-free multi-threading. Normally, 00707 // each move played in a position should also cause a RAVE value 00708 // to be added. But in lock-free multi-threading it can happen 00709 // that the move value was already updated but the RAVE value not 00710 SG_ASSERT(m_numberThreads > 1 && m_lockFree); 00711 value = moveValue; 00712 } 00713 } 00714 else if (hasRave) 00715 value = raveStats.Mean(); 00716 else 00717 value = m_firstPlayUrgency; 00718 SG_ASSERT(m_numberThreads > 1 00719 || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3/*epsilon*/); 00720 return value; 00721 } 00722 00723 string SgUctSearch::LastGameSummaryLine() const 00724 { 00725 return SummaryLine(LastGameInfo()); 00726 } 00727 00728 SgUctValue SgUctSearch::Log(SgUctValue x) const 00729 { 00730 #if SG_UCTFASTLOG 00731 return SgUctValue(m_fastLog.Log(float(x))); 00732 #else 00733 return log(x); 00734 #endif 00735 } 00736 00737 /** Creates the children with the given moves and merges with existing 00738 children in the tree. */ 00739 void SgUctSearch::CreateChildren(SgUctThreadState& state, 00740 const SgUctNode& node, 00741 bool deleteChildTrees) 00742 { 00743 unsigned int threadId = state.m_threadId; 00744 if (! m_tree.HasCapacity(threadId, state.m_moves.size())) 00745 { 00746 Debug(state, str(format("SgUctSearch: maximum tree size %1% reached") 00747 % m_tree.MaxNodes())); 00748 state.m_isTreeOutOfMem = true; 00749 m_isTreeOutOfMemory = true; 00750 SgSynchronizeThreadMemory(); 00751 return; 00752 } 00753 m_tree.MergeChildren(threadId, node, state.m_moves, deleteChildTrees); 00754 } 00755 00756 bool SgUctSearch::NeedToComputeKnowledge(const SgUctNode* current) 00757 { 00758 if (m_knowledgeThreshold.empty()) 00759 return false; 00760 for (std::size_t i = 0; i < m_knowledgeThreshold.size(); ++i) 00761 { 00762 const SgUctValue threshold = m_knowledgeThreshold[i]; 00763 if (current->KnowledgeCount() < threshold) 00764 { 00765 if (current->MoveCount() >= threshold) 00766 { 00767 // Mark knowledge computed immediately so other 00768 // threads fall through and do not waste time 00769 // re-computing this knowledge. 00770 m_tree.SetKnowledgeCount(*current, threshold); 00771 SG_ASSERT(current->MoveCount()); 00772 return true; 00773 } 00774 return false; 00775 } 00776 } 00777 return false; 00778 } 00779 00780 void SgUctSearch::OnStartSearch() 00781 { 00782 m_mpiSynchronizer->OnStartSearch(*this); 00783 } 00784 00785 void SgUctSearch::OnEndSearch() 00786 { 00787 m_mpiSynchronizer->OnEndSearch(*this); 00788 } 00789 00790 /** Print time, mean, nodes searched, and PV */ 00791 void SgUctSearch::PrintSearchProgress(double currTime) const 00792 { 00793 const int MAX_SEQ_PRINT_LENGTH = 15; 00794 const SgUctValue MIN_MOVE_COUNT = 10; 00795 SgUctValue rootMoveCount = m_tree.Root().MoveCount(); 00796 SgUctValue rootMean = m_tree.Root().Mean(); 00797 ostringstream out; 00798 const SgUctNode* current = &m_tree.Root(); 00799 out << (format("%s | %.3f | %.0f ") 00800 % SgTime::Format(currTime, true) % rootMean % rootMoveCount); 00801 for (int i = 0; i <= MAX_SEQ_PRINT_LENGTH && current->HasChildren(); ++i) 00802 { 00803 current = FindBestChild(*current); 00804 if (current == 0 || current->MoveCount() < MIN_MOVE_COUNT) 00805 break; 00806 if (i == 0) 00807 out << "|"; 00808 if (i < MAX_SEQ_PRINT_LENGTH) 00809 out << " " << MoveString(current->Move()); 00810 else 00811 out << " *"; 00812 } 00813 SgDebug() << out.str() << endl; 00814 } 00815 00816 void SgUctSearch::OnSearchIteration(SgUctValue gameNumber, 00817 unsigned int threadId, 00818 const SgUctGameInfo& info) 00819 { 00820 const int DISPLAY_INTERVAL = 5; 00821 00822 m_mpiSynchronizer->OnSearchIteration(*this, gameNumber, threadId, info); 00823 double currTime = m_timer.GetTime(); 00824 00825 if (threadId == 0 && currTime - m_lastScoreDisplayTime > DISPLAY_INTERVAL) 00826 { 00827 PrintSearchProgress(currTime); 00828 m_lastScoreDisplayTime = currTime; 00829 } 00830 } 00831 00832 void SgUctSearch::PlayGame(SgUctThreadState& state, GlobalLock* lock) 00833 { 00834 state.m_isTreeOutOfMem = false; 00835 state.GameStart(); 00836 SgUctGameInfo& info = state.m_gameInfo; 00837 info.Clear(m_numberPlayouts); 00838 bool isTerminal; 00839 bool abortInTree = ! PlayInTree(state, isTerminal); 00840 00841 // The playout phase is always unlocked 00842 if (lock != 0) 00843 lock->unlock(); 00844 00845 if (!info.m_nodes.empty() && isTerminal) 00846 { 00847 const SgUctNode& terminalNode = *info.m_nodes.back(); 00848 SgUctValue eval = state.Evaluate(); 00849 if (eval > 0.6) 00850 m_tree.SetProvenType(terminalNode, SG_PROVEN_WIN); 00851 else if (eval < 0.6) 00852 m_tree.SetProvenType(terminalNode, SG_PROVEN_LOSS); 00853 PropagateProvenStatus(info.m_nodes); 00854 } 00855 00856 size_t nuMovesInTree = info.m_inTreeSequence.size(); 00857 00858 // Play some "fake" playouts if node is a proven node 00859 if (! info.m_nodes.empty() && info.m_nodes.back()->IsProven()) 00860 { 00861 for (size_t i = 0; i < m_numberPlayouts; ++i) 00862 { 00863 info.m_sequence[i] = info.m_inTreeSequence; 00864 info.m_skipRaveUpdate[i].assign(nuMovesInTree, false); 00865 SgUctValue eval = info.m_nodes.back()->IsProvenWin() ? 1 : 0; 00866 size_t nuMoves = info.m_sequence[i].size(); 00867 if (nuMoves % 2 != 0) 00868 eval = InverseEval(eval); 00869 info.m_aborted[i] = abortInTree || state.m_isTreeOutOfMem; 00870 info.m_eval[i] = eval; 00871 } 00872 } 00873 else 00874 { 00875 state.StartPlayouts(); 00876 for (size_t i = 0; i < m_numberPlayouts; ++i) 00877 { 00878 state.StartPlayout(); 00879 info.m_sequence[i] = info.m_inTreeSequence; 00880 // skipRaveUpdate only used in playout phase 00881 info.m_skipRaveUpdate[i].assign(nuMovesInTree, false); 00882 bool abort = abortInTree || state.m_isTreeOutOfMem; 00883 if (! abort && ! isTerminal) 00884 abort = ! PlayoutGame(state, i); 00885 SgUctValue eval; 00886 if (abort) 00887 eval = UnknownEval(); 00888 else 00889 eval = state.Evaluate(); 00890 size_t nuMoves = info.m_sequence[i].size(); 00891 if (nuMoves % 2 != 0) 00892 eval = InverseEval(eval); 00893 info.m_aborted[i] = abort; 00894 info.m_eval[i] = eval; 00895 state.EndPlayout(); 00896 state.TakeBackPlayout(nuMoves - nuMovesInTree); 00897 } 00898 } 00899 state.TakeBackInTree(nuMovesInTree); 00900 00901 // End of unlocked part if ! m_lockFree 00902 if (lock != 0) 00903 lock->lock(); 00904 00905 UpdateTree(info); 00906 if (m_rave) 00907 UpdateRaveValues(state); 00908 UpdateStatistics(info); 00909 } 00910 00911 /** Backs up proven information. Last node of nodes is the newly 00912 proven node. */ 00913 void SgUctSearch::PropagateProvenStatus(const vector<const SgUctNode*>& nodes) 00914 { 00915 if (nodes.size() <= 1) 00916 return; 00917 size_t i = nodes.size() - 2; 00918 while (true) 00919 { 00920 const SgUctNode& parent = *nodes[i]; 00921 SgUctProvenType type = SG_PROVEN_LOSS; 00922 for (SgUctChildIterator it(m_tree, parent); it; ++it) 00923 { 00924 const SgUctNode& child = *it; 00925 if (!child.IsProven()) 00926 type = SG_NOT_PROVEN; 00927 else if (child.IsProvenLoss()) 00928 { 00929 type = SG_PROVEN_WIN; 00930 break; 00931 } 00932 } 00933 if (type == SG_NOT_PROVEN) 00934 break; 00935 else 00936 m_tree.SetProvenType(parent, type); 00937 if (i == 0) 00938 break; 00939 --i; 00940 } 00941 } 00942 00943 /** Play game until it leaves the tree. 00944 @param state 00945 @param[out] isTerminal Was the sequence terminated because of a real 00946 terminal position (GenerateAllMoves() returned an empty list)? 00947 @return @c false, if game was aborted due to maximum length */ 00948 bool SgUctSearch::PlayInTree(SgUctThreadState& state, bool& isTerminal) 00949 { 00950 vector<SgMove>& sequence = state.m_gameInfo.m_inTreeSequence; 00951 vector<const SgUctNode*>& nodes = state.m_gameInfo.m_nodes; 00952 const SgUctNode* root = &m_tree.Root(); 00953 const SgUctNode* current = root; 00954 if (m_virtualLoss && m_numberThreads > 1) 00955 m_tree.AddVirtualLoss(*current); 00956 nodes.push_back(current); 00957 bool breakAfterSelect = false; 00958 isTerminal = false; 00959 bool deepenTree = false; 00960 while (true) 00961 { 00962 if (sequence.size() == m_maxGameLength) 00963 return false; 00964 if (current->IsProven()) 00965 break; 00966 if (! current->HasChildren()) 00967 { 00968 state.m_moves.clear(); 00969 SgUctProvenType provenType = SG_NOT_PROVEN; 00970 state.GenerateAllMoves(0, state.m_moves, provenType); 00971 if (current == root) 00972 ApplyRootFilter(state.m_moves); 00973 if (provenType != SG_NOT_PROVEN) 00974 { 00975 m_tree.SetProvenType(*current, provenType); 00976 PropagateProvenStatus(nodes); 00977 break; 00978 } 00979 if (state.m_moves.empty()) 00980 { 00981 isTerminal = true; 00982 break; 00983 } 00984 if ( deepenTree 00985 || current->MoveCount() >= m_expandThreshold 00986 ) 00987 { 00988 deepenTree = false; 00989 ExpandNode(state, *current); 00990 if (state.m_isTreeOutOfMem) 00991 return true; 00992 if (! deepenTree) 00993 breakAfterSelect = true; 00994 } 00995 else 00996 break; 00997 } 00998 else if (NeedToComputeKnowledge(current)) 00999 { 01000 m_statistics.m_knowledge++; 01001 deepenTree = false; 01002 SgUctProvenType provenType = SG_NOT_PROVEN; 01003 bool truncate = state.GenerateAllMoves(current->KnowledgeCount(), 01004 state.m_moves, 01005 provenType); 01006 if (current == root) 01007 ApplyRootFilter(state.m_moves); 01008 CreateChildren(state, *current, truncate); 01009 if (provenType != SG_NOT_PROVEN) 01010 { 01011 m_tree.SetProvenType(*current, provenType); 01012 PropagateProvenStatus(nodes); 01013 break; 01014 } 01015 if (state.m_moves.empty()) 01016 { 01017 isTerminal = true; 01018 break; 01019 } 01020 if (state.m_isTreeOutOfMem) 01021 return true; 01022 if (! deepenTree) 01023 breakAfterSelect = true; 01024 } 01025 current = &SelectChild(state.m_randomizeCounter, *current); 01026 if (m_virtualLoss && m_numberThreads > 1) 01027 m_tree.AddVirtualLoss(*current); 01028 nodes.push_back(current); 01029 SgMove move = current->Move(); 01030 state.Execute(move); 01031 sequence.push_back(move); 01032 if (breakAfterSelect) 01033 break; 01034 } 01035 return true; 01036 } 01037 01038 /** Finish the game using GeneratePlayoutMove(). 01039 @param state The thread state. 01040 @param playout The number of the playout. 01041 @return @c false if game was aborted */ 01042 bool SgUctSearch::PlayoutGame(SgUctThreadState& state, std::size_t playout) 01043 { 01044 SgUctGameInfo& info = state.m_gameInfo; 01045 vector<SgMove>& sequence = info.m_sequence[playout]; 01046 vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout]; 01047 while (true) 01048 { 01049 if (sequence.size() == m_maxGameLength) 01050 return false; 01051 bool skipRave = false; 01052 SgMove move = state.GeneratePlayoutMove(skipRave); 01053 if (move == SG_NULLMOVE) 01054 break; 01055 state.ExecutePlayout(move); 01056 sequence.push_back(move); 01057 skipRaveUpdate.push_back(skipRave); 01058 } 01059 return true; 01060 } 01061 01062 SgUctValue SgUctSearch::Search(SgUctValue maxGames, double maxTime, 01063 vector<SgMove>& sequence, 01064 const vector<SgMove>& rootFilter, 01065 SgUctTree* initTree, 01066 SgUctEarlyAbortParam* earlyAbort) 01067 { 01068 m_timer.Start(); 01069 m_rootFilter = rootFilter; 01070 if (m_logGames) 01071 { 01072 m_log.open(m_mpiSynchronizer->ToNodeFilename(m_logFileName).c_str()); 01073 m_log << "StartSearch maxGames=" << maxGames << '\n'; 01074 } 01075 m_maxGames = maxGames; 01076 m_maxTime = maxTime; 01077 m_earlyAbort.reset(0); 01078 if (earlyAbort != 0) 01079 m_earlyAbort.reset(new SgUctEarlyAbortParam(*earlyAbort)); 01080 01081 for (size_t i = 0; i < m_threads.size(); ++i) 01082 { 01083 m_threads[i]->m_state->m_isSearchInitialized = false; 01084 } 01085 StartSearch(rootFilter, initTree); 01086 SgUctValue pruneMinCount = m_pruneMinCount; 01087 while (true) 01088 { 01089 m_isTreeOutOfMemory = false; 01090 SgSynchronizeThreadMemory(); 01091 for (size_t i = 0; i < m_threads.size(); ++i) 01092 m_threads[i]->StartPlay(); 01093 for (size_t i = 0; i < m_threads.size(); ++i) 01094 m_threads[i]->WaitPlayFinished(); 01095 if (m_aborted || ! m_pruneFullTree) 01096 break; 01097 else 01098 { 01099 double startPruneTime = m_timer.GetTime(); 01100 SgDebug() << "SgUctSearch: pruning nodes with count < " 01101 << pruneMinCount << " (at time " << fixed << setprecision(1) 01102 << startPruneTime << ")\n"; 01103 SgUctTree& tempTree = GetTempTree(); 01104 m_tree.CopyPruneLowCount(tempTree, pruneMinCount, true); 01105 int prunedSizePercentage = 01106 static_cast<int>(tempTree.NuNodes() * 100 / m_tree.NuNodes()); 01107 SgDebug() << "SgUctSearch: pruned size: " << tempTree.NuNodes() 01108 << " (" << prunedSizePercentage << "%) time: " 01109 << (m_timer.GetTime() - startPruneTime) << "\n"; 01110 if (prunedSizePercentage > 50) 01111 pruneMinCount *= 2; 01112 else 01113 pruneMinCount = m_pruneMinCount; 01114 m_tree.Swap(tempTree); 01115 } 01116 } 01117 EndSearch(); 01118 m_statistics.m_time = m_timer.GetTime(); 01119 if (m_statistics.m_time > numeric_limits<double>::epsilon()) 01120 m_statistics.m_gamesPerSecond = GamesPlayed() / m_statistics.m_time; 01121 if (m_logGames) 01122 m_log.close(); 01123 FindBestSequence(sequence); 01124 return (m_tree.Root().MoveCount() > 0) ? (SgUctValue)m_tree.Root().Mean() : (SgUctValue)0.5; 01125 } 01126 01127 /** Loop invoked by each thread for playing games. */ 01128 void SgUctSearch::SearchLoop(SgUctThreadState& state, GlobalLock* lock) 01129 { 01130 if (! state.m_isSearchInitialized) 01131 { 01132 OnThreadStartSearch(state); 01133 state.m_isSearchInitialized = true; 01134 } 01135 01136 if (NumberThreads() == 1 || m_lockFree) 01137 lock = 0; 01138 if (lock != 0) 01139 lock->lock(); 01140 state.m_isTreeOutOfMem = false; 01141 while (! state.m_isTreeOutOfMem) 01142 { 01143 PlayGame(state, lock); 01144 OnSearchIteration(m_numberGames + 1, state.m_threadId, 01145 state.m_gameInfo); 01146 if (m_logGames) 01147 m_log << SummaryLine(state.m_gameInfo) << '\n'; 01148 ++m_numberGames; 01149 if (m_isTreeOutOfMemory) 01150 break; 01151 if (m_aborted || CheckAbortSearch(state)) 01152 { 01153 m_aborted = true; 01154 SgSynchronizeThreadMemory(); 01155 break; 01156 } 01157 } 01158 if (lock != 0) 01159 lock->unlock(); 01160 01161 m_searchLoopFinished->wait(); 01162 if (m_aborted || ! m_pruneFullTree) 01163 OnThreadEndSearch(state); 01164 } 01165 01166 void SgUctSearch::OnThreadStartSearch(SgUctThreadState& state) 01167 { 01168 m_mpiSynchronizer->OnThreadStartSearch(*this, state); 01169 } 01170 01171 void SgUctSearch::OnThreadEndSearch(SgUctThreadState& state) 01172 { 01173 m_mpiSynchronizer->OnThreadEndSearch(*this, state); 01174 } 01175 01176 SgPoint SgUctSearch::SearchOnePly(SgUctValue maxGames, double maxTime, 01177 SgUctValue& value) 01178 { 01179 if (m_threads.size() == 0) 01180 CreateThreads(); 01181 OnStartSearch(); 01182 // SearchOnePly is not multi-threaded. 01183 // It uses the state of the first thread. 01184 SgUctThreadState& state = ThreadState(0); 01185 state.StartSearch(); 01186 vector<SgUctMoveInfo> moves; 01187 SgUctProvenType provenType; 01188 state.GameStart(); 01189 state.GenerateAllMoves(0, moves, provenType); 01190 vector<SgUctStatistics> statistics(moves.size()); 01191 SgUctValue games = 0; 01192 m_timer.Start(); 01193 SgUctGameInfo& info = state.m_gameInfo; 01194 while (games < maxGames && m_timer.GetTime() < maxTime && ! SgUserAbort()) 01195 { 01196 for (size_t i = 0; i < moves.size(); ++i) 01197 { 01198 state.GameStart(); 01199 info.Clear(1); 01200 SgMove move = moves[i].m_move; 01201 state.Execute(move); 01202 info.m_inTreeSequence.push_back(move); 01203 info.m_sequence[0].push_back(move); 01204 info.m_skipRaveUpdate[0].push_back(false); 01205 state.StartPlayouts(); 01206 state.StartPlayout(); 01207 bool abortGame = ! PlayoutGame(state, 0); 01208 SgUctValue eval; 01209 if (abortGame) 01210 eval = UnknownEval(); 01211 else 01212 eval = state.Evaluate(); 01213 state.EndPlayout(); 01214 state.TakeBackPlayout(info.m_sequence[0].size() - 1); 01215 state.TakeBackInTree(1); 01216 statistics[i].Add(info.m_sequence[0].size() % 2 == 0 ? 01217 eval : InverseEval(eval)); 01218 OnSearchIteration(games + 1, 0, info); 01219 games += 1; 01220 } 01221 } 01222 SgMove bestMove = SG_NULLMOVE; 01223 for (size_t i = 0; i < moves.size(); ++i) 01224 { 01225 SgDebug() << MoveString(moves[i].m_move) 01226 << ' ' << statistics[i].Mean() << '\n'; 01227 if (bestMove == SG_NULLMOVE || statistics[i].Mean() > value) 01228 { 01229 bestMove = moves[i].m_move; 01230 value = statistics[i].Mean(); 01231 } 01232 } 01233 return bestMove; 01234 } 01235 01236 const SgUctNode& SgUctSearch::SelectChild(int& randomizeCounter, 01237 const SgUctNode& node) 01238 { 01239 bool useRave = m_rave; 01240 if (m_randomizeRaveFrequency > 0 && --randomizeCounter == 0) 01241 { 01242 useRave = false; 01243 randomizeCounter = m_randomizeRaveFrequency; 01244 } 01245 SG_ASSERT(node.HasChildren()); 01246 SgUctValue posCount = node.PosCount(); 01247 int virtualLossCount = node.VirtualLossCount(); 01248 if (virtualLossCount > 1) 01249 { 01250 // Note: must remove the virtual loss already added to 01251 // node for the current thread. 01252 posCount += SgUctValue(virtualLossCount - 1); 01253 } 01254 01255 if (posCount == 0) 01256 // If position count is zero, return first child 01257 return *SgUctChildIterator(m_tree, node); 01258 SgUctValue logPosCount = Log(posCount); 01259 const SgUctNode* bestChild = 0; 01260 SgUctValue bestUpperBound = 0; 01261 const SgUctValue epsilon = SgUctValue(1e-7); 01262 for (SgUctChildIterator it(m_tree, node); it; ++it) 01263 { 01264 const SgUctNode& child = *it; 01265 if (!child.IsProvenWin()) // Avoid losing moves 01266 { 01267 SgUctValue bound = GetBound(useRave, logPosCount, child); 01268 // Compare bound to best bound using a not too small epsilon 01269 // because the unit tests rely on the fact that the first child is 01270 // chosen if children have the same bounds and on some platforms 01271 // the result of the comparison is not well-defined and depends on 01272 // the compiler settings and the type of SgUctValue even if count 01273 // and value of the children are exactly the same. 01274 if (bestChild == 0 || bound > bestUpperBound + epsilon) 01275 { 01276 bestChild = &child; 01277 bestUpperBound = bound; 01278 } 01279 } 01280 } 01281 if (bestChild != 0) 01282 return *bestChild; 01283 // It can happen with multiple threads that all children are losing 01284 // in this state but this thread got in here before that information 01285 // was propagated up the tree. So just return the first child 01286 // in this case. 01287 return *node.FirstChild(); 01288 } 01289 01290 void SgUctSearch::SetNumberThreads(unsigned int n) 01291 { 01292 SG_ASSERT(n >= 1); 01293 if (m_numberThreads == n) 01294 return; 01295 m_numberThreads = n; 01296 CreateThreads(); 01297 } 01298 01299 void SgUctSearch::SetRave(bool enable) 01300 { 01301 if (enable && m_moveRange <= 0) 01302 throw SgException("RAVE not supported for this game"); 01303 m_rave = enable; 01304 } 01305 01306 void SgUctSearch::SetThreadStateFactory(SgUctThreadStateFactory* factory) 01307 { 01308 SG_ASSERT(m_threadStateFactory.get() == 0); 01309 m_threadStateFactory.reset(factory); 01310 DeleteThreads(); 01311 // Don't create states here, because this function could be called in the 01312 // constructor of the subclass, and the factory passes the search (which 01313 // is not fully constructed) as an argument to the Create() function 01314 } 01315 01316 void SgUctSearch::StartSearch(const vector<SgMove>& rootFilter, 01317 SgUctTree* initTree) 01318 { 01319 if (m_threads.size() == 0) 01320 CreateThreads(); 01321 if (m_numberThreads > 1 && SgTime::DefaultMode() == SG_TIME_CPU) 01322 // Using CPU time with multiple threads makes the measured time 01323 // and games/sec not very meaningful; the total cputime is not equal 01324 // to the total real time, even if there is no other load on the 01325 // machine, because the time, while threads are waiting for a lock 01326 // does not contribute to the cputime. 01327 SgWarning() << "SgUctSearch: using cpu time with multiple threads\n"; 01328 m_raveWeightParam1 = (SgUctValue)(1.0 / m_raveWeightInitial); 01329 m_raveWeightParam2 = (SgUctValue)(1.0 / m_raveWeightFinal); 01330 if (initTree == 0) 01331 m_tree.Clear(); 01332 else 01333 { 01334 m_tree.Swap(*initTree); 01335 if (m_tree.HasCapacity(0, m_tree.Root().NuChildren())) 01336 m_tree.ApplyFilter(0, m_tree.Root(), rootFilter); 01337 else 01338 SgWarning() << 01339 "SgUctSearch: " 01340 "root filter not applied (tree reached maximum size)\n"; 01341 } 01342 m_statistics.Clear(); 01343 m_aborted = false; 01344 m_wasEarlyAbort = false; 01345 m_checkTimeInterval = 1; 01346 m_numberGames = 0; 01347 m_lastScoreDisplayTime = m_timer.GetTime(); 01348 OnStartSearch(); 01349 01350 m_nextCheckTime = (SgUctValue)m_checkTimeInterval; 01351 m_startRootMoveCount = m_tree.Root().MoveCount(); 01352 01353 for (unsigned int i = 0; i < m_threads.size(); ++i) 01354 { 01355 SgUctThreadState& state = ThreadState(i); 01356 state.m_randomizeCounter = m_randomizeRaveFrequency; 01357 state.StartSearch(); 01358 } 01359 } 01360 01361 void SgUctSearch::EndSearch() 01362 { 01363 OnEndSearch(); 01364 } 01365 01366 string SgUctSearch::SummaryLine(const SgUctGameInfo& info) const 01367 { 01368 ostringstream buffer; 01369 const vector<const SgUctNode*>& nodes = info.m_nodes; 01370 for (size_t i = 1; i < nodes.size(); ++i) 01371 { 01372 const SgUctNode* node = nodes[i]; 01373 SgMove move = node->Move(); 01374 buffer << ' ' << MoveString(move) << " (" << fixed << setprecision(2) 01375 << node->Mean() << ',' << node->MoveCount() << ')'; 01376 } 01377 for (size_t i = 0; i < info.m_eval.size(); ++i) 01378 buffer << ' ' << fixed << setprecision(2) << info.m_eval[i]; 01379 return buffer.str(); 01380 } 01381 01382 void SgUctSearch::UpdateCheckTimeInterval(double time) 01383 { 01384 if (time < numeric_limits<double>::epsilon()) 01385 return; 01386 // Dynamically update m_checkTimeInterval (see comment at definition of 01387 // m_checkTimeInterval) 01388 double wantedTimeDiff = (m_maxTime > 1 ? 0.1 : 0.1 * m_maxTime); 01389 if (time < wantedTimeDiff / 10) 01390 { 01391 // Computing games per second might be unreliable for small times 01392 m_checkTimeInterval *= 2; 01393 return; 01394 } 01395 m_statistics.m_gamesPerSecond = GamesPlayed() / time; 01396 double gamesPerSecondPerThread = 01397 m_statistics.m_gamesPerSecond / double(m_numberThreads); 01398 m_checkTimeInterval = SgUctValue(wantedTimeDiff * gamesPerSecondPerThread); 01399 if (m_checkTimeInterval == 0) 01400 m_checkTimeInterval = 1; 01401 } 01402 01403 /** Update the RAVE values in the tree for both players after a game was 01404 played. 01405 @see SgUctSearch::Rave() */ 01406 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state) 01407 { 01408 for (size_t i = 0; i < m_numberPlayouts; ++i) 01409 UpdateRaveValues(state, i); 01410 } 01411 01412 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state, 01413 std::size_t playout) 01414 { 01415 SgUctGameInfo& info = state.m_gameInfo; 01416 const vector<SgMove>& sequence = info.m_sequence[playout]; 01417 if (sequence.size() == 0) 01418 return; 01419 SG_ASSERT(m_moveRange > 0); 01420 size_t* firstPlay = state.m_firstPlay.get(); 01421 size_t* firstPlayOpp = state.m_firstPlayOpp.get(); 01422 fill_n(firstPlay, m_moveRange, numeric_limits<size_t>::max()); 01423 fill_n(firstPlayOpp, m_moveRange, numeric_limits<size_t>::max()); 01424 const vector<const SgUctNode*>& nodes = info.m_nodes; 01425 const vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout]; 01426 SgUctValue eval = info.m_eval[playout]; 01427 SgUctValue invEval = InverseEval(eval); 01428 size_t nuNodes = nodes.size(); 01429 size_t i = sequence.size() - 1; 01430 bool opp = (i % 2 != 0); 01431 01432 // Update firstPlay, firstPlayOpp arrays using playout moves 01433 for ( ; i >= nuNodes; --i) 01434 { 01435 SG_ASSERT(i < skipRaveUpdate.size()); 01436 SG_ASSERT(i < sequence.size()); 01437 if (! skipRaveUpdate[i]) 01438 { 01439 SgMove mv = sequence[i]; 01440 size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]); 01441 if (i < first) 01442 first = i; 01443 } 01444 opp = ! opp; 01445 } 01446 01447 while (true) 01448 { 01449 SG_ASSERT(i < skipRaveUpdate.size()); 01450 SG_ASSERT(i < sequence.size()); 01451 // skipRaveUpdate currently not used in in-tree phase 01452 SG_ASSERT(i >= info.m_inTreeSequence.size() || ! skipRaveUpdate[i]); 01453 if (! skipRaveUpdate[i]) 01454 { 01455 SgMove mv = sequence[i]; 01456 size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]); 01457 if (i < first) 01458 first = i; 01459 if (opp) 01460 UpdateRaveValues(state, playout, invEval, i, 01461 firstPlayOpp, firstPlay); 01462 else 01463 UpdateRaveValues(state, playout, eval, i, 01464 firstPlay, firstPlayOpp); 01465 } 01466 if (i == 0) 01467 break; 01468 --i; 01469 opp = ! opp; 01470 } 01471 } 01472 01473 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state, 01474 std::size_t playout, SgUctValue eval, 01475 std::size_t i, 01476 const std::size_t firstPlay[], 01477 const std::size_t firstPlayOpp[]) 01478 { 01479 SG_ASSERT(i < state.m_gameInfo.m_nodes.size()); 01480 const SgUctNode* node = state.m_gameInfo.m_nodes[i]; 01481 if (! node->HasChildren()) 01482 return; 01483 size_t len = state.m_gameInfo.m_sequence[playout].size(); 01484 for (SgUctChildIterator it(m_tree, *node); it; ++it) 01485 { 01486 const SgUctNode& child = *it; 01487 SgMove mv = child.Move(); 01488 size_t first = firstPlay[mv]; 01489 SG_ASSERT(first >= i); 01490 if (first == numeric_limits<size_t>::max()) 01491 continue; 01492 if (m_raveCheckSame && SgUtil::InRange(firstPlayOpp[mv], i, first)) 01493 continue; 01494 SgUctValue weight; 01495 if (m_weightRaveUpdates) 01496 weight = 2 - SgUctValue(first - i) / SgUctValue(len - i); 01497 else 01498 weight = 1; 01499 m_tree.AddRaveValue(child, eval, weight); 01500 } 01501 } 01502 01503 void SgUctSearch::UpdateStatistics(const SgUctGameInfo& info) 01504 { 01505 m_statistics.m_movesInTree.Add( 01506 static_cast<float>(info.m_inTreeSequence.size())); 01507 for (size_t i = 0; i < m_numberPlayouts; ++i) 01508 { 01509 m_statistics.m_gameLength.Add( 01510 static_cast<float>(info.m_sequence[i].size())); 01511 m_statistics.m_aborted.Add(info.m_aborted[i] ? 1.f : 0.f); 01512 } 01513 } 01514 01515 void SgUctSearch::UpdateTree(const SgUctGameInfo& info) 01516 { 01517 SgUctValue eval = 0; 01518 for (size_t i = 0; i < m_numberPlayouts; ++i) 01519 eval += info.m_eval[i]; 01520 eval /= SgUctValue(m_numberPlayouts); 01521 SgUctValue inverseEval = InverseEval(eval); 01522 const vector<const SgUctNode*>& nodes = info.m_nodes; 01523 SgUctValue count = SgUctValue(m_numberPlayouts); 01524 for (size_t i = 0; i < nodes.size(); ++i) 01525 { 01526 const SgUctNode& node = *nodes[i]; 01527 const SgUctNode* father = (i > 0 ? nodes[i - 1] : 0); 01528 m_tree.AddGameResults(node, father, i % 2 == 0 ? eval : inverseEval, 01529 count); 01530 // Remove the virtual loss 01531 if (m_virtualLoss && m_numberThreads > 1) 01532 m_tree.RemoveVirtualLoss(node); 01533 } 01534 } 01535 01536 void SgUctSearch::WriteStatistics(ostream& out) const 01537 { 01538 out << SgWriteLabel("Count") << m_tree.Root().MoveCount() << '\n' 01539 << SgWriteLabel("GamesPlayed") << GamesPlayed() << '\n' 01540 << SgWriteLabel("Nodes") << m_tree.NuNodes() << '\n'; 01541 if (!m_knowledgeThreshold.empty()) 01542 out << SgWriteLabel("Knowledge") 01543 << m_statistics.m_knowledge << " (" << fixed << setprecision(1) 01544 << m_statistics.m_knowledge * 100.0 / m_tree.Root().MoveCount() 01545 << "%)\n"; 01546 m_statistics.Write(out); 01547 m_mpiSynchronizer->WriteStatistics(out); 01548 } 01549 01550 //----------------------------------------------------------------------------