00001 //---------------------------------------------------------------------------- 00002 /** @file SgUctTree.cpp 00003 See SgUctTree.h */ 00004 //---------------------------------------------------------------------------- 00005 00006 #include "SgSystem.h" 00007 #include "SgUctTree.h" 00008 00009 #include <boost/format.hpp> 00010 #include "SgDebug.h" 00011 #include "SgTimer.h" 00012 00013 using namespace std; 00014 using boost::format; 00015 using boost::shared_ptr; 00016 00017 //---------------------------------------------------------------------------- 00018 00019 SgUctAllocator::~SgUctAllocator() 00020 { 00021 if (m_start != 0) 00022 { 00023 Clear(); 00024 std::free(m_start); 00025 } 00026 } 00027 00028 bool SgUctAllocator::Contains(const SgUctNode& node) const 00029 { 00030 return (&node >= m_start && &node < m_finish); 00031 } 00032 00033 void SgUctAllocator::Swap(SgUctAllocator& allocator) 00034 { 00035 swap(m_start, allocator.m_start); 00036 swap(m_finish, allocator.m_finish); 00037 swap(m_endOfStorage, allocator.m_endOfStorage); 00038 } 00039 00040 void SgUctAllocator::SetMaxNodes(std::size_t maxNodes) 00041 { 00042 if (m_start != 0) 00043 { 00044 Clear(); 00045 std::free(m_start); 00046 } 00047 void* ptr = std::malloc(maxNodes * sizeof(SgUctNode)); 00048 if (ptr == 0) 00049 throw std::bad_alloc(); 00050 m_start = static_cast<SgUctNode*>(ptr); 00051 m_finish = m_start; 00052 m_endOfStorage = m_start + maxNodes; 00053 } 00054 00055 //---------------------------------------------------------------------------- 00056 00057 SgUctTree::SgUctTree() 00058 : m_maxNodes(0), 00059 m_root(SG_NULLMOVE) 00060 { 00061 } 00062 00063 void SgUctTree::ApplyFilter(std::size_t allocatorId, const SgUctNode& node, 00064 const vector<SgMove>& rootFilter) 00065 { 00066 SG_ASSERT(Contains(node)); 00067 SG_ASSERT(Allocator(allocatorId).HasCapacity(node.NuChildren())); 00068 if (! node.HasChildren()) 00069 return; 00070 00071 SgUctAllocator& allocator = Allocator(allocatorId); 00072 const SgUctNode* firstChild = allocator.Finish(); 00073 00074 int nuChildren = 0; 00075 for (SgUctChildIterator it(*this, node); it; ++it) 00076 { 00077 SgMove move = (*it).Move(); 00078 if (find(rootFilter.begin(), rootFilter.end(), move) 00079 == rootFilter.end()) 00080 { 00081 SgUctNode* child = allocator.CreateOne(move); 00082 child->CopyDataFrom(*it); 00083 int childNuChildren = (*it).NuChildren(); 00084 child->SetNuChildren(childNuChildren); 00085 if (childNuChildren > 0) 00086 child->SetFirstChild((*it).FirstChild()); 00087 ++nuChildren; 00088 } 00089 } 00090 00091 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); 00092 // Write order dependency: SgUctSearch in lock-free mode assumes that 00093 // m_firstChild is valid if m_nuChildren is greater zero 00094 SgSynchronizeThreadMemory(); 00095 nonConstNode.SetFirstChild(firstChild); 00096 SgSynchronizeThreadMemory(); 00097 nonConstNode.SetNuChildren(nuChildren); 00098 } 00099 00100 void SgUctTree::SetChildren(std::size_t allocatorId, const SgUctNode& node, 00101 const vector<SgMove>& moves) 00102 { 00103 SG_ASSERT(Contains(node)); 00104 SG_ASSERT(Allocator(allocatorId).HasCapacity(moves.size())); 00105 SG_ASSERT(node.HasChildren()); 00106 00107 SgUctAllocator& allocator = Allocator(allocatorId); 00108 const SgUctNode* firstChild = allocator.Finish(); 00109 00110 int nuChildren = 0; 00111 for (size_t i = 0; i < moves.size(); ++i) 00112 { 00113 bool found = false; 00114 for (SgUctChildIterator it(*this, node); it; ++it) 00115 { 00116 SgMove move = (*it).Move(); 00117 if (move == moves[i]) 00118 { 00119 found = true; 00120 SgUctNode* child = allocator.CreateOne(move); 00121 child->CopyDataFrom(*it); 00122 int childNuChildren = (*it).NuChildren(); 00123 child->SetNuChildren(childNuChildren); 00124 if (childNuChildren > 0) 00125 child->SetFirstChild((*it).FirstChild()); 00126 ++nuChildren; 00127 break; 00128 } 00129 } 00130 if (! found) 00131 { 00132 allocator.CreateOne(moves[i]); 00133 ++nuChildren; 00134 } 00135 } 00136 SG_ASSERT((size_t)nuChildren == moves.size()); 00137 00138 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); 00139 // Write order dependency: SgUctSearch in lock-free mode assumes that 00140 // m_firstChild is valid if m_nuChildren is greater zero 00141 SgSynchronizeThreadMemory(); 00142 nonConstNode.SetFirstChild(firstChild); 00143 SgSynchronizeThreadMemory(); 00144 nonConstNode.SetNuChildren(nuChildren); 00145 } 00146 00147 void SgUctTree::CheckConsistency() const 00148 { 00149 for (SgUctTreeIterator it(*this); it; ++it) 00150 if (! Contains(*it)) 00151 ThrowConsistencyError(str(format("! Contains(%1%)") % &(*it))); 00152 } 00153 00154 void SgUctTree::Clear() 00155 { 00156 for (size_t i = 0; i < NuAllocators(); ++i) 00157 Allocator(i).Clear(); 00158 m_root = SgUctNode(SG_NULLMOVE); 00159 } 00160 00161 /** Check if node is in tree. 00162 Only used for assertions. May not be available in future implementations. */ 00163 bool SgUctTree::Contains(const SgUctNode& node) const 00164 { 00165 if (&node == &m_root) 00166 return true; 00167 for (size_t i = 0; i < NuAllocators(); ++i) 00168 if (Allocator(i).Contains(node)) 00169 return true; 00170 return false; 00171 } 00172 00173 void SgUctTree::CopyPruneLowCount(SgUctTree& target, SgUctValue minCount, 00174 bool warnTruncate, double maxTime) const 00175 { 00176 size_t allocatorId = 0; 00177 SgTimer timer; 00178 bool abort = false; 00179 CopySubtree(target, target.m_root, m_root, minCount, allocatorId, 00180 warnTruncate, abort, timer, maxTime, 00181 /* alwaysKeepProven */ false); 00182 SgSynchronizeThreadMemory(); 00183 } 00184 00185 /** Recursive function used by SgUctTree::ExtractSubtree and 00186 SgUctTree::CopyPruneLowCount. 00187 @param target The target tree. 00188 @param targetNode The target node; it is already created but the content 00189 not yet copied 00190 @param node The node in the source tree to be copied. 00191 @param minCount The minimum count (SgUctNode::MoveCount()) of a non-root 00192 node in the source tree to copy 00193 @param currentAllocatorId The current node allocator. Will be incremented 00194 in each call to CopySubtree to use node allocators of target tree evenly. 00195 @param warnTruncate Print warning to SgDebug() if tree was 00196 truncated (e.g due to reassigning nodes to different allocators) 00197 @param[in,out] abort Flag to abort copying. Must be initialized to false 00198 by top-level caller 00199 @param timer 00200 @param maxTime See ExtractSubtree() 00201 @param alwaysKeepProven Copy proven nodes even if below minCount */ 00202 bool SgUctTree::CopySubtree(SgUctTree& target, SgUctNode& targetNode, 00203 const SgUctNode& node, SgUctValue minCount, 00204 std::size_t& currentAllocatorId, 00205 bool warnTruncate, bool& abort, SgTimer& timer, 00206 double maxTime, bool alwaysKeepProven) const 00207 00208 { 00209 SG_ASSERT(Contains(node)); 00210 SG_ASSERT(target.Contains(targetNode)); 00211 targetNode.CopyDataFrom(node); 00212 00213 if (! node.HasChildren()) 00214 return true; 00215 00216 if (node.IsProven()) 00217 { 00218 if (!alwaysKeepProven && node.MoveCount() < minCount) 00219 { 00220 targetNode.SetProvenType(SG_NOT_PROVEN); 00221 return false; 00222 } 00223 } 00224 else 00225 { 00226 if (node.MoveCount() < minCount) 00227 return false; 00228 } 00229 00230 SgUctAllocator& targetAllocator = target.Allocator(currentAllocatorId); 00231 int nuChildren = node.NuChildren(); 00232 if (! abort) 00233 { 00234 if (! targetAllocator.HasCapacity(nuChildren)) 00235 { 00236 // This can happen even if target tree has same maximum number of 00237 // nodes, because allocators are used differently. 00238 if (warnTruncate) 00239 SgDebug() << 00240 "SgUctTree::CopySubtree: Truncated (allocator capacity)\n"; 00241 abort = true; 00242 } 00243 if (timer.IsTimeOut(maxTime, 10000)) 00244 { 00245 if (warnTruncate) 00246 SgDebug() << "SgUctTree::CopySubtree: Truncated (max time)\n"; 00247 abort = true; 00248 } 00249 if (SgUserAbort()) 00250 { 00251 if (warnTruncate) 00252 SgDebug() << "SgUctTree::CopySubtree: Truncated (aborted)\n"; 00253 abort = true; 00254 } 00255 } 00256 if (abort) 00257 { 00258 // Don't copy the children and set the pos count to zero (should 00259 // reflect the sum of children move counts) 00260 targetNode.SetPosCount(0); 00261 if (targetNode.IsProven()) 00262 targetNode.SetProvenType(SG_NOT_PROVEN); 00263 return false; 00264 } 00265 00266 SgUctNode* firstTargetChild = targetAllocator.Finish(); 00267 targetNode.SetFirstChild(firstTargetChild); 00268 targetNode.SetNuChildren(nuChildren); 00269 00270 // Create target nodes first (must be contiguous in the target tree) 00271 targetAllocator.CreateN(nuChildren); 00272 00273 // Recurse 00274 bool copiedCompleteTree = true; 00275 SgUctNode* targetChild = firstTargetChild; 00276 for (SgUctChildIterator it(*this, node); it; ++it, ++targetChild) 00277 { 00278 const SgUctNode& child = *it; 00279 ++currentAllocatorId; // Cycle to use allocators uniformly 00280 if (currentAllocatorId >= target.NuAllocators()) 00281 currentAllocatorId = 0; 00282 copiedCompleteTree &= CopySubtree(target, *targetChild, child, 00283 minCount, currentAllocatorId, 00284 warnTruncate, abort, timer, 00285 maxTime, alwaysKeepProven); 00286 } 00287 if (!copiedCompleteTree && targetNode.IsProven()) 00288 targetNode.SetProvenType(SG_NOT_PROVEN); 00289 return copiedCompleteTree; 00290 } 00291 00292 void SgUctTree::CreateAllocators(std::size_t nuThreads) 00293 { 00294 Clear(); 00295 m_allocators.clear(); 00296 for (size_t i = 0; i < nuThreads; ++i) 00297 { 00298 boost::shared_ptr<SgUctAllocator> allocator(new SgUctAllocator()); 00299 m_allocators.push_back(allocator); 00300 } 00301 } 00302 00303 void SgUctTree::DumpDebugInfo(std::ostream& out) const 00304 { 00305 out << "Root " << &m_root << '\n'; 00306 for (size_t i = 0; i < NuAllocators(); ++i) 00307 out << "Allocator " << i 00308 << " size=" << Allocator(i).NuNodes() 00309 << " start=" << Allocator(i).Start() 00310 << " finish=" << Allocator(i).Finish() << '\n'; 00311 } 00312 00313 void SgUctTree::ExtractSubtree(SgUctTree& target, const SgUctNode& node, 00314 bool warnTruncate, double maxTime, 00315 SgUctValue minCount) const 00316 { 00317 SG_ASSERT(Contains(node)); 00318 SG_ASSERT(&target != this); 00319 SG_ASSERT(target.MaxNodes() == MaxNodes()); 00320 target.Clear(); 00321 size_t allocatorId = 0; 00322 SgTimer timer; 00323 bool abort = false; 00324 CopySubtree(target, target.m_root, node, minCount, allocatorId, warnTruncate, 00325 abort, timer, maxTime, /* alwaysKeepProven */ true); 00326 SgSynchronizeThreadMemory(); 00327 } 00328 00329 void SgUctTree::MergeChildren(std::size_t allocatorId, const SgUctNode& node, 00330 const std::vector<SgUctMoveInfo>& moves, 00331 bool deleteChildTrees) 00332 { 00333 SG_ASSERT(Contains(node)); 00334 // Parameters are const-references, because only the tree is allowed 00335 // to modify nodes 00336 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); 00337 SG_ASSERT(moves.size() <= std::size_t(std::numeric_limits<int>::max())); 00338 int nuNewChildren = int(moves.size()); 00339 00340 if (nuNewChildren == 0) 00341 { 00342 // Write order dependency 00343 nonConstNode.SetNuChildren(0); 00344 SgSynchronizeThreadMemory(); 00345 nonConstNode.SetFirstChild(0); 00346 return; 00347 } 00348 00349 SgUctAllocator& allocator = Allocator(allocatorId); 00350 SG_ASSERT(allocator.HasCapacity(nuNewChildren)); 00351 00352 const SgUctNode* newFirstChild = allocator.Finish(); 00353 SgUctValue parentCount = allocator.Create(moves); 00354 00355 // Update new children with data in old children 00356 for (std::size_t i = 0; i < moves.size(); ++i) 00357 { 00358 SgUctNode* newChild = const_cast<SgUctNode*>(&newFirstChild[i]); 00359 for (SgUctChildIterator it(*this, node); it; ++it) 00360 { 00361 const SgUctNode& oldChild = *it; 00362 if (oldChild.Move() == moves[i].m_move) 00363 { 00364 newChild->MergeResults(oldChild); 00365 newChild->SetKnowledgeCount(oldChild.KnowledgeCount()); 00366 if (! deleteChildTrees) 00367 { 00368 newChild->SetPosCount(oldChild.PosCount()); 00369 parentCount += oldChild.MoveCount(); 00370 if (oldChild.HasChildren()) 00371 { 00372 newChild->SetFirstChild(oldChild.FirstChild()); 00373 newChild->SetNuChildren(oldChild.NuChildren()); 00374 } 00375 } 00376 break; 00377 } 00378 } 00379 } 00380 nonConstNode.SetPosCount(parentCount); 00381 00382 // Write order dependency: We do not want an SgUctChildIterator to 00383 // run past the end of a node's children, which can happen if one 00384 // is created between the two statements below. We modify node in 00385 // such a way so as to avoid that. 00386 SgSynchronizeThreadMemory(); 00387 if (nonConstNode.NuChildren() < nuNewChildren) 00388 { 00389 nonConstNode.SetFirstChild(newFirstChild); 00390 SgSynchronizeThreadMemory(); 00391 nonConstNode.SetNuChildren(nuNewChildren); 00392 } 00393 else 00394 { 00395 nonConstNode.SetNuChildren(nuNewChildren); 00396 SgSynchronizeThreadMemory(); 00397 nonConstNode.SetFirstChild(newFirstChild); 00398 } 00399 } 00400 00401 std::size_t SgUctTree::NuNodes() const 00402 { 00403 size_t nuNodes = 1; // Count root node 00404 for (size_t i = 0; i < NuAllocators(); ++i) 00405 nuNodes += Allocator(i).NuNodes(); 00406 return nuNodes; 00407 } 00408 00409 void SgUctTree::SetMaxNodes(std::size_t maxNodes) 00410 { 00411 Clear(); 00412 size_t nuAllocators = NuAllocators(); 00413 if (nuAllocators == 0) 00414 { 00415 SgDebug() << "SgUctTree::SetMaxNodes: no allocators registered\n"; 00416 SG_ASSERT(false); 00417 return; 00418 } 00419 m_maxNodes = maxNodes; 00420 size_t maxNodesPerAlloc = maxNodes / nuAllocators; 00421 for (size_t i = 0; i < NuAllocators(); ++i) 00422 Allocator(i).SetMaxNodes(maxNodesPerAlloc); 00423 } 00424 00425 void SgUctTree::Swap(SgUctTree& tree) 00426 { 00427 SG_ASSERT(MaxNodes() == tree.MaxNodes()); 00428 SG_ASSERT(NuAllocators() == tree.NuAllocators()); 00429 swap(m_root, tree.m_root); 00430 for (size_t i = 0; i < NuAllocators(); ++i) 00431 Allocator(i).Swap(tree.Allocator(i)); 00432 } 00433 00434 void SgUctTree::ThrowConsistencyError(const string& message) const 00435 { 00436 DumpDebugInfo(SgDebug()); 00437 throw SgException("SgUctTree::ThrowConsistencyError: " + message); 00438 } 00439 00440 //---------------------------------------------------------------------------- 00441 00442 SgUctTreeIterator::SgUctTreeIterator(const SgUctTree& tree) 00443 : m_tree(tree), 00444 m_current(&tree.Root()) 00445 { 00446 } 00447 00448 const SgUctNode& SgUctTreeIterator::operator*() const 00449 { 00450 return *m_current; 00451 } 00452 00453 void SgUctTreeIterator::operator++() 00454 { 00455 if (m_current->HasChildren()) 00456 { 00457 SgUctChildIterator* it = new SgUctChildIterator(m_tree, *m_current); 00458 m_stack.push(shared_ptr<SgUctChildIterator>(it)); 00459 m_current = &(**it); 00460 return; 00461 } 00462 while (! m_stack.empty()) 00463 { 00464 SgUctChildIterator& it = *m_stack.top(); 00465 SG_ASSERT(it); 00466 ++it; 00467 if (it) 00468 { 00469 m_current = &(*it); 00470 return; 00471 } 00472 else 00473 { 00474 m_stack.pop(); 00475 m_current = 0; 00476 } 00477 } 00478 m_current = 0; 00479 } 00480 00481 SgUctTreeIterator::operator bool() const 00482 { 00483 return (m_current != 0); 00484 } 00485 00486 //----------------------------------------------------------------------------