00001 //---------------------------------------------------------------------------- 00002 /** @file SgUctTree.h 00003 Class SgUctTree and strongly related classes. */ 00004 //---------------------------------------------------------------------------- 00005 00006 #ifndef SG_UCTTREE_H 00007 #define SG_UCTTREE_H 00008 00009 #include <limits> 00010 #include <stack> 00011 #include <boost/shared_ptr.hpp> 00012 #include "SgMove.h" 00013 #include "SgStatistics.h" 00014 #include "SgStatisticsVlt.h" 00015 #include "SgUctValue.h" 00016 00017 class SgTimer; 00018 00019 //---------------------------------------------------------------------------- 00020 00021 typedef SgStatisticsBase<float,std::size_t> SgUctStatisticsBase; 00022 00023 typedef SgStatisticsVltBase<float,std::size_t> SgUctStatisticsBaseVolatile; 00024 00025 //---------------------------------------------------------------------------- 00026 00027 /** Used for node creation. */ 00028 struct SgUctMoveInfo 00029 { 00030 /** Move for the child. */ 00031 SgMove m_move; 00032 00033 /** Value of node after node is created. 00034 Value is from child's perspective, so the value stored here 00035 must be the inverse of the evaluation from the parent's 00036 perspective. */ 00037 SgUctValue m_value; 00038 00039 /** Count of node after node is created. */ 00040 SgUctValue m_count; 00041 00042 /** Rave value of move after node is created from viewpoint of 00043 parent node. 00044 Value should not be inverted to child's perspective. */ 00045 SgUctValue m_raveValue; 00046 00047 /** Rave count of move after node is created. */ 00048 SgUctValue m_raveCount; 00049 00050 SgUctMoveInfo(); 00051 00052 SgUctMoveInfo(SgMove move); 00053 00054 SgUctMoveInfo(SgMove move, SgUctValue value, SgUctValue count, 00055 SgUctValue raveValue, SgUctValue raveCount); 00056 }; 00057 00058 inline SgUctMoveInfo::SgUctMoveInfo() 00059 : m_value(0), 00060 m_count(0), 00061 m_raveValue(0), 00062 m_raveCount(0) 00063 { 00064 } 00065 00066 inline SgUctMoveInfo::SgUctMoveInfo(SgMove move) 00067 : m_move(move), 00068 m_value(0), 00069 m_count(0), 00070 m_raveValue(0), 00071 m_raveCount(0) 00072 { 00073 } 00074 00075 inline SgUctMoveInfo::SgUctMoveInfo(SgMove move, SgUctValue value, SgUctValue count, 00076 SgUctValue raveValue, SgUctValue raveCount) 00077 : m_move(move), 00078 m_value(value), 00079 m_count(count), 00080 m_raveValue(raveValue), 00081 m_raveCount(raveCount) 00082 { 00083 } 00084 00085 //---------------------------------------------------------------------------- 00086 00087 /** Types of proven nodes. */ 00088 typedef enum 00089 { 00090 /** Node is not a proven win or loss. */ 00091 SG_NOT_PROVEN, 00092 00093 /** Node is a proven win. */ 00094 SG_PROVEN_WIN, 00095 00096 /** Node is a proven loss. */ 00097 SG_PROVEN_LOSS 00098 00099 } SgUctProvenType; 00100 00101 //---------------------------------------------------------------------------- 00102 00103 /** Node used in SgUctTree. 00104 All data members are declared as volatile to avoid that the compiler 00105 re-orders writes, which can break assumptions made by SgUctSearch in 00106 lock-free mode (see @ref sguctsearchlockfree). For example, the search 00107 relies on the fact that m_firstChild is valid, if m_nuChildren is greater 00108 zero or that the mean value of the move and RAVE value statistics is valid 00109 if the corresponding count is greater zero. 00110 @ingroup sguctgroup */ 00111 class SgUctNode 00112 { 00113 public: 00114 /** Initializes node with given move, value and count. */ 00115 SgUctNode(const SgUctMoveInfo& info); 00116 00117 /** Add game result. 00118 @param eval The game result (e.g. score or 0/1 for win loss) */ 00119 void AddGameResult(SgUctValue eval); 00120 00121 /** Adds a game result count times. */ 00122 void AddGameResults(SgUctValue eval, SgUctValue count); 00123 00124 /** Add other nodes results to this node's. */ 00125 void MergeResults(const SgUctNode& node); 00126 00127 /** Removes a game result. 00128 @param eval The game result (e.g. score or 0/1 for win loss) */ 00129 void RemoveGameResult(SgUctValue eval); 00130 00131 /** Removes a game result count times. */ 00132 void RemoveGameResults(SgUctValue eval, SgUctValue count); 00133 00134 /** Number of times this node was visited. 00135 This corresponds to the sum of MoveCount() of all children. 00136 It can be different from MoveCount() of this position, if prior 00137 knowledge initialization of the children is used. */ 00138 SgUctValue PosCount() const; 00139 00140 /** Number of times the move leading to this position was chosen. 00141 This count will be different from PosCount(), if prior knowledge 00142 initialization is used. */ 00143 SgUctValue MoveCount() const; 00144 00145 /** Get first child. 00146 @note This information is an implementation detail of how SgUctTree 00147 manages nodes. Use SgUctChildIterator to access children nodes. */ 00148 const SgUctNode* FirstChild() const; 00149 00150 /** Does the node have at least one child? */ 00151 bool HasChildren() const; 00152 00153 /** Average game result. 00154 Requires: HasMean() */ 00155 SgUctValue Mean() const; 00156 00157 /** True, if mean value is defined (move count not zero) */ 00158 bool HasMean() const; 00159 00160 /** Get number of children. 00161 @note This information is an implementation detail of how SgUctTree 00162 manages nodes. Use SgUctChildIterator to access children nodes. */ 00163 int NuChildren() const; 00164 00165 /** See FirstChild() */ 00166 void SetFirstChild(const SgUctNode* child); 00167 00168 /** See NuChildren() */ 00169 void SetNuChildren(int nuChildren); 00170 00171 /** Increment the position count. 00172 See PosCount() */ 00173 void IncPosCount(); 00174 00175 /** Increment the position count. 00176 See PosCount() */ 00177 void IncPosCount(SgUctValue count); 00178 00179 /** Decrement the position count. 00180 See PosCount() */ 00181 void DecPosCount(); 00182 00183 /** Decrement the position count. 00184 See PosCount() */ 00185 void DecPosCount(SgUctValue count); 00186 00187 void SetPosCount(SgUctValue value); 00188 00189 /** Initialize value with prior knowledge. */ 00190 void InitializeValue(SgUctValue value, SgUctValue count); 00191 00192 /** Copy data from other node. 00193 Copies all data, apart from the children information (first child 00194 and number of children). */ 00195 void CopyDataFrom(const SgUctNode& node); 00196 00197 /** Get move. 00198 Requires: Node has a move (is not root node) */ 00199 SgMove Move() const; 00200 00201 /** Get RAVE count. 00202 @see SgUctSearch::Rave(). */ 00203 SgUctValue RaveCount() const; 00204 00205 /** Get RAVE mean value. 00206 Requires: HasRaveValue() 00207 @see SgUctSearch::Rave(). */ 00208 SgUctValue RaveValue() const; 00209 00210 bool HasRaveValue() const; 00211 00212 /** Add a game result value to the RAVE value. 00213 @see SgUctSearch::Rave(). */ 00214 void AddRaveValue(SgUctValue value, SgUctValue weight); 00215 00216 /** Removes a rave result. */ 00217 void RemoveRaveValue(SgUctValue value); 00218 00219 void RemoveRaveValue(SgUctValue value, SgUctValue weight); 00220 00221 /** Initialize RAVE value with prior knowledge. */ 00222 void InitializeRaveValue(SgUctValue value, SgUctValue count); 00223 00224 int VirtualLossCount() const; 00225 00226 void AddVirtualLoss(); 00227 00228 void RemoveVirtualLoss(); 00229 00230 /** Returns the last time knowledge was computed. */ 00231 SgUctValue KnowledgeCount() const; 00232 00233 /** Set that knowledge has been computed at count. */ 00234 void SetKnowledgeCount(SgUctValue count); 00235 00236 /** Returns true if node is a proven node. */ 00237 bool IsProven() const; 00238 00239 bool IsProvenWin() const; 00240 00241 bool IsProvenLoss() const; 00242 00243 SgUctProvenType ProvenType() const; 00244 00245 void SetProvenType(SgUctProvenType type); 00246 00247 private: 00248 SgUctStatisticsVolatile m_statistics; 00249 00250 const SgUctNode* volatile m_firstChild; 00251 00252 volatile int m_nuChildren; 00253 00254 volatile SgMove m_move; 00255 00256 /** RAVE statistics. 00257 Uses double for count to allow adding fractional values if RAVE 00258 updates are weighted. */ 00259 SgUctStatisticsVolatile m_raveValue; 00260 00261 volatile SgUctValue m_posCount; 00262 00263 volatile SgUctValue m_knowledgeCount; 00264 00265 volatile SgUctProvenType m_provenType; 00266 00267 volatile int m_virtualLossCount; 00268 }; 00269 00270 inline SgUctNode::SgUctNode(const SgUctMoveInfo& info) 00271 : m_statistics(info.m_value, info.m_count), 00272 m_nuChildren(0), 00273 m_move(info.m_move), 00274 m_raveValue(info.m_raveValue, info.m_raveCount), 00275 m_posCount(0), 00276 m_knowledgeCount(0), 00277 m_provenType(SG_NOT_PROVEN), 00278 m_virtualLossCount(0) 00279 { 00280 // m_firstChild is not initialized, only defined if m_nuChildren > 0 00281 } 00282 00283 inline void SgUctNode::AddGameResult(SgUctValue eval) 00284 { 00285 m_statistics.Add(eval); 00286 } 00287 00288 inline void SgUctNode::AddGameResults(SgUctValue eval, SgUctValue count) 00289 { 00290 m_statistics.Add(eval, count); 00291 } 00292 00293 inline void SgUctNode::MergeResults(const SgUctNode& node) 00294 { 00295 if (node.m_statistics.IsDefined()) 00296 m_statistics.Add(node.m_statistics.Mean(), node.m_statistics.Count()); 00297 if (node.m_raveValue.IsDefined()) 00298 m_raveValue.Add(node.m_raveValue.Mean(), node.m_raveValue.Count()); 00299 } 00300 00301 inline void SgUctNode::RemoveGameResult(SgUctValue eval) 00302 { 00303 m_statistics.Remove(eval); 00304 } 00305 00306 inline void SgUctNode::RemoveGameResults(SgUctValue eval, SgUctValue count) 00307 { 00308 m_statistics.Remove(eval, count); 00309 } 00310 00311 inline void SgUctNode::AddRaveValue(SgUctValue value, SgUctValue weight) 00312 { 00313 m_raveValue.Add(value, weight); 00314 } 00315 00316 inline void SgUctNode::RemoveRaveValue(SgUctValue value) 00317 { 00318 m_raveValue.Remove(value); 00319 } 00320 00321 inline void SgUctNode::RemoveRaveValue(SgUctValue value, SgUctValue weight) 00322 { 00323 m_raveValue.Remove(value, weight); 00324 } 00325 00326 inline void SgUctNode::CopyDataFrom(const SgUctNode& node) 00327 { 00328 m_statistics = node.m_statistics; 00329 m_move = node.m_move; 00330 m_raveValue = node.m_raveValue; 00331 m_posCount = node.m_posCount; 00332 m_knowledgeCount = node.m_knowledgeCount; 00333 m_provenType = node.m_provenType; 00334 m_virtualLossCount = node.m_virtualLossCount; 00335 } 00336 00337 inline const SgUctNode* SgUctNode::FirstChild() const 00338 { 00339 SG_ASSERT(HasChildren()); // Otherwise m_firstChild is undefined 00340 return m_firstChild; 00341 } 00342 00343 inline bool SgUctNode::HasChildren() const 00344 { 00345 return (m_nuChildren > 0); 00346 } 00347 00348 inline bool SgUctNode::HasMean() const 00349 { 00350 return m_statistics.IsDefined(); 00351 } 00352 00353 inline bool SgUctNode::HasRaveValue() const 00354 { 00355 return m_raveValue.IsDefined(); 00356 } 00357 00358 inline int SgUctNode::VirtualLossCount() const 00359 { 00360 return m_virtualLossCount; 00361 } 00362 00363 inline void SgUctNode::AddVirtualLoss() 00364 { 00365 m_virtualLossCount++; 00366 } 00367 00368 inline void SgUctNode::RemoveVirtualLoss() 00369 { 00370 // May become negative with lock-free multithreading. Negative 00371 // values are allowed so that errors introduced by multithreading 00372 // will tend to average out. 00373 m_virtualLossCount--; 00374 } 00375 00376 inline void SgUctNode::IncPosCount() 00377 { 00378 ++m_posCount; 00379 } 00380 00381 inline void SgUctNode::IncPosCount(SgUctValue count) 00382 { 00383 m_posCount += count; 00384 } 00385 00386 inline void SgUctNode::DecPosCount() 00387 { 00388 SgUctValue posCount = m_posCount; 00389 if (posCount > 0) 00390 { 00391 m_posCount = posCount - 1; 00392 } 00393 } 00394 00395 inline void SgUctNode::DecPosCount(SgUctValue count) 00396 { 00397 SgUctValue posCount = m_posCount; 00398 if (posCount >= count) 00399 { 00400 m_posCount = posCount - count; 00401 } 00402 } 00403 00404 inline void SgUctNode::InitializeValue(SgUctValue value, SgUctValue count) 00405 { 00406 m_statistics.Initialize(value, count); 00407 } 00408 00409 inline void SgUctNode::InitializeRaveValue(SgUctValue value, SgUctValue count) 00410 { 00411 m_raveValue.Initialize(value, count); 00412 } 00413 00414 inline SgUctValue SgUctNode::Mean() const 00415 { 00416 return m_statistics.Mean(); 00417 } 00418 00419 inline SgMove SgUctNode::Move() const 00420 { 00421 SG_ASSERT(m_move != SG_NULLMOVE); 00422 return m_move; 00423 } 00424 00425 inline SgUctValue SgUctNode::MoveCount() const 00426 { 00427 return m_statistics.Count(); 00428 } 00429 00430 inline int SgUctNode::NuChildren() const 00431 { 00432 return m_nuChildren; 00433 } 00434 00435 inline SgUctValue SgUctNode::PosCount() const 00436 { 00437 return m_posCount; 00438 } 00439 00440 inline SgUctValue SgUctNode::RaveCount() const 00441 { 00442 return m_raveValue.Count(); 00443 } 00444 00445 inline SgUctValue SgUctNode::RaveValue() const 00446 { 00447 return m_raveValue.Mean(); 00448 } 00449 00450 inline void SgUctNode::SetFirstChild(const SgUctNode* child) 00451 { 00452 m_firstChild = child; 00453 } 00454 00455 inline void SgUctNode::SetNuChildren(int nuChildren) 00456 { 00457 SG_ASSERT(nuChildren >= 0); 00458 m_nuChildren = nuChildren; 00459 } 00460 00461 inline void SgUctNode::SetPosCount(SgUctValue value) 00462 { 00463 m_posCount = value; 00464 } 00465 00466 inline SgUctValue SgUctNode::KnowledgeCount() const 00467 { 00468 return m_knowledgeCount; 00469 } 00470 00471 inline void SgUctNode::SetKnowledgeCount(SgUctValue count) 00472 { 00473 m_knowledgeCount = count; 00474 } 00475 00476 inline bool SgUctNode::IsProven() const 00477 { 00478 return m_provenType != SG_NOT_PROVEN; 00479 } 00480 00481 inline bool SgUctNode::IsProvenWin() const 00482 { 00483 return m_provenType == SG_PROVEN_WIN; 00484 } 00485 00486 inline bool SgUctNode::IsProvenLoss() const 00487 { 00488 return m_provenType == SG_PROVEN_LOSS; 00489 } 00490 00491 inline SgUctProvenType SgUctNode::ProvenType() const 00492 { 00493 return m_provenType; 00494 } 00495 00496 inline void SgUctNode::SetProvenType(SgUctProvenType type) 00497 { 00498 m_provenType = type; 00499 } 00500 00501 //---------------------------------------------------------------------------- 00502 00503 /** Allocater for nodes used in the implementation of SgUctTree. 00504 Each thread has its own node allocator to allow lock-free usage of 00505 SgUctTree. 00506 @ingroup sguctgroup */ 00507 class SgUctAllocator 00508 { 00509 public: 00510 SgUctAllocator(); 00511 00512 ~SgUctAllocator(); 00513 00514 void Clear(); 00515 00516 /** Does the allocator have the capacity for n more nodes? */ 00517 bool HasCapacity(std::size_t n) const; 00518 00519 std::size_t NuNodes() const; 00520 00521 std::size_t MaxNodes() const; 00522 00523 void SetMaxNodes(std::size_t maxNodes); 00524 00525 /** Check if allocator contains node. 00526 This function uses pointer comparisons. Since the result of 00527 comparisons for pointers to elements in different containers 00528 is platform-dependent, it is only guaranteed that it returns true, 00529 if not node belongs to the allocator, but not that it returns false 00530 for nodes not in the allocator. */ 00531 bool Contains(const SgUctNode& node) const; 00532 00533 const SgUctNode* Start() const; 00534 00535 SgUctNode* Finish(); 00536 00537 const SgUctNode* Finish() const; 00538 00539 /** Create a new node at the end of the storage. 00540 REQUIRES: HasCapacity(1) 00541 @param move The constructor argument. 00542 @return A pointer to new newly created node. */ 00543 SgUctNode* CreateOne(SgMove move); 00544 00545 /** Create a number of new nodes with a given list of moves at the end of 00546 the storage. Returns the sum of counts of moves. 00547 REQUIRES: HasCapacity(moves.size()) 00548 @param moves The list of moves. */ 00549 SgUctValue Create(const std::vector<SgUctMoveInfo>& moves); 00550 00551 /** Create a number of new nodes at the end of the storage. 00552 REQUIRES: HasCapacity(n) 00553 @param n The number of nodes to create. */ 00554 void CreateN(std::size_t n); 00555 00556 void Swap(SgUctAllocator& allocator); 00557 00558 private: 00559 SgUctNode* m_start; 00560 00561 SgUctNode* m_finish; 00562 00563 SgUctNode* m_endOfStorage; 00564 00565 /** Not implemented. 00566 Cannot be copied because array contains pointers to elements. 00567 Use Swap() instead. */ 00568 SgUctAllocator& operator=(const SgUctAllocator& tree); 00569 }; 00570 00571 inline SgUctAllocator::SgUctAllocator() 00572 { 00573 m_start = 0; 00574 } 00575 00576 inline void SgUctAllocator::Clear() 00577 { 00578 if (m_start != 0) 00579 { 00580 for (SgUctNode* it = m_start; it != m_finish; ++it) 00581 it->~SgUctNode(); 00582 m_finish = m_start; 00583 } 00584 } 00585 00586 inline SgUctNode* SgUctAllocator::CreateOne(SgMove move) 00587 { 00588 SG_ASSERT(HasCapacity(1)); 00589 new(m_finish) SgUctNode(move); 00590 return (m_finish++); 00591 } 00592 00593 inline SgUctValue SgUctAllocator::Create( 00594 const std::vector<SgUctMoveInfo>& moves) 00595 { 00596 SG_ASSERT(HasCapacity(moves.size())); 00597 SgUctValue count = 0; 00598 for (std::vector<SgUctMoveInfo>::const_iterator it = moves.begin(); 00599 it != moves.end(); ++it, ++m_finish) 00600 { 00601 new(m_finish) SgUctNode(*it); 00602 count += it->m_count; 00603 } 00604 return count; 00605 } 00606 00607 inline void SgUctAllocator::CreateN(std::size_t n) 00608 { 00609 SG_ASSERT(HasCapacity(n)); 00610 SgUctNode* newFinish = m_finish + n; 00611 for ( ; m_finish != newFinish; ++m_finish) 00612 new(m_finish) SgUctNode(SG_NULLMOVE); 00613 } 00614 00615 inline SgUctNode* SgUctAllocator::Finish() 00616 { 00617 return m_finish; 00618 } 00619 00620 inline const SgUctNode* SgUctAllocator::Finish() const 00621 { 00622 return m_finish; 00623 } 00624 00625 inline bool SgUctAllocator::HasCapacity(std::size_t n) const 00626 { 00627 return (m_finish + n <= m_endOfStorage); 00628 } 00629 00630 inline std::size_t SgUctAllocator::MaxNodes() const 00631 { 00632 return m_endOfStorage - m_start; 00633 } 00634 00635 inline std::size_t SgUctAllocator::NuNodes() const 00636 { 00637 return m_finish - m_start; 00638 } 00639 00640 inline const SgUctNode* SgUctAllocator::Start() const 00641 { 00642 return m_start; 00643 } 00644 00645 //---------------------------------------------------------------------------- 00646 00647 /** Tree used in SgUctSearch. 00648 The nodes can be accessed only by getting non-const references or modified 00649 through accessor functions of SgUctTree, therefore SgUctTree can guarantee 00650 the integrity of the tree structure. 00651 The tree can be used in a lock-free way during a search (see 00652 @ref sguctsearchlockfree). 00653 @ingroup sguctgroup */ 00654 class SgUctTree 00655 { 00656 public: 00657 friend class SgUctChildIterator; 00658 00659 /** Constructor. 00660 Construct a tree. Before using the tree, CreateAllocators() and 00661 SetMaxNodes() must be called (in this order). */ 00662 SgUctTree(); 00663 00664 /** Create node allocators for threads. */ 00665 void CreateAllocators(std::size_t nuThreads); 00666 00667 /** Add a game result. 00668 @param node The node. 00669 @param father The father (if not root) to update the position count. 00670 @param eval */ 00671 void AddGameResult(const SgUctNode& node, const SgUctNode* father, 00672 SgUctValue eval); 00673 00674 /** Adds a game result count times. */ 00675 void AddGameResults(const SgUctNode& node, const SgUctNode* father, 00676 SgUctValue eval, SgUctValue count); 00677 00678 /** Removes a game result. 00679 @param node The node. 00680 @param father The father (if not root) to update the position count. 00681 @param eval */ 00682 void RemoveGameResult(const SgUctNode& node, const SgUctNode* father, 00683 SgUctValue eval); 00684 00685 /** Removes a game result count times. */ 00686 void RemoveGameResults(const SgUctNode& node, const SgUctNode* father, 00687 SgUctValue eval, SgUctValue count); 00688 00689 /** Adds a virtual loss to the given node. */ 00690 void AddVirtualLoss(const SgUctNode &node); 00691 00692 /** Removes a virtual loss to the given node. */ 00693 void RemoveVirtualLoss(const SgUctNode &node); 00694 00695 void SetProvenType(const SgUctNode& node, SgUctProvenType type); 00696 00697 void SetKnowledgeCount(const SgUctNode& node, SgUctValue count); 00698 00699 void Clear(); 00700 00701 /** Return the current maximum number of nodes. 00702 This returns the maximum number of nodes as set by SetMaxNodes(). 00703 See SetMaxNodes() why the real maximum number of nodes can be higher 00704 or lower. */ 00705 std::size_t MaxNodes() const; 00706 00707 /** Change maximum number of nodes. 00708 Also clears the tree. This will call SetMaxNodes() at each registered 00709 allocator with maxNodes / numberAllocators as an argument. The real 00710 maximum number of nodes can be higher (because the root node is 00711 owned by this class, not an allocator) or lower (if maxNodes is not 00712 a multiple of the number of allocators). 00713 @param maxNodes Maximum number of nodes */ 00714 void SetMaxNodes(std::size_t maxNodes); 00715 00716 /** Swap content with another tree. 00717 The other tree must have the same number of allocators and 00718 the same maximum number of nodes. */ 00719 void Swap(SgUctTree& tree); 00720 00721 bool HasCapacity(std::size_t allocatorId, std::size_t n) const; 00722 00723 /** Create children nodes. 00724 Requires: Allocator(allocatorId).HasCapacity(moves.size()) */ 00725 void CreateChildren(std::size_t allocatorId, const SgUctNode& node, 00726 const std::vector<SgUctMoveInfo>& moves); 00727 00728 /** Merge new children with old. 00729 Requires: Allocator(allocatorId).HasCapacity(moves.size()) */ 00730 void MergeChildren(std::size_t allocatorId, const SgUctNode& node, 00731 const std::vector<SgUctMoveInfo>& moves, 00732 bool deleteChildTrees); 00733 00734 /** Extract subtree to a different tree. 00735 The tree will be truncated if one of the allocators overflows (can 00736 happen due to reassigning nodes to different allocators), the given 00737 max time is exceeded or on SgUserAbort(). 00738 @param[out] target The resulting subtree. Must have the same maximum 00739 number of nodes. Will be cleared before using. 00740 @param node The start node of the subtree. 00741 @param warnTruncate Print warning to SgDebug() if tree was truncated 00742 @param maxTime Truncate the tree, if the extraction takes longer than 00743 the given time 00744 @param minCount */ 00745 void ExtractSubtree(SgUctTree& target, const SgUctNode& node, 00746 bool warnTruncate, 00747 double maxTime = std::numeric_limits<double>::max(), 00748 SgUctValue minCount = 0) const; 00749 00750 /** Get a copy of the tree with low count nodes pruned. 00751 The tree will be truncated if one of the allocators overflows (can 00752 happen due to reassigning nodes to different allocators), the given 00753 max time is exceeded or on SgUserAbort(). 00754 @param[out] target The resulting tree. Must have the same maximum 00755 number of nodes. Will be cleared before using. 00756 @param minCount The minimum count (SgUctNode::MoveCount()) 00757 @param warnTruncate Print warning to SgDebug() if tree was truncated 00758 @param maxTime Truncate the tree, if the extraction takes longer than 00759 the given time */ 00760 void CopyPruneLowCount(SgUctTree& target, SgUctValue minCount, 00761 bool warnTruncate, 00762 double maxTime = std::numeric_limits<double>::max()) const; 00763 00764 const SgUctNode& Root() const; 00765 00766 std::size_t NuAllocators() const; 00767 00768 /** Total number of nodes. 00769 Includes the sum of nodes in all allocators plus the root node. */ 00770 std::size_t NuNodes() const; 00771 00772 /** Number of nodes in one of the allocators. */ 00773 std::size_t NuNodes(std::size_t allocatorId) const; 00774 00775 /** Add a game result value to the RAVE value of a node. 00776 @param node The node with the move 00777 @param value 00778 @param weight 00779 @see SgUctSearch::Rave(). */ 00780 void AddRaveValue(const SgUctNode& node, SgUctValue value, SgUctValue weight); 00781 00782 /** Remove a game result from the RAVE value of a node. 00783 @param node The node with the move 00784 @param value 00785 @param weight 00786 @see SgUctSearch::Rave(). */ 00787 void RemoveRaveValue(const SgUctNode& node, SgUctValue value, SgUctValue weight); 00788 00789 /** Initialize the value and count of a node. */ 00790 void InitializeValue(const SgUctNode& node, SgUctValue value, 00791 SgUctValue count); 00792 00793 void SetPosCount(const SgUctNode& node, SgUctValue posCount); 00794 00795 /** Initialize the rave value and count of a move node with prior 00796 knowledge. */ 00797 void InitializeRaveValue(const SgUctNode& node, SgUctValue value, SgUctValue count); 00798 00799 /** Remove some children of a node according to a list of filtered moves. 00800 Requires: Allocator(allocatorId).HasCapacity(node.NuChildren()) <br> 00801 For efficiency, no reorganization of the tree is done to remove 00802 the dead subtrees (and NuNodes() will not report the real number of 00803 nodes in the tree). This function can be used in lock-free mode. */ 00804 void ApplyFilter(std::size_t allocatorId, const SgUctNode& node, 00805 const std::vector<SgMove>& rootFilter); 00806 00807 /** Sets the children under node to be exactly those in moves, 00808 reusing the old children if possible. Children not in moves 00809 are pruned, children missing from moves are added as leaves. 00810 Requires: Allocator(allocatorId).HasCapacity(moves.size()) */ 00811 void SetChildren(std::size_t allocatorId, const SgUctNode& node, 00812 const vector<SgMove>& moves); 00813 00814 /** @name Functions for debugging */ 00815 // @{ 00816 00817 /** Do some consistency checks. 00818 @throws SgException if inconsistencies are detected. */ 00819 void CheckConsistency() const; 00820 00821 /** Check if tree contains node. 00822 This function uses pointer comparisons. Since the result of 00823 comparisons for pointers to elements in different containers 00824 is platform-dependent, it is only guaranteed that it returns true, 00825 if not node belongs to the allocator, but not that it returns false 00826 for nodes not in the tree. */ 00827 bool Contains(const SgUctNode& node) const; 00828 00829 void DumpDebugInfo(std::ostream& out) const; 00830 00831 // @} // @name 00832 00833 private: 00834 std::size_t m_maxNodes; 00835 00836 SgUctNode m_root; 00837 00838 /** Allocators. 00839 The elements are owned by the vector (shared_ptr is only used because 00840 auto_ptr should not be used with standard containers) */ 00841 std::vector<boost::shared_ptr<SgUctAllocator> > m_allocators; 00842 00843 /** Not implemented. 00844 Cannot be copied because allocators contain pointers to elements. 00845 Use SgUctTree::Swap instead. */ 00846 SgUctTree& operator=(const SgUctTree& tree); 00847 00848 SgUctAllocator& Allocator(std::size_t i); 00849 00850 const SgUctAllocator& Allocator(std::size_t i) const; 00851 00852 bool CopySubtree(SgUctTree& target, SgUctNode& targetNode, 00853 const SgUctNode& node, SgUctValue minCount, 00854 std::size_t& currentAllocatorId, bool warnTruncate, 00855 bool& abort, SgTimer& timer, double maxTime, 00856 bool alwaysKeepProven) const; 00857 00858 void ThrowConsistencyError(const std::string& message) const; 00859 }; 00860 00861 inline void SgUctTree::AddGameResult(const SgUctNode& node, 00862 const SgUctNode* father, SgUctValue eval) 00863 { 00864 SG_ASSERT(Contains(node)); 00865 // Parameters are const-references, because only the tree is allowed 00866 // to modify nodes 00867 if (father != 0) 00868 const_cast<SgUctNode*>(father)->IncPosCount(); 00869 const_cast<SgUctNode&>(node).AddGameResult(eval); 00870 } 00871 00872 inline void SgUctTree::AddGameResults(const SgUctNode& node, 00873 const SgUctNode* father, SgUctValue eval, 00874 SgUctValue count) 00875 { 00876 00877 SG_ASSERT(Contains(node)); 00878 // Parameters are const-references, because only the tree is allowed 00879 // to modify nodes 00880 if (father != 0) 00881 const_cast<SgUctNode*>(father)->IncPosCount(count); 00882 const_cast<SgUctNode&>(node).AddGameResults(eval, count); 00883 } 00884 00885 inline void SgUctTree::CreateChildren(std::size_t allocatorId, 00886 const SgUctNode& node, 00887 const std::vector<SgUctMoveInfo>& moves) 00888 { 00889 SG_ASSERT(Contains(node)); 00890 // Parameters are const-references, because only the tree is allowed 00891 // to modify nodes 00892 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); 00893 SG_ASSERT(moves.size() <= std::size_t(std::numeric_limits<int>::max())); 00894 int nuChildren = int(moves.size()); 00895 SG_ASSERT(nuChildren > 0); 00896 SgUctAllocator& allocator = Allocator(allocatorId); 00897 SG_ASSERT(allocator.HasCapacity(nuChildren)); 00898 00899 // In lock-free multi-threading, a node can be expanded multiple times 00900 // (the later thread overwrites the children information of the previous 00901 // thread) 00902 SG_ASSERT(NuAllocators() > 1 || ! node.HasChildren()); 00903 00904 const SgUctNode* firstChild = allocator.Finish(); 00905 00906 SgUctValue parentCount = allocator.Create(moves); 00907 00908 // Write order dependency: SgUctSearch in lock-free mode assumes that 00909 // m_firstChild is valid if m_nuChildren is greater zero 00910 nonConstNode.SetPosCount(parentCount); 00911 SgSynchronizeThreadMemory(); 00912 nonConstNode.SetFirstChild(firstChild); 00913 SgSynchronizeThreadMemory(); 00914 nonConstNode.SetNuChildren(nuChildren); 00915 } 00916 00917 inline void SgUctTree::RemoveGameResult(const SgUctNode& node, 00918 const SgUctNode* father, SgUctValue eval) 00919 { 00920 SG_ASSERT(Contains(node)); 00921 // Parameters are const-references, because only the tree is allowed 00922 // to modify nodes 00923 if (father != 0) 00924 const_cast<SgUctNode*>(father)->DecPosCount(); 00925 const_cast<SgUctNode&>(node).RemoveGameResult(eval); 00926 } 00927 00928 inline void SgUctTree::RemoveGameResults(const SgUctNode& node, 00929 const SgUctNode* father, SgUctValue eval, 00930 SgUctValue count) 00931 { 00932 SG_ASSERT(Contains(node)); 00933 // Parameters are const-references, because only the tree is allowed 00934 // to modify nodes 00935 if (father != 0) 00936 const_cast<SgUctNode*>(father)->DecPosCount(count); 00937 const_cast<SgUctNode&>(node).RemoveGameResults(eval, count); 00938 } 00939 00940 inline void SgUctTree::AddVirtualLoss(const SgUctNode& node) 00941 { 00942 const_cast<SgUctNode&>(node).AddVirtualLoss(); 00943 } 00944 00945 inline void SgUctTree::RemoveVirtualLoss(const SgUctNode& node) 00946 { 00947 const_cast<SgUctNode&>(node).RemoveVirtualLoss(); 00948 } 00949 00950 inline void SgUctTree::AddRaveValue(const SgUctNode& node, SgUctValue value, 00951 SgUctValue weight) 00952 { 00953 SG_ASSERT(Contains(node)); 00954 // Parameters are const-references, because only the tree is allowed 00955 // to modify nodes 00956 const_cast<SgUctNode&>(node).AddRaveValue(value, weight); 00957 } 00958 00959 inline void SgUctTree::RemoveRaveValue(const SgUctNode& node, SgUctValue value, 00960 SgUctValue weight) 00961 { 00962 SG_UNUSED(weight); 00963 SG_ASSERT(Contains(node)); 00964 // Parameters are const-references, because only the tree is allowed 00965 // to modify nodes 00966 const_cast<SgUctNode&>(node).RemoveRaveValue(value, weight); 00967 } 00968 00969 inline SgUctAllocator& SgUctTree::Allocator(std::size_t i) 00970 { 00971 SG_ASSERT(i < m_allocators.size()); 00972 return *m_allocators[i]; 00973 } 00974 00975 inline const SgUctAllocator& SgUctTree::Allocator(std::size_t i) const 00976 { 00977 SG_ASSERT(i < m_allocators.size()); 00978 return *m_allocators[i]; 00979 } 00980 00981 inline bool SgUctTree::HasCapacity(std::size_t allocatorId, 00982 std::size_t n) const 00983 { 00984 return Allocator(allocatorId).HasCapacity(n); 00985 } 00986 00987 inline void SgUctTree::InitializeValue(const SgUctNode& node, 00988 SgUctValue value, SgUctValue count) 00989 { 00990 SG_ASSERT(Contains(node)); 00991 // Parameter is const-reference, because only the tree is allowed 00992 // to modify nodes 00993 const_cast<SgUctNode&>(node).InitializeValue(value, count); 00994 } 00995 00996 inline void SgUctTree::InitializeRaveValue(const SgUctNode& node, 00997 SgUctValue value, SgUctValue count) 00998 { 00999 SG_ASSERT(Contains(node)); 01000 // Parameters are const-references, because only the tree is allowed 01001 // to modify nodes 01002 const_cast<SgUctNode&>(node).InitializeRaveValue(value, count); 01003 } 01004 01005 inline std::size_t SgUctTree::MaxNodes() const 01006 { 01007 return m_maxNodes; 01008 } 01009 01010 inline std::size_t SgUctTree::NuAllocators() const 01011 { 01012 return m_allocators.size(); 01013 } 01014 01015 inline std::size_t SgUctTree::NuNodes(std::size_t allocatorId) const 01016 { 01017 return Allocator(allocatorId).NuNodes(); 01018 } 01019 01020 inline const SgUctNode& SgUctTree::Root() const 01021 { 01022 return m_root; 01023 } 01024 01025 inline void SgUctTree::SetKnowledgeCount(const SgUctNode& node, 01026 SgUctValue count) 01027 { 01028 SG_ASSERT(Contains(node)); 01029 // Parameters are const-references, because only the tree is allowed 01030 // to modify nodes 01031 const_cast<SgUctNode&>(node).SetKnowledgeCount(count); 01032 } 01033 01034 inline void SgUctTree::SetPosCount(const SgUctNode& node, 01035 SgUctValue posCount) 01036 { 01037 SG_ASSERT(Contains(node)); 01038 // Parameters are const-references, because only the tree is allowed 01039 // to modify nodes 01040 const_cast<SgUctNode&>(node).SetPosCount(posCount); 01041 } 01042 01043 inline void SgUctTree::SetProvenType(const SgUctNode &node, 01044 SgUctProvenType type) 01045 { 01046 SG_ASSERT(Contains(node)); 01047 // Parameters are const-references, because only the tree is allowed 01048 // to modify nodes 01049 const_cast<SgUctNode&>(node).SetProvenType(type); 01050 } 01051 01052 //---------------------------------------------------------------------------- 01053 01054 /** Iterator over all children of a node. 01055 It was intentionally implemented to be used only, if at least one child 01056 exists (checked with an assertion), since in many use cases, the case 01057 of no children needs to be handled specially and should be checked 01058 before doing a loop over all children. 01059 @ingroup sguctgroup */ 01060 class SgUctChildIterator 01061 { 01062 public: 01063 /** Constructor. 01064 Requires: node.HasChildren() */ 01065 SgUctChildIterator(const SgUctTree& tree, const SgUctNode& node); 01066 01067 const SgUctNode& operator*() const; 01068 01069 void operator++(); 01070 01071 operator bool() const; 01072 01073 private: 01074 const SgUctNode* m_current; 01075 01076 const SgUctNode* m_last; 01077 }; 01078 01079 inline SgUctChildIterator::SgUctChildIterator(const SgUctTree& tree, 01080 const SgUctNode& node) 01081 { 01082 SG_DEBUG_ONLY(tree); 01083 SG_ASSERT(tree.Contains(node)); 01084 SG_ASSERT(node.HasChildren()); 01085 m_current = node.FirstChild(); 01086 m_last = m_current + node.NuChildren(); 01087 } 01088 01089 inline const SgUctNode& SgUctChildIterator::operator*() const 01090 { 01091 return *m_current; 01092 } 01093 01094 inline void SgUctChildIterator::operator++() 01095 { 01096 ++m_current; 01097 } 01098 01099 inline SgUctChildIterator::operator bool() const 01100 { 01101 return (m_current < m_last); 01102 } 01103 01104 //---------------------------------------------------------------------------- 01105 01106 /** Iterator for traversing a tree depth-first. 01107 @ingroup sguctgroup */ 01108 class SgUctTreeIterator 01109 { 01110 public: 01111 SgUctTreeIterator(const SgUctTree& tree); 01112 01113 const SgUctNode& operator*() const; 01114 01115 void operator++(); 01116 01117 operator bool() const; 01118 01119 private: 01120 const SgUctTree& m_tree; 01121 01122 const SgUctNode* m_current; 01123 01124 /** Stack of child iterators. 01125 The elements are owned by the stack (shared_ptr is only used because 01126 auto_ptr should not be used with standard containers) */ 01127 std::stack<boost::shared_ptr<SgUctChildIterator> > m_stack; 01128 }; 01129 01130 //---------------------------------------------------------------------------- 01131 01132 #endif // SG_UCTTREE_H