Index   Main   Namespaces   Classes   Hierarchy   Annotated   Files   Compound   Global   Pages  

SgUctSearch.cpp

Go to the documentation of this file.
00001 //----------------------------------------------------------------------------
00002 /** @file SgUctSearch.cpp */
00003 //----------------------------------------------------------------------------
00004 
00005 #include "SgSystem.h"
00006 #include "SgUctSearch.h"
00007 
00008 #include <algorithm>
00009 #include <cmath>
00010 #include <iomanip>
00011 #include <boost/format.hpp>
00012 #include <boost/io/ios_state.hpp>
00013 #include <boost/version.hpp>
00014 #include "SgDebug.h"
00015 #include "SgHashTable.h"
00016 #include "SgMath.h"
00017 #include "SgPlatform.h"
00018 #include "SgWrite.h"
00019 
00020 using namespace std;
00021 using boost::barrier;
00022 using boost::condition;
00023 using boost::format;
00024 using boost::mutex;
00025 using boost::shared_ptr;
00026 using boost::io::ios_all_saver;
00027 
00028 #define BOOST_VERSION_MAJOR (BOOST_VERSION / 100000)
00029 #define BOOST_VERSION_MINOR (BOOST_VERSION / 100 % 1000)
00030 
00031 //----------------------------------------------------------------------------
00032 
00033 namespace {
00034 
00035 const bool DEBUG_THREADS = false;
00036 
00037 /** Get a default value for lock-free mode.
00038     Lock-free mode works only on IA-32/Intel-64 architectures or if the macro
00039     ENABLE_CACHE_SYNC from Fuego's configure script is defined. The
00040     architecture is determined by using the macro HOST_CPU from Fuego's
00041     configure script. On Windows, an Intel architecture is always assumed. */
00042 bool GetLockFreeDefault()
00043 {
00044 #if defined(WIN32) || defined(ENABLE_CACHE_SYNC)
00045     return true;
00046 #elif defined(HOST_CPU)
00047     string hostCpu(HOST_CPU);
00048     return hostCpu == "i386" || hostCpu == "i486" || hostCpu == "i586"
00049         || hostCpu == "i686" || hostCpu == "x86_64";
00050 #else
00051     return false;
00052 #endif
00053 }
00054 
00055 /** Get a default value for the tree size.
00056     The default value is that both trees used by SgUctSearch take no more than
00057     half of the total amount of memory on the system (but no less than
00058     384 MB and not more than 1 GB). */
00059 size_t GetMaxNodesDefault()
00060 {
00061     size_t totalMemory = SgPlatform::TotalMemory();
00062     SgDebug() << "SgUctSearch: system memory ";
00063     if (totalMemory == 0)
00064         SgDebug() << "unknown";
00065     else
00066         SgDebug() << totalMemory;
00067     // Use half of the physical memory by default but at least 284K and not
00068     // more than 1G
00069     size_t searchMemory = totalMemory / 2;
00070     if (searchMemory < 384000000)
00071         searchMemory = 384000000;
00072     if (searchMemory > 1000000000)
00073         searchMemory = 1000000000;
00074     size_t memoryPerTree = searchMemory / 2;
00075     size_t nodesPerTree = memoryPerTree / sizeof(SgUctNode);
00076     SgDebug() << ", using " << searchMemory << " (" << nodesPerTree
00077               << " nodes)\n";
00078     return nodesPerTree;
00079 }
00080 
00081 void Notify(mutex& aMutex, condition& aCondition)
00082 {
00083     mutex::scoped_lock lock(aMutex);
00084     aCondition.notify_all();
00085 }
00086 
00087 } // namespace
00088 
00089 //----------------------------------------------------------------------------
00090 
00091 void SgUctGameInfo::Clear(std::size_t numberPlayouts)
00092 {
00093     m_nodes.clear();
00094     m_inTreeSequence.clear();
00095     if (numberPlayouts != m_sequence.size())
00096     {
00097         m_sequence.resize(numberPlayouts);
00098         m_skipRaveUpdate.resize(numberPlayouts);
00099         m_eval.resize(numberPlayouts);
00100         m_aborted.resize(numberPlayouts);
00101     }
00102     for (size_t i = 0; i < numberPlayouts; ++i)
00103     {
00104         m_sequence[i].clear();
00105         m_skipRaveUpdate[i].clear();
00106     }
00107 }
00108 
00109 //----------------------------------------------------------------------------
00110 
00111 SgUctThreadState::SgUctThreadState(unsigned int threadId, int moveRange)
00112     : m_threadId(threadId),
00113       m_isSearchInitialized(false),
00114       m_isTreeOutOfMem(false)
00115 {
00116     if (moveRange > 0)
00117     {
00118         m_firstPlay.reset(new size_t[moveRange]);
00119         m_firstPlayOpp.reset(new size_t[moveRange]);
00120     }
00121 }
00122 
00123 SgUctThreadState::~SgUctThreadState()
00124 {
00125 }
00126 
00127 void SgUctThreadState::EndPlayout()
00128 {
00129     // Default implementation does nothing
00130 }
00131 
00132 void SgUctThreadState::GameStart()
00133 {
00134     // Default implementation does nothing
00135 }
00136 
00137 void SgUctThreadState::StartPlayout()
00138 {
00139     // Default implementation does nothing
00140 }
00141 
00142 void SgUctThreadState::StartPlayouts()
00143 {
00144     // Default implementation does nothing
00145 }
00146 
00147 //----------------------------------------------------------------------------
00148 
00149 SgUctThreadStateFactory::~SgUctThreadStateFactory()
00150 {
00151 }
00152 
00153 //----------------------------------------------------------------------------
00154 
00155 SgUctSearch::Thread::Function::Function(Thread& thread)
00156     : m_thread(thread)
00157 {
00158 }
00159 
00160 void SgUctSearch::Thread::Function::operator()()
00161 {
00162     m_thread();
00163 }
00164 
00165 SgUctSearch::Thread::Thread(SgUctSearch& search,
00166                             auto_ptr<SgUctThreadState> state)
00167     : m_state(state),
00168       m_search(search),
00169       m_quit(false),
00170       m_threadReady(2),
00171       m_playFinishedLock(m_playFinishedMutex),
00172 #if BOOST_VERSION_MAJOR == 1 && BOOST_VERSION_MINOR <= 34
00173       m_globalLock(search.m_globalMutex, false),
00174 #else
00175       m_globalLock(search.m_globalMutex, boost::defer_lock),
00176 #endif
00177       m_thread(Function(*this))
00178 {
00179     m_threadReady.wait();
00180 }
00181 
00182 SgUctSearch::Thread::~Thread()
00183 {
00184     m_quit = true;
00185     StartPlay();
00186     m_thread.join();
00187 }
00188 
00189 void SgUctSearch::Thread::operator()()
00190 {
00191     if (DEBUG_THREADS)
00192         SgDebug() << "SgUctSearch::Thread: starting thread "
00193                   << m_state->m_threadId << '\n';
00194     mutex::scoped_lock lock(m_startPlayMutex);
00195     m_threadReady.wait();
00196     while (true)
00197     {
00198         m_startPlay.wait(lock);
00199         if (m_quit)
00200             break;
00201         m_search.SearchLoop(*m_state, &m_globalLock);
00202         Notify(m_playFinishedMutex, m_playFinished);
00203     }
00204     if (DEBUG_THREADS)
00205         SgDebug() << "SgUctSearch::Thread: finishing thread "
00206                   << m_state->m_threadId << '\n';
00207 }
00208 
00209 void SgUctSearch::Thread::StartPlay()
00210 {
00211     Notify(m_startPlayMutex, m_startPlay);
00212 }
00213 
00214 void SgUctSearch::Thread::WaitPlayFinished()
00215 {
00216     m_playFinished.wait(m_playFinishedLock);
00217 }
00218 
00219 //----------------------------------------------------------------------------
00220 
00221 void SgUctSearchStat::Clear()
00222 {
00223     m_time = 0;
00224     m_knowledge = 0;
00225     m_gamesPerSecond = 0;
00226     m_gameLength.Clear();
00227     m_movesInTree.Clear();
00228     m_aborted.Clear();
00229 }
00230 
00231 void SgUctSearchStat::Write(std::ostream& out) const
00232 {
00233     ios_all_saver saver(out);
00234     out << SgWriteLabel("Time") << setprecision(2) << m_time << '\n'
00235         << SgWriteLabel("GameLen") << fixed << setprecision(1);
00236     m_gameLength.Write(out);
00237     out << '\n'
00238         << SgWriteLabel("InTree");
00239     m_movesInTree.Write(out);
00240     out << '\n'
00241         << SgWriteLabel("Aborted")
00242         << static_cast<int>(100 * m_aborted.Mean()) << "%\n"
00243         << SgWriteLabel("Games/s") << fixed << setprecision(1)
00244         << m_gamesPerSecond << '\n';
00245 }
00246 
00247 //----------------------------------------------------------------------------
00248 
00249 SgUctSearch::SgUctSearch(SgUctThreadStateFactory* threadStateFactory,
00250                          int moveRange)
00251     : m_threadStateFactory(threadStateFactory),
00252       m_logGames(false),
00253       m_rave(false),
00254       m_knowledgeThreshold(),
00255       m_moveSelect(SG_UCTMOVESELECT_COUNT),
00256       m_raveCheckSame(false),
00257       m_randomizeRaveFrequency(20),
00258       m_lockFree(GetLockFreeDefault()),
00259       m_weightRaveUpdates(true),
00260       m_pruneFullTree(true),
00261       m_checkFloatPrecision(true),
00262       m_numberThreads(1),
00263       m_numberPlayouts(1),
00264       m_maxNodes(GetMaxNodesDefault()),
00265       m_pruneMinCount(16),
00266       m_moveRange(moveRange),
00267       m_maxGameLength(numeric_limits<size_t>::max()),
00268       m_expandThreshold(numeric_limits<SgUctValue>::is_integer ? (SgUctValue)1 : numeric_limits<SgUctValue>::epsilon()),
00269       m_biasTermConstant(0.7f),
00270       m_firstPlayUrgency(10000),
00271       m_raveWeightInitial(0.9f),
00272       m_raveWeightFinal(20000),
00273       m_virtualLoss(false),
00274       m_logFileName("uctsearch.log"),
00275       m_fastLog(10),
00276       m_mpiSynchronizer(SgMpiNullSynchronizer::Create())
00277 {
00278     // Don't create thread states here, because the factory passes the search
00279     // (which is not fully constructed here, because the subclass constructors
00280     // are not called yet) as an argument to the Create() function
00281 }
00282 
00283 SgUctSearch::~SgUctSearch()
00284 {
00285     DeleteThreads();
00286 }
00287 
00288 void SgUctSearch::ApplyRootFilter(vector<SgUctMoveInfo>& moves)
00289 {
00290     // Filter without changing the order of the unfiltered moves
00291     vector<SgUctMoveInfo> filteredMoves;
00292     for (vector<SgUctMoveInfo>::const_iterator it = moves.begin();
00293          it != moves.end(); ++it)
00294         if (find(m_rootFilter.begin(), m_rootFilter.end(), it->m_move)
00295             == m_rootFilter.end())
00296             filteredMoves.push_back(*it);
00297     moves = filteredMoves;
00298 }
00299 
00300 SgUctValue SgUctSearch::GamesPlayed() const
00301 {
00302     return m_tree.Root().MoveCount() - m_startRootMoveCount;
00303 }
00304 
00305 bool SgUctSearch::CheckAbortSearch(SgUctThreadState& state)
00306 {
00307     if (SgUserAbort())
00308     {
00309         Debug(state, "SgUctSearch: abort flag");
00310         return true;
00311     }
00312     const SgUctNode& root = m_tree.Root();
00313     if (! SgUctValueUtil::IsPrecise(root.MoveCount()) && m_checkFloatPrecision)
00314     {
00315         Debug(state, "SgUctSearch: floating point type precision reached");
00316         return true;
00317     }
00318     SgUctValue rootCount = root.MoveCount();
00319     if (rootCount >= m_maxGames)
00320     {
00321         Debug(state, "SgUctSearch: max games reached");
00322         return true;
00323     }
00324     if (root.IsProven())
00325     {
00326         if (root.IsProvenWin())
00327             Debug(state, "SgUctSearch: root is proven win!");
00328         else 
00329             Debug(state, "SgUctSearch: root is proven loss!");
00330         return true;
00331     }
00332     const bool isEarlyAbort = CheckEarlyAbort();
00333     if (   isEarlyAbort
00334         && m_earlyAbort->m_reductionFactor * rootCount >= m_maxGames
00335        )
00336     {
00337         Debug(state, "SgUctSearch: max games reached (early abort)");
00338         m_wasEarlyAbort = true;
00339         return true;
00340     }
00341     if (m_numberGames >= m_nextCheckTime)
00342     {
00343         m_nextCheckTime = m_numberGames + m_checkTimeInterval;
00344         double time = m_timer.GetTime();
00345         if (time > m_maxTime)
00346         {
00347             Debug(state, "SgUctSearch: max time reached");
00348             return true;
00349         }
00350         if (isEarlyAbort
00351             && m_earlyAbort->m_reductionFactor * time > m_maxTime)
00352         {
00353             Debug(state, "SgUctSearch: max time reached (early abort)");
00354             m_wasEarlyAbort = true;
00355             return true;
00356         }
00357         UpdateCheckTimeInterval(time);
00358         if (m_moveSelect == SG_UCTMOVESELECT_COUNT)
00359         {
00360             double remainingGamesDouble = m_maxGames - rootCount - 1;
00361             // Use time based count abort, only if time > 1, otherwise
00362             // m_gamesPerSecond is unreliable
00363             if (time > 1.)
00364             {
00365                 double remainingTime = m_maxTime - time;
00366                 remainingGamesDouble =
00367                     min(remainingGamesDouble,
00368                         remainingTime * m_statistics.m_gamesPerSecond);
00369             }
00370             SgUctValue uctCountMax = numeric_limits<SgUctValue>::max();
00371             SgUctValue remainingGames;
00372             if (remainingGamesDouble >= static_cast<double>(uctCountMax - 1))
00373                 remainingGames = uctCountMax;
00374             else
00375                 remainingGames = SgUctValue(remainingGamesDouble);
00376             if (CheckCountAbort(state, remainingGames))
00377             {
00378                 Debug(state, "SgUctSearch: move cannot change anymore");
00379                 return true;
00380             }
00381         }
00382     }
00383     return false;
00384 }
00385 
00386 bool SgUctSearch::CheckCountAbort(SgUctThreadState& state,
00387                                   SgUctValue remainingGames) const
00388 {
00389     const SgUctNode& root = m_tree.Root();
00390     const SgUctNode* bestChild = FindBestChild(root);
00391     if (bestChild == 0)
00392         return false;
00393     SgUctValue bestCount = bestChild->MoveCount();
00394     vector<SgMove>& excludeMoves = state.m_excludeMoves;
00395     excludeMoves.clear();
00396     excludeMoves.push_back(bestChild->Move());
00397     const SgUctNode* secondBestChild = FindBestChild(root, &excludeMoves);
00398     if (secondBestChild == 0)
00399         return false;
00400     SgUctValue secondBestCount = secondBestChild->MoveCount();
00401     SG_ASSERT(secondBestCount <= bestCount || m_numberThreads > 1);
00402     return (remainingGames <= bestCount - secondBestCount);
00403 }
00404 
00405 bool SgUctSearch::CheckEarlyAbort() const
00406 {
00407     const SgUctNode& root = m_tree.Root();
00408     return   m_earlyAbort.get() != 0
00409           && root.HasMean()
00410           && root.MoveCount() > m_earlyAbort->m_minGames
00411           && root.Mean() > m_earlyAbort->m_threshold;
00412 }
00413 
00414 void SgUctSearch::CreateThreads()
00415 {
00416     DeleteThreads();
00417     for (unsigned int i = 0; i < m_numberThreads; ++i)
00418     {
00419         auto_ptr<SgUctThreadState> state(
00420                                       m_threadStateFactory->Create(i, *this));
00421         shared_ptr<Thread> thread(new Thread(*this, state));
00422         m_threads.push_back(thread);
00423     }
00424     m_tree.CreateAllocators(m_numberThreads);
00425     m_tree.SetMaxNodes(m_maxNodes);
00426 
00427     m_searchLoopFinished.reset(new barrier(m_numberThreads));
00428 }
00429 
00430 /** Write a debugging line of text from within a thread.
00431     Prepends the line with the thread number if number of threads is greater
00432     than one. Also ensures that the line is written as a single string to
00433     avoid intermingling of text lines from different threads.
00434     @param state The state of the thread (only used for state.m_threadId)
00435     @param textLine The line of text without trailing newline character. */
00436 void SgUctSearch::Debug(const SgUctThreadState& state,
00437                         const std::string& textLine)
00438 {
00439     if (m_numberThreads > 1)
00440     {
00441         // SgDebug() is not necessarily thread-safe
00442         GlobalLock lock(m_globalMutex);
00443         SgDebug() << (format("[%1%] %2%\n") % state.m_threadId % textLine);
00444     }
00445     else
00446         SgDebug() << (format("%1%\n") % textLine);
00447 }
00448 
00449 void SgUctSearch::DeleteThreads()
00450 {
00451     m_threads.clear();
00452 }
00453 
00454 /** Expand a node.
00455     @param state The thread state with state.m_moves already computed.
00456     @param node The node to expand. */
00457 void SgUctSearch::ExpandNode(SgUctThreadState& state, const SgUctNode& node)
00458 {
00459     unsigned int threadId = state.m_threadId;
00460     if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00461     {
00462         Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00463                          % m_tree.MaxNodes()));
00464         state.m_isTreeOutOfMem = true;
00465         m_isTreeOutOfMemory = true;
00466         SgSynchronizeThreadMemory();
00467         return;
00468     }
00469     m_tree.CreateChildren(threadId, node, state.m_moves);
00470 }
00471 
00472 const SgUctNode*
00473 SgUctSearch::FindBestChild(const SgUctNode& node,
00474                            const vector<SgMove>* excludeMoves) const
00475 {
00476     if (! node.HasChildren())
00477         return 0;
00478     const SgUctNode* bestChild = 0;
00479     SgUctValue bestValue = 0;
00480     for (SgUctChildIterator it(m_tree, node); it; ++it)
00481     {
00482         const SgUctNode& child = *it;
00483         if (excludeMoves != 0)
00484         {
00485             vector<SgMove>::const_iterator begin = excludeMoves->begin();
00486             vector<SgMove>::const_iterator end = excludeMoves->end();
00487             if (find(begin, end, child.Move()) != end)
00488                 continue;
00489         }
00490         if (  ! child.HasMean()
00491            && ! (  (  m_moveSelect == SG_UCTMOVESELECT_BOUND
00492                    || m_moveSelect == SG_UCTMOVESELECT_ESTIMATE
00493                    )
00494                 && m_rave
00495                 && child.HasRaveValue()
00496                 )
00497             )
00498             continue;
00499         if (child.IsProvenLoss()) // Always choose winning move!
00500         {
00501             bestChild = &child;
00502             break;
00503         }
00504         SgUctValue value;
00505         switch (m_moveSelect)
00506         {
00507         case SG_UCTMOVESELECT_VALUE:
00508             value = InverseEstimate((SgUctValue)child.Mean());
00509             break;
00510         case SG_UCTMOVESELECT_COUNT:
00511             value = child.MoveCount();
00512             break;
00513         case SG_UCTMOVESELECT_BOUND:
00514             value = GetBound(m_rave, node, child);
00515             break;
00516         case SG_UCTMOVESELECT_ESTIMATE:
00517             value = GetValueEstimate(m_rave, child);
00518             break;
00519         default:
00520             SG_ASSERT(false);
00521             value = SG_UCTMOVESELECT_VALUE;
00522         }
00523         if (bestChild == 0 || value > bestValue)
00524         {
00525             bestChild = &child;
00526             bestValue = value;
00527         }
00528     }
00529     return bestChild;
00530 }
00531 
00532 void SgUctSearch::FindBestSequence(vector<SgMove>& sequence) const
00533 {
00534     sequence.clear();
00535     const SgUctNode* current = &m_tree.Root();
00536     while (true)
00537     {
00538         current = FindBestChild(*current);
00539         if (current == 0)
00540             break;
00541         sequence.push_back(current->Move());
00542         if (! current->HasChildren())
00543             break;
00544     }
00545 }
00546 
00547 void SgUctSearch::GenerateAllMoves(std::vector<SgUctMoveInfo>& moves)
00548 {
00549     if (m_threads.size() == 0)
00550         CreateThreads();
00551     moves.clear();
00552     OnStartSearch();
00553     SgUctThreadState& state = ThreadState(0);
00554     state.StartSearch();
00555     SgUctProvenType type;
00556     state.GenerateAllMoves(0, moves, type);
00557 }
00558 
00559 SgUctValue SgUctSearch::GetBound(bool useRave, const SgUctNode& node,
00560                                  const SgUctNode& child) const
00561 {
00562     SgUctValue posCount = node.PosCount();
00563     int virtualLossCount = node.VirtualLossCount();
00564     if (virtualLossCount > 0)
00565     {
00566         posCount += SgUctValue(virtualLossCount);
00567     }
00568     return GetBound(useRave, Log(posCount), child);
00569 }
00570 
00571 SgUctValue SgUctSearch::GetBound(bool useRave, SgUctValue logPosCount,
00572                                  const SgUctNode& child) const
00573 {
00574     SgUctValue value;
00575     if (useRave)
00576         value = GetValueEstimateRave(child);
00577     else
00578         value = GetValueEstimate(false, child);
00579     if (m_biasTermConstant == 0.0)
00580         return value;
00581     else
00582     {
00583         SgUctValue moveCount = static_cast<SgUctValue>(child.MoveCount());
00584         SgUctValue bound =
00585             value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1));
00586         return bound;
00587     }
00588 }
00589 
00590 SgUctTree& SgUctSearch::GetTempTree()
00591 {
00592     m_tempTree.Clear();
00593     // Use NumberThreads() (not m_tree.NuAllocators()) and MaxNodes() (not
00594     // m_tree.MaxNodes()), because of the delayed thread (and thereby
00595     // allocator) creation in SgUctSearch
00596     if (m_tempTree.NuAllocators() != NumberThreads())
00597     {
00598         m_tempTree.CreateAllocators(NumberThreads());
00599         m_tempTree.SetMaxNodes(MaxNodes());
00600     }
00601     else if (m_tempTree.MaxNodes() != MaxNodes())
00602     {
00603         m_tempTree.SetMaxNodes(MaxNodes());
00604     }
00605     return m_tempTree;
00606 }
00607 
00608 SgUctValue SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const
00609 {
00610     SgUctValue value = 0;
00611     SgUctValue weightSum = 0;
00612     bool hasValue = false;
00613 
00614     SgUctStatistics uctStats;
00615     if (child.HasMean())
00616     {
00617         uctStats.Initialize(child.Mean(), child.MoveCount());
00618     }
00619     int virtualLossCount = child.VirtualLossCount();
00620     if (virtualLossCount > 0)
00621     {
00622         uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
00623     }
00624 
00625     if (uctStats.IsDefined())
00626     {
00627         SgUctValue weight = static_cast<SgUctValue>(uctStats.Count());
00628         value += weight * InverseEstimate((SgUctValue)uctStats.Mean());
00629         weightSum += weight;
00630         hasValue = true;
00631     }
00632 
00633     if (useRave)
00634     {
00635         SgUctStatistics raveStats;
00636         if (child.HasRaveValue())
00637         {
00638             raveStats.Initialize(child.RaveValue(), child.RaveCount());
00639         }
00640         if (virtualLossCount > 0)
00641         {
00642             raveStats.Add(0, SgUctValue(virtualLossCount));
00643         }
00644         if (raveStats.IsDefined())
00645         {
00646             SgUctValue raveCount = raveStats.Count();
00647             SgUctValue weight =
00648                 raveCount
00649                 / (  m_raveWeightParam1
00650                      + m_raveWeightParam2 * raveCount
00651                      );
00652             value += weight * raveStats.Mean();
00653             weightSum += weight;
00654             hasValue = true;
00655         }
00656     }
00657     if (hasValue)
00658         return value / weightSum;
00659     else
00660         return m_firstPlayUrgency;
00661 }
00662 
00663 /** Optimized version of GetValueEstimate() if RAVE and not other
00664     estimators are used.
00665     Previously there were more estimators than move value and RAVE value,
00666     and in the future there may be again. GetValueEstimate() is easier to
00667     extend, this function is more optimized for the special case. */
00668 SgUctValue SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const
00669 {
00670     SG_ASSERT(m_rave);
00671     SgUctValue value;
00672     SgUctStatistics uctStats;
00673     if (child.HasMean())
00674     {
00675         uctStats.Initialize(child.Mean(), child.MoveCount());
00676     }
00677     SgUctStatistics raveStats;
00678     if (child.HasRaveValue())
00679     {
00680         raveStats.Initialize(child.RaveValue(), child.RaveCount());
00681     }
00682     int virtualLossCount = child.VirtualLossCount();
00683     if (virtualLossCount > 0)
00684     {
00685         uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
00686         raveStats.Add(0, SgUctValue(virtualLossCount));
00687     }
00688     bool hasRave = raveStats.IsDefined();
00689     
00690     if (uctStats.IsDefined())
00691     {
00692         SgUctValue moveValue = InverseEstimate((SgUctValue)uctStats.Mean());
00693         if (hasRave)
00694         {
00695             SgUctValue moveCount = uctStats.Count();
00696             SgUctValue raveCount = raveStats.Count();
00697             SgUctValue weight =
00698                 raveCount
00699                 / (moveCount
00700                    * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount)
00701                    + raveCount);
00702             value = weight * raveStats.Mean() + (1.f - weight) * moveValue;
00703         }
00704         else
00705         {
00706             // This can happen only in lock-free multi-threading. Normally,
00707             // each move played in a position should also cause a RAVE value
00708             // to be added. But in lock-free multi-threading it can happen
00709             // that the move value was already updated but the RAVE value not
00710             SG_ASSERT(m_numberThreads > 1 && m_lockFree);
00711             value = moveValue;
00712         }
00713     }
00714     else if (hasRave)
00715         value = raveStats.Mean();
00716     else
00717         value = m_firstPlayUrgency;
00718     SG_ASSERT(m_numberThreads > 1
00719               || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3/*epsilon*/);
00720     return value;
00721 }
00722 
00723 string SgUctSearch::LastGameSummaryLine() const
00724 {
00725     return SummaryLine(LastGameInfo());
00726 }
00727 
00728 SgUctValue SgUctSearch::Log(SgUctValue x) const
00729 {
00730 #if SG_UCTFASTLOG
00731     return SgUctValue(m_fastLog.Log(float(x)));
00732 #else
00733     return log(x);
00734 #endif
00735 }
00736 
00737 /** Creates the children with the given moves and merges with existing
00738     children in the tree. */
00739 void SgUctSearch::CreateChildren(SgUctThreadState& state, 
00740                                  const SgUctNode& node,
00741                                  bool deleteChildTrees)
00742 {
00743     unsigned int threadId = state.m_threadId;
00744     if (! m_tree.HasCapacity(threadId, state.m_moves.size()))
00745     {
00746         Debug(state, str(format("SgUctSearch: maximum tree size %1% reached")
00747                          % m_tree.MaxNodes()));
00748         state.m_isTreeOutOfMem = true;
00749         m_isTreeOutOfMemory = true;
00750         SgSynchronizeThreadMemory();
00751         return;
00752     }
00753     m_tree.MergeChildren(threadId, node, state.m_moves, deleteChildTrees);
00754 }
00755 
00756 bool SgUctSearch::NeedToComputeKnowledge(const SgUctNode* current)
00757 {
00758     if (m_knowledgeThreshold.empty())
00759         return false;
00760     for (std::size_t i = 0; i < m_knowledgeThreshold.size(); ++i)
00761     {
00762         const SgUctValue threshold = m_knowledgeThreshold[i];
00763         if (current->KnowledgeCount() < threshold)
00764         {
00765             if (current->MoveCount() >= threshold)
00766             {
00767                 // Mark knowledge computed immediately so other
00768                 // threads fall through and do not waste time
00769                 // re-computing this knowledge.
00770                 m_tree.SetKnowledgeCount(*current, threshold);
00771                 SG_ASSERT(current->MoveCount());
00772                 return true;
00773             }
00774             return false;
00775         }
00776     }
00777     return false;
00778 }
00779 
00780 void SgUctSearch::OnStartSearch()
00781 {
00782     m_mpiSynchronizer->OnStartSearch(*this);
00783 }
00784 
00785 void SgUctSearch::OnEndSearch()
00786 {
00787     m_mpiSynchronizer->OnEndSearch(*this);
00788 }
00789 
00790 /** Print time, mean, nodes searched, and PV */
00791 void SgUctSearch::PrintSearchProgress(double currTime) const
00792 {
00793     const int MAX_SEQ_PRINT_LENGTH = 15;
00794     const SgUctValue MIN_MOVE_COUNT = 10;
00795     SgUctValue rootMoveCount = m_tree.Root().MoveCount();
00796     SgUctValue rootMean = m_tree.Root().Mean();
00797     ostringstream out;
00798     const SgUctNode* current = &m_tree.Root();
00799     out << (format("%s | %.3f | %.0f ")
00800             % SgTime::Format(currTime, true) % rootMean % rootMoveCount);
00801     for (int i = 0; i <= MAX_SEQ_PRINT_LENGTH && current->HasChildren(); ++i)
00802     {
00803         current = FindBestChild(*current);
00804         if (current == 0 || current->MoveCount() < MIN_MOVE_COUNT)
00805             break;
00806         if (i == 0)
00807             out << "|";
00808         if (i < MAX_SEQ_PRINT_LENGTH)
00809             out << " " << MoveString(current->Move());
00810         else
00811             out << " *";
00812     }
00813     SgDebug() << out.str() << endl;
00814 }
00815 
00816 void SgUctSearch::OnSearchIteration(SgUctValue gameNumber,
00817                                     unsigned int threadId,
00818                                     const SgUctGameInfo& info)
00819 {
00820     const int DISPLAY_INTERVAL = 5;
00821 
00822     m_mpiSynchronizer->OnSearchIteration(*this, gameNumber, threadId, info);
00823     double currTime = m_timer.GetTime();
00824 
00825     if (threadId == 0 && currTime - m_lastScoreDisplayTime > DISPLAY_INTERVAL)
00826     {
00827         PrintSearchProgress(currTime);
00828         m_lastScoreDisplayTime = currTime;
00829     }
00830 }
00831 
00832 void SgUctSearch::PlayGame(SgUctThreadState& state, GlobalLock* lock)
00833 {
00834     state.m_isTreeOutOfMem = false;
00835     state.GameStart();
00836     SgUctGameInfo& info = state.m_gameInfo;
00837     info.Clear(m_numberPlayouts);
00838     bool isTerminal;
00839     bool abortInTree = ! PlayInTree(state, isTerminal);
00840 
00841     // The playout phase is always unlocked
00842     if (lock != 0)
00843         lock->unlock();
00844 
00845     if (!info.m_nodes.empty() && isTerminal)
00846     {
00847         const SgUctNode& terminalNode = *info.m_nodes.back();
00848         SgUctValue eval = state.Evaluate();
00849         if (eval > 0.6) 
00850             m_tree.SetProvenType(terminalNode, SG_PROVEN_WIN);
00851         else if (eval < 0.6)
00852             m_tree.SetProvenType(terminalNode, SG_PROVEN_LOSS);
00853         PropagateProvenStatus(info.m_nodes);
00854     }
00855 
00856     size_t nuMovesInTree = info.m_inTreeSequence.size();
00857 
00858     // Play some "fake" playouts if node is a proven node
00859     if (! info.m_nodes.empty() && info.m_nodes.back()->IsProven())
00860     {
00861         for (size_t i = 0; i < m_numberPlayouts; ++i)
00862         {
00863             info.m_sequence[i] = info.m_inTreeSequence;
00864             info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00865             SgUctValue eval = info.m_nodes.back()->IsProvenWin() ? 1 : 0;
00866             size_t nuMoves = info.m_sequence[i].size();
00867             if (nuMoves % 2 != 0)
00868                 eval = InverseEval(eval);
00869             info.m_aborted[i] = abortInTree || state.m_isTreeOutOfMem;
00870             info.m_eval[i] = eval;
00871         }
00872     }
00873     else 
00874     {
00875         state.StartPlayouts();
00876         for (size_t i = 0; i < m_numberPlayouts; ++i)
00877         {
00878             state.StartPlayout();
00879             info.m_sequence[i] = info.m_inTreeSequence;
00880             // skipRaveUpdate only used in playout phase
00881             info.m_skipRaveUpdate[i].assign(nuMovesInTree, false);
00882             bool abort = abortInTree || state.m_isTreeOutOfMem;
00883             if (! abort && ! isTerminal)
00884                 abort = ! PlayoutGame(state, i);
00885             SgUctValue eval;
00886             if (abort)
00887                 eval = UnknownEval();
00888             else
00889                 eval = state.Evaluate();
00890             size_t nuMoves = info.m_sequence[i].size();
00891             if (nuMoves % 2 != 0)
00892                 eval = InverseEval(eval);
00893             info.m_aborted[i] = abort;
00894             info.m_eval[i] = eval;
00895             state.EndPlayout();
00896             state.TakeBackPlayout(nuMoves - nuMovesInTree);
00897         }
00898     }
00899     state.TakeBackInTree(nuMovesInTree);
00900 
00901     // End of unlocked part if ! m_lockFree
00902     if (lock != 0)
00903         lock->lock();
00904 
00905     UpdateTree(info);
00906     if (m_rave)
00907         UpdateRaveValues(state);
00908     UpdateStatistics(info);
00909 }
00910 
00911 /** Backs up proven information. Last node of nodes is the newly
00912     proven node. */
00913 void SgUctSearch::PropagateProvenStatus(const vector<const SgUctNode*>& nodes)
00914 {
00915     if (nodes.size() <= 1) 
00916         return;
00917     size_t i = nodes.size() - 2;
00918     while (true)
00919     {
00920         const SgUctNode& parent = *nodes[i];
00921         SgUctProvenType type = SG_PROVEN_LOSS;
00922         for (SgUctChildIterator it(m_tree, parent); it; ++it)
00923         {
00924             const SgUctNode& child = *it;
00925             if (!child.IsProven())
00926                 type = SG_NOT_PROVEN;
00927             else if (child.IsProvenLoss())
00928             {
00929                 type = SG_PROVEN_WIN;
00930                 break;
00931             }
00932         }
00933         if (type == SG_NOT_PROVEN)
00934             break;
00935         else
00936             m_tree.SetProvenType(parent, type);
00937         if (i == 0)
00938             break;
00939         --i;
00940     }
00941 }
00942 
00943 /** Play game until it leaves the tree.
00944     @param state
00945     @param[out] isTerminal Was the sequence terminated because of a real
00946     terminal position (GenerateAllMoves() returned an empty list)?
00947     @return @c false, if game was aborted due to maximum length */
00948 bool SgUctSearch::PlayInTree(SgUctThreadState& state, bool& isTerminal)
00949 {
00950     vector<SgMove>& sequence = state.m_gameInfo.m_inTreeSequence;
00951     vector<const SgUctNode*>& nodes = state.m_gameInfo.m_nodes;
00952     const SgUctNode* root = &m_tree.Root();
00953     const SgUctNode* current = root;
00954     if (m_virtualLoss && m_numberThreads > 1)
00955         m_tree.AddVirtualLoss(*current);
00956     nodes.push_back(current);
00957     bool breakAfterSelect = false;
00958     isTerminal = false;
00959     bool deepenTree = false;
00960     while (true)
00961     {
00962         if (sequence.size() == m_maxGameLength)
00963             return false;
00964         if (current->IsProven())
00965             break;
00966         if (! current->HasChildren())
00967         {
00968             state.m_moves.clear();
00969             SgUctProvenType provenType = SG_NOT_PROVEN;
00970             state.GenerateAllMoves(0, state.m_moves, provenType);
00971             if (current == root)
00972                 ApplyRootFilter(state.m_moves);
00973             if (provenType != SG_NOT_PROVEN)
00974             {
00975                 m_tree.SetProvenType(*current, provenType);
00976                 PropagateProvenStatus(nodes);
00977                 break;
00978             }
00979             if (state.m_moves.empty())
00980             {
00981                 isTerminal = true;
00982                 break;
00983             }
00984             if (  deepenTree
00985                || current->MoveCount() >= m_expandThreshold
00986                )
00987             {
00988                 deepenTree = false;
00989                 ExpandNode(state, *current);
00990                 if (state.m_isTreeOutOfMem)
00991                     return true;
00992                 if (! deepenTree)
00993                     breakAfterSelect = true;
00994             }
00995             else
00996                 break;
00997         }
00998         else if (NeedToComputeKnowledge(current))
00999         {
01000             m_statistics.m_knowledge++;
01001             deepenTree = false;
01002             SgUctProvenType provenType = SG_NOT_PROVEN;
01003             bool truncate = state.GenerateAllMoves(current->KnowledgeCount(), 
01004                                                    state.m_moves,
01005                                                    provenType);
01006             if (current == root)
01007                 ApplyRootFilter(state.m_moves);
01008             CreateChildren(state, *current, truncate);
01009             if (provenType != SG_NOT_PROVEN)
01010             {
01011                 m_tree.SetProvenType(*current, provenType);
01012                 PropagateProvenStatus(nodes);
01013                 break;
01014             }
01015             if (state.m_moves.empty())
01016             {
01017                 isTerminal = true;
01018                 break;
01019             }
01020             if (state.m_isTreeOutOfMem)
01021                 return true;
01022             if (! deepenTree)
01023                 breakAfterSelect = true;
01024         }
01025         current = &SelectChild(state.m_randomizeCounter, *current);
01026         if (m_virtualLoss && m_numberThreads > 1)
01027             m_tree.AddVirtualLoss(*current);
01028         nodes.push_back(current);
01029         SgMove move = current->Move();
01030         state.Execute(move);
01031         sequence.push_back(move);
01032         if (breakAfterSelect)
01033             break;
01034     }
01035     return true;
01036 }
01037 
01038 /** Finish the game using GeneratePlayoutMove().
01039     @param state The thread state.
01040     @param playout The number of the playout.
01041     @return @c false if game was aborted */
01042 bool SgUctSearch::PlayoutGame(SgUctThreadState& state, std::size_t playout)
01043 {
01044     SgUctGameInfo& info = state.m_gameInfo;
01045     vector<SgMove>& sequence = info.m_sequence[playout];
01046     vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
01047     while (true)
01048     {
01049         if (sequence.size() == m_maxGameLength)
01050             return false;
01051         bool skipRave = false;
01052         SgMove move = state.GeneratePlayoutMove(skipRave);
01053         if (move == SG_NULLMOVE)
01054             break;
01055         state.ExecutePlayout(move);
01056         sequence.push_back(move);
01057         skipRaveUpdate.push_back(skipRave);
01058     }
01059     return true;
01060 }
01061 
01062 SgUctValue SgUctSearch::Search(SgUctValue maxGames, double maxTime,
01063                                vector<SgMove>& sequence,
01064                                const vector<SgMove>& rootFilter,
01065                                SgUctTree* initTree,
01066                                SgUctEarlyAbortParam* earlyAbort)
01067 {
01068     m_timer.Start();
01069     m_rootFilter = rootFilter;
01070     if (m_logGames)
01071     {
01072         m_log.open(m_mpiSynchronizer->ToNodeFilename(m_logFileName).c_str());
01073         m_log << "StartSearch maxGames=" << maxGames << '\n';
01074     }
01075     m_maxGames = maxGames;
01076     m_maxTime = maxTime;
01077     m_earlyAbort.reset(0);
01078     if (earlyAbort != 0)
01079         m_earlyAbort.reset(new SgUctEarlyAbortParam(*earlyAbort));
01080 
01081     for (size_t i = 0; i < m_threads.size(); ++i)
01082     {
01083         m_threads[i]->m_state->m_isSearchInitialized = false;
01084     }
01085     StartSearch(rootFilter, initTree);
01086     SgUctValue pruneMinCount = m_pruneMinCount;
01087     while (true)
01088     {
01089         m_isTreeOutOfMemory = false;
01090         SgSynchronizeThreadMemory();
01091         for (size_t i = 0; i < m_threads.size(); ++i)
01092             m_threads[i]->StartPlay();
01093         for (size_t i = 0; i < m_threads.size(); ++i)
01094             m_threads[i]->WaitPlayFinished();
01095         if (m_aborted || ! m_pruneFullTree)
01096             break;
01097         else
01098         {
01099             double startPruneTime = m_timer.GetTime();
01100             SgDebug() << "SgUctSearch: pruning nodes with count < "
01101                   << pruneMinCount << " (at time " << fixed << setprecision(1)
01102                   << startPruneTime << ")\n";
01103             SgUctTree& tempTree = GetTempTree();
01104             m_tree.CopyPruneLowCount(tempTree, pruneMinCount, true);
01105             int prunedSizePercentage =
01106                 static_cast<int>(tempTree.NuNodes() * 100 / m_tree.NuNodes());
01107             SgDebug() << "SgUctSearch: pruned size: " << tempTree.NuNodes()
01108                       << " (" << prunedSizePercentage << "%) time: "
01109                       << (m_timer.GetTime() - startPruneTime) << "\n";
01110             if (prunedSizePercentage > 50)
01111                 pruneMinCount *= 2;
01112             else
01113                  pruneMinCount = m_pruneMinCount; 
01114             m_tree.Swap(tempTree);
01115         }
01116     }
01117     EndSearch();
01118     m_statistics.m_time = m_timer.GetTime();
01119     if (m_statistics.m_time > numeric_limits<double>::epsilon())
01120         m_statistics.m_gamesPerSecond = GamesPlayed() / m_statistics.m_time;
01121     if (m_logGames)
01122         m_log.close();
01123     FindBestSequence(sequence);
01124     return (m_tree.Root().MoveCount() > 0) ? (SgUctValue)m_tree.Root().Mean() : (SgUctValue)0.5;
01125 }
01126 
01127 /** Loop invoked by each thread for playing games. */
01128 void SgUctSearch::SearchLoop(SgUctThreadState& state, GlobalLock* lock)
01129 {
01130     if (! state.m_isSearchInitialized)
01131     {
01132         OnThreadStartSearch(state);
01133         state.m_isSearchInitialized = true;
01134     }
01135 
01136     if (NumberThreads() == 1 || m_lockFree)
01137         lock = 0;
01138     if (lock != 0)
01139         lock->lock();
01140     state.m_isTreeOutOfMem = false;
01141     while (! state.m_isTreeOutOfMem)
01142     {
01143         PlayGame(state, lock);
01144         OnSearchIteration(m_numberGames + 1, state.m_threadId,
01145                           state.m_gameInfo);
01146         if (m_logGames)
01147             m_log << SummaryLine(state.m_gameInfo) << '\n';
01148         ++m_numberGames;
01149         if (m_isTreeOutOfMemory)
01150             break;
01151         if (m_aborted || CheckAbortSearch(state))
01152         {
01153             m_aborted = true;
01154             SgSynchronizeThreadMemory();
01155             break;
01156         }
01157     }
01158     if (lock != 0)
01159         lock->unlock();
01160 
01161     m_searchLoopFinished->wait();
01162     if (m_aborted || ! m_pruneFullTree)
01163         OnThreadEndSearch(state);
01164 }
01165 
01166 void SgUctSearch::OnThreadStartSearch(SgUctThreadState& state)
01167 {
01168     m_mpiSynchronizer->OnThreadStartSearch(*this, state);
01169 }
01170 
01171 void SgUctSearch::OnThreadEndSearch(SgUctThreadState& state)
01172 {
01173     m_mpiSynchronizer->OnThreadEndSearch(*this, state);
01174 }
01175 
01176 SgPoint SgUctSearch::SearchOnePly(SgUctValue maxGames, double maxTime,
01177                                   SgUctValue& value)
01178 {
01179     if (m_threads.size() == 0)
01180         CreateThreads();
01181     OnStartSearch();
01182     // SearchOnePly is not multi-threaded.
01183     // It uses the state of the first thread.
01184     SgUctThreadState& state = ThreadState(0);
01185     state.StartSearch();
01186     vector<SgUctMoveInfo> moves;
01187     SgUctProvenType provenType;
01188     state.GameStart();
01189     state.GenerateAllMoves(0, moves, provenType);
01190     vector<SgUctStatistics> statistics(moves.size());
01191     SgUctValue games = 0;
01192     m_timer.Start();
01193     SgUctGameInfo& info = state.m_gameInfo;
01194     while (games < maxGames && m_timer.GetTime() < maxTime && ! SgUserAbort())
01195     {
01196         for (size_t i = 0; i < moves.size(); ++i)
01197         {
01198             state.GameStart();
01199             info.Clear(1);
01200             SgMove move = moves[i].m_move;
01201             state.Execute(move);
01202             info.m_inTreeSequence.push_back(move);
01203             info.m_sequence[0].push_back(move);
01204             info.m_skipRaveUpdate[0].push_back(false);
01205             state.StartPlayouts();
01206             state.StartPlayout();
01207             bool abortGame = ! PlayoutGame(state, 0);
01208             SgUctValue eval;
01209             if (abortGame)
01210                 eval = UnknownEval();
01211             else
01212                 eval = state.Evaluate();
01213             state.EndPlayout();
01214             state.TakeBackPlayout(info.m_sequence[0].size() - 1);
01215             state.TakeBackInTree(1);
01216             statistics[i].Add(info.m_sequence[0].size() % 2 == 0 ?
01217                               eval : InverseEval(eval));
01218             OnSearchIteration(games + 1, 0, info);
01219             games += 1;
01220         }
01221     }
01222     SgMove bestMove = SG_NULLMOVE;
01223     for (size_t i = 0; i < moves.size(); ++i)
01224     {
01225         SgDebug() << MoveString(moves[i].m_move) 
01226                   << ' ' << statistics[i].Mean() << '\n';
01227         if (bestMove == SG_NULLMOVE || statistics[i].Mean() > value)
01228         {
01229             bestMove = moves[i].m_move;
01230             value = statistics[i].Mean();
01231         }
01232     }
01233     return bestMove;
01234 }
01235 
01236 const SgUctNode& SgUctSearch::SelectChild(int& randomizeCounter, 
01237                                           const SgUctNode& node)
01238 {
01239     bool useRave = m_rave;
01240     if (m_randomizeRaveFrequency > 0 && --randomizeCounter == 0)
01241     {
01242         useRave = false;
01243         randomizeCounter = m_randomizeRaveFrequency;
01244     }
01245     SG_ASSERT(node.HasChildren());
01246     SgUctValue posCount = node.PosCount();
01247     int virtualLossCount = node.VirtualLossCount();
01248     if (virtualLossCount > 1)
01249     {
01250         // Note: must remove the virtual loss already added to
01251         // node for the current thread.
01252         posCount += SgUctValue(virtualLossCount - 1);
01253     }
01254 
01255     if (posCount == 0)
01256         // If position count is zero, return first child
01257         return *SgUctChildIterator(m_tree, node);
01258     SgUctValue logPosCount = Log(posCount);
01259     const SgUctNode* bestChild = 0;
01260     SgUctValue bestUpperBound = 0;
01261     const SgUctValue epsilon = SgUctValue(1e-7);
01262     for (SgUctChildIterator it(m_tree, node); it; ++it)
01263     {
01264         const SgUctNode& child = *it;
01265         if (!child.IsProvenWin()) // Avoid losing moves
01266         {
01267             SgUctValue bound = GetBound(useRave, logPosCount, child);
01268             // Compare bound to best bound using a not too small epsilon
01269             // because the unit tests rely on the fact that the first child is
01270             // chosen if children have the same bounds and on some platforms
01271             // the result of the comparison is not well-defined and depends on
01272             // the compiler settings and the type of SgUctValue even if count
01273             // and value of the children are exactly the same.
01274             if (bestChild == 0 || bound > bestUpperBound + epsilon)
01275             {
01276                 bestChild = &child;
01277                 bestUpperBound = bound;
01278             }
01279         }
01280     }
01281     if (bestChild != 0)
01282         return *bestChild;
01283     // It can happen with multiple threads that all children are losing
01284     // in this state but this thread got in here before that information
01285     // was propagated up the tree. So just return the first child
01286     // in this case.
01287     return *node.FirstChild();
01288 }
01289 
01290 void SgUctSearch::SetNumberThreads(unsigned int n)
01291 {
01292     SG_ASSERT(n >= 1);
01293     if (m_numberThreads == n)
01294         return;
01295     m_numberThreads = n;
01296     CreateThreads();
01297 }
01298 
01299 void SgUctSearch::SetRave(bool enable)
01300 {
01301     if (enable && m_moveRange <= 0)
01302         throw SgException("RAVE not supported for this game");
01303     m_rave = enable;
01304 }
01305 
01306 void SgUctSearch::SetThreadStateFactory(SgUctThreadStateFactory* factory)
01307 {
01308     SG_ASSERT(m_threadStateFactory.get() == 0);
01309     m_threadStateFactory.reset(factory);
01310     DeleteThreads();
01311     // Don't create states here, because this function could be called in the
01312     // constructor of the subclass, and the factory passes the search (which
01313     // is not fully constructed) as an argument to the Create() function
01314 }
01315 
01316 void SgUctSearch::StartSearch(const vector<SgMove>& rootFilter,
01317                               SgUctTree* initTree)
01318 {
01319     if (m_threads.size() == 0)
01320         CreateThreads();
01321     if (m_numberThreads > 1 && SgTime::DefaultMode() == SG_TIME_CPU)
01322         // Using CPU time with multiple threads makes the measured time
01323         // and games/sec not very meaningful; the total cputime is not equal
01324         // to the total real time, even if there is no other load on the
01325         // machine, because the time, while threads are waiting for a lock
01326         // does not contribute to the cputime.
01327         SgWarning() << "SgUctSearch: using cpu time with multiple threads\n";
01328     m_raveWeightParam1 = (SgUctValue)(1.0 / m_raveWeightInitial);
01329     m_raveWeightParam2 = (SgUctValue)(1.0 / m_raveWeightFinal);
01330     if (initTree == 0)
01331         m_tree.Clear();
01332     else
01333     {
01334         m_tree.Swap(*initTree);
01335         if (m_tree.HasCapacity(0, m_tree.Root().NuChildren()))
01336             m_tree.ApplyFilter(0, m_tree.Root(), rootFilter);
01337         else
01338             SgWarning() <<
01339                 "SgUctSearch: "
01340                 "root filter not applied (tree reached maximum size)\n";
01341     }
01342     m_statistics.Clear();
01343     m_aborted = false;
01344     m_wasEarlyAbort = false;
01345     m_checkTimeInterval = 1;
01346     m_numberGames = 0;
01347     m_lastScoreDisplayTime = m_timer.GetTime();
01348     OnStartSearch();
01349     
01350     m_nextCheckTime = (SgUctValue)m_checkTimeInterval;
01351     m_startRootMoveCount = m_tree.Root().MoveCount();
01352 
01353     for (unsigned int i = 0; i < m_threads.size(); ++i)
01354     {
01355         SgUctThreadState& state = ThreadState(i);
01356         state.m_randomizeCounter = m_randomizeRaveFrequency;
01357         state.StartSearch();
01358     }
01359 }
01360 
01361 void SgUctSearch::EndSearch()
01362 {
01363     OnEndSearch();
01364 }
01365 
01366 string SgUctSearch::SummaryLine(const SgUctGameInfo& info) const
01367 {
01368     ostringstream buffer;
01369     const vector<const SgUctNode*>& nodes = info.m_nodes;
01370     for (size_t i = 1; i < nodes.size(); ++i)
01371     {
01372         const SgUctNode* node = nodes[i];
01373         SgMove move = node->Move();
01374         buffer << ' ' << MoveString(move) << " (" << fixed << setprecision(2)
01375                << node->Mean() << ',' << node->MoveCount() << ')';
01376     }
01377     for (size_t i = 0; i < info.m_eval.size(); ++i)
01378         buffer << ' ' << fixed << setprecision(2) << info.m_eval[i];
01379     return buffer.str();
01380 }
01381 
01382 void SgUctSearch::UpdateCheckTimeInterval(double time)
01383 {
01384     if (time < numeric_limits<double>::epsilon())
01385         return;
01386     // Dynamically update m_checkTimeInterval (see comment at definition of
01387     // m_checkTimeInterval)
01388     double wantedTimeDiff = (m_maxTime > 1 ? 0.1 : 0.1 * m_maxTime);
01389     if (time < wantedTimeDiff / 10)
01390     {
01391         // Computing games per second might be unreliable for small times
01392         m_checkTimeInterval *= 2;
01393         return;
01394     }
01395     m_statistics.m_gamesPerSecond = GamesPlayed() / time;
01396     double gamesPerSecondPerThread =
01397         m_statistics.m_gamesPerSecond / double(m_numberThreads);
01398     m_checkTimeInterval = SgUctValue(wantedTimeDiff * gamesPerSecondPerThread);
01399     if (m_checkTimeInterval == 0)
01400         m_checkTimeInterval = 1;
01401 }
01402 
01403 /** Update the RAVE values in the tree for both players after a game was
01404     played.
01405     @see SgUctSearch::Rave() */
01406 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state)
01407 {
01408     for (size_t i = 0; i < m_numberPlayouts; ++i)
01409         UpdateRaveValues(state, i);
01410 }
01411 
01412 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01413                                    std::size_t playout)
01414 {
01415     SgUctGameInfo& info = state.m_gameInfo;
01416     const vector<SgMove>& sequence = info.m_sequence[playout];
01417     if (sequence.size() == 0)
01418         return;
01419     SG_ASSERT(m_moveRange > 0);
01420     size_t* firstPlay = state.m_firstPlay.get();
01421     size_t* firstPlayOpp = state.m_firstPlayOpp.get();
01422     fill_n(firstPlay, m_moveRange, numeric_limits<size_t>::max());
01423     fill_n(firstPlayOpp, m_moveRange, numeric_limits<size_t>::max());
01424     const vector<const SgUctNode*>& nodes = info.m_nodes;
01425     const vector<bool>& skipRaveUpdate = info.m_skipRaveUpdate[playout];
01426     SgUctValue eval = info.m_eval[playout];
01427     SgUctValue invEval = InverseEval(eval);
01428     size_t nuNodes = nodes.size();
01429     size_t i = sequence.size() - 1;
01430     bool opp = (i % 2 != 0);
01431 
01432     // Update firstPlay, firstPlayOpp arrays using playout moves
01433     for ( ; i >= nuNodes; --i)
01434     {
01435         SG_ASSERT(i < skipRaveUpdate.size());
01436         SG_ASSERT(i < sequence.size());
01437         if (! skipRaveUpdate[i])
01438         {
01439             SgMove mv = sequence[i];
01440             size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01441             if (i < first)
01442                 first = i;
01443         }
01444         opp = ! opp;
01445     }
01446 
01447     while (true)
01448     {
01449         SG_ASSERT(i < skipRaveUpdate.size());
01450         SG_ASSERT(i < sequence.size());
01451         // skipRaveUpdate currently not used in in-tree phase
01452         SG_ASSERT(i >= info.m_inTreeSequence.size() || ! skipRaveUpdate[i]);
01453         if (! skipRaveUpdate[i])
01454         {
01455             SgMove mv = sequence[i];
01456             size_t& first = (opp ? firstPlayOpp[mv] : firstPlay[mv]);
01457             if (i < first)
01458                 first = i;
01459             if (opp)
01460                 UpdateRaveValues(state, playout, invEval, i,
01461                                  firstPlayOpp, firstPlay);
01462             else
01463                 UpdateRaveValues(state, playout, eval, i,
01464                                  firstPlay, firstPlayOpp);
01465         }
01466         if (i == 0)
01467             break;
01468         --i;
01469         opp = ! opp;
01470     }
01471 }
01472 
01473 void SgUctSearch::UpdateRaveValues(SgUctThreadState& state,
01474                                    std::size_t playout, SgUctValue eval,
01475                                    std::size_t i,
01476                                    const std::size_t firstPlay[],
01477                                    const std::size_t firstPlayOpp[])
01478 {
01479     SG_ASSERT(i < state.m_gameInfo.m_nodes.size());
01480     const SgUctNode* node = state.m_gameInfo.m_nodes[i];
01481     if (! node->HasChildren())
01482         return;
01483     size_t len = state.m_gameInfo.m_sequence[playout].size();
01484     for (SgUctChildIterator it(m_tree, *node); it; ++it)
01485     {
01486         const SgUctNode& child = *it;
01487         SgMove mv = child.Move();
01488         size_t first = firstPlay[mv];
01489         SG_ASSERT(first >= i);
01490         if (first == numeric_limits<size_t>::max())
01491             continue;
01492         if  (m_raveCheckSame && SgUtil::InRange(firstPlayOpp[mv], i, first))
01493             continue;
01494         SgUctValue weight;
01495         if (m_weightRaveUpdates)
01496             weight = 2 - SgUctValue(first - i) / SgUctValue(len - i);
01497         else
01498             weight = 1;
01499         m_tree.AddRaveValue(child, eval, weight);
01500     }
01501 }
01502 
01503 void SgUctSearch::UpdateStatistics(const SgUctGameInfo& info)
01504 {
01505     m_statistics.m_movesInTree.Add(
01506                             static_cast<float>(info.m_inTreeSequence.size()));
01507     for (size_t i = 0; i < m_numberPlayouts; ++i)
01508     {
01509         m_statistics.m_gameLength.Add(
01510                                static_cast<float>(info.m_sequence[i].size()));
01511         m_statistics.m_aborted.Add(info.m_aborted[i] ? 1.f : 0.f);
01512     }
01513 }
01514 
01515 void SgUctSearch::UpdateTree(const SgUctGameInfo& info)
01516 {
01517     SgUctValue eval = 0;
01518     for (size_t i = 0; i < m_numberPlayouts; ++i)
01519         eval += info.m_eval[i];
01520     eval /= SgUctValue(m_numberPlayouts);
01521     SgUctValue inverseEval = InverseEval(eval);
01522     const vector<const SgUctNode*>& nodes = info.m_nodes;
01523     SgUctValue count = SgUctValue(m_numberPlayouts);
01524     for (size_t i = 0; i < nodes.size(); ++i)
01525     {
01526         const SgUctNode& node = *nodes[i];
01527         const SgUctNode* father = (i > 0 ? nodes[i - 1] : 0);
01528         m_tree.AddGameResults(node, father, i % 2 == 0 ? eval : inverseEval,
01529                               count);
01530         // Remove the virtual loss
01531         if (m_virtualLoss && m_numberThreads > 1)
01532             m_tree.RemoveVirtualLoss(node);
01533     }
01534 }
01535 
01536 void SgUctSearch::WriteStatistics(ostream& out) const
01537 {
01538     out << SgWriteLabel("Count") << m_tree.Root().MoveCount() << '\n'
01539         << SgWriteLabel("GamesPlayed") << GamesPlayed() << '\n'
01540         << SgWriteLabel("Nodes") << m_tree.NuNodes() << '\n';
01541     if (!m_knowledgeThreshold.empty())
01542         out << SgWriteLabel("Knowledge") 
01543             << m_statistics.m_knowledge << " (" << fixed << setprecision(1) 
01544             << m_statistics.m_knowledge * 100.0 / m_tree.Root().MoveCount()
01545             << "%)\n";
01546     m_statistics.Write(out);
01547     m_mpiSynchronizer->WriteStatistics(out);
01548 }
01549 
01550 //----------------------------------------------------------------------------


Sun Mar 13 2011 Doxygen 1.7.1