00001
00002
00003
00004
00005
00006 #include "SgSystem.h"
00007 #include "SgUctTree.h"
00008
00009 #include <boost/format.hpp>
00010 #include "SgDebug.h"
00011 #include "SgTimer.h"
00012
00013 using namespace std;
00014 using boost::format;
00015 using boost::shared_ptr;
00016
00017
00018
00019 SgUctAllocator::~SgUctAllocator()
00020 {
00021 if (m_start != 0)
00022 {
00023 Clear();
00024 std::free(m_start);
00025 }
00026 }
00027
00028 bool SgUctAllocator::Contains(const SgUctNode& node) const
00029 {
00030 return (&node >= m_start && &node < m_finish);
00031 }
00032
00033 void SgUctAllocator::Swap(SgUctAllocator& allocator)
00034 {
00035 swap(m_start, allocator.m_start);
00036 swap(m_finish, allocator.m_finish);
00037 swap(m_endOfStorage, allocator.m_endOfStorage);
00038 }
00039
00040 void SgUctAllocator::SetMaxNodes(std::size_t maxNodes)
00041 {
00042 if (m_start != 0)
00043 {
00044 Clear();
00045 std::free(m_start);
00046 }
00047 void* ptr = std::malloc(maxNodes * sizeof(SgUctNode));
00048 if (ptr == 0)
00049 throw std::bad_alloc();
00050 m_start = static_cast<SgUctNode*>(ptr);
00051 m_finish = m_start;
00052 m_endOfStorage = m_start + maxNodes;
00053 }
00054
00055
00056
00057 SgUctTree::SgUctTree()
00058 : m_maxNodes(0),
00059 m_root(SG_NULLMOVE)
00060 {
00061 }
00062
00063 void SgUctTree::ApplyFilter(std::size_t allocatorId, const SgUctNode& node,
00064 const vector<SgMove>& rootFilter)
00065 {
00066 SG_ASSERT(Contains(node));
00067 SG_ASSERT(Allocator(allocatorId).HasCapacity(node.NuChildren()));
00068 if (! node.HasChildren())
00069 return;
00070
00071 SgUctAllocator& allocator = Allocator(allocatorId);
00072 const SgUctNode* firstChild = allocator.Finish();
00073
00074 int nuChildren = 0;
00075 for (SgUctChildIterator it(*this, node); it; ++it)
00076 {
00077 SgMove move = (*it).Move();
00078 if (find(rootFilter.begin(), rootFilter.end(), move)
00079 == rootFilter.end())
00080 {
00081 SgUctNode* child = allocator.CreateOne(move);
00082 child->CopyDataFrom(*it);
00083 int childNuChildren = (*it).NuChildren();
00084 child->SetNuChildren(childNuChildren);
00085 if (childNuChildren > 0)
00086 child->SetFirstChild((*it).FirstChild());
00087 ++nuChildren;
00088 }
00089 }
00090
00091 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00092
00093
00094 SgSynchronizeThreadMemory();
00095 nonConstNode.SetFirstChild(firstChild);
00096 SgSynchronizeThreadMemory();
00097 nonConstNode.SetNuChildren(nuChildren);
00098 }
00099
00100 void SgUctTree::SetChildren(std::size_t allocatorId, const SgUctNode& node,
00101 const vector<SgMove>& moves)
00102 {
00103 SG_ASSERT(Contains(node));
00104 SG_ASSERT(Allocator(allocatorId).HasCapacity(moves.size()));
00105 SG_ASSERT(node.HasChildren());
00106
00107 SgUctAllocator& allocator = Allocator(allocatorId);
00108 const SgUctNode* firstChild = allocator.Finish();
00109
00110 int nuChildren = 0;
00111 for (size_t i = 0; i < moves.size(); ++i)
00112 {
00113 bool found = false;
00114 for (SgUctChildIterator it(*this, node); it; ++it)
00115 {
00116 SgMove move = (*it).Move();
00117 if (move == moves[i])
00118 {
00119 found = true;
00120 SgUctNode* child = allocator.CreateOne(move);
00121 child->CopyDataFrom(*it);
00122 int childNuChildren = (*it).NuChildren();
00123 child->SetNuChildren(childNuChildren);
00124 if (childNuChildren > 0)
00125 child->SetFirstChild((*it).FirstChild());
00126 ++nuChildren;
00127 break;
00128 }
00129 }
00130 if (! found)
00131 {
00132 allocator.CreateOne(moves[i]);
00133 ++nuChildren;
00134 }
00135 }
00136 SG_ASSERT((size_t)nuChildren == moves.size());
00137
00138 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00139
00140
00141 SgSynchronizeThreadMemory();
00142 nonConstNode.SetFirstChild(firstChild);
00143 SgSynchronizeThreadMemory();
00144 nonConstNode.SetNuChildren(nuChildren);
00145 }
00146
00147 void SgUctTree::CheckConsistency() const
00148 {
00149 for (SgUctTreeIterator it(*this); it; ++it)
00150 if (! Contains(*it))
00151 ThrowConsistencyError(str(format("! Contains(%1%)") % &(*it)));
00152 }
00153
00154 void SgUctTree::Clear()
00155 {
00156 for (size_t i = 0; i < NuAllocators(); ++i)
00157 Allocator(i).Clear();
00158 m_root = SgUctNode(SG_NULLMOVE);
00159 }
00160
00161
00162
00163 bool SgUctTree::Contains(const SgUctNode& node) const
00164 {
00165 if (&node == &m_root)
00166 return true;
00167 for (size_t i = 0; i < NuAllocators(); ++i)
00168 if (Allocator(i).Contains(node))
00169 return true;
00170 return false;
00171 }
00172
00173 void SgUctTree::CopyPruneLowCount(SgUctTree& target, SgUctValue minCount,
00174 bool warnTruncate, double maxTime) const
00175 {
00176 size_t allocatorId = 0;
00177 SgTimer timer;
00178 bool abort = false;
00179 CopySubtree(target, target.m_root, m_root, minCount, allocatorId,
00180 warnTruncate, abort, timer, maxTime,
00181 false);
00182 SgSynchronizeThreadMemory();
00183 }
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202 bool SgUctTree::CopySubtree(SgUctTree& target, SgUctNode& targetNode,
00203 const SgUctNode& node, SgUctValue minCount,
00204 std::size_t& currentAllocatorId,
00205 bool warnTruncate, bool& abort, SgTimer& timer,
00206 double maxTime, bool alwaysKeepProven) const
00207
00208 {
00209 SG_ASSERT(Contains(node));
00210 SG_ASSERT(target.Contains(targetNode));
00211 targetNode.CopyDataFrom(node);
00212
00213 if (! node.HasChildren())
00214 return true;
00215
00216 if (node.IsProven())
00217 {
00218 if (!alwaysKeepProven && node.MoveCount() < minCount)
00219 {
00220 targetNode.SetProvenType(SG_NOT_PROVEN);
00221 return false;
00222 }
00223 }
00224 else
00225 {
00226 if (node.MoveCount() < minCount)
00227 return false;
00228 }
00229
00230 SgUctAllocator& targetAllocator = target.Allocator(currentAllocatorId);
00231 int nuChildren = node.NuChildren();
00232 if (! abort)
00233 {
00234 if (! targetAllocator.HasCapacity(nuChildren))
00235 {
00236
00237
00238 if (warnTruncate)
00239 SgDebug() <<
00240 "SgUctTree::CopySubtree: Truncated (allocator capacity)\n";
00241 abort = true;
00242 }
00243 if (timer.IsTimeOut(maxTime, 10000))
00244 {
00245 if (warnTruncate)
00246 SgDebug() << "SgUctTree::CopySubtree: Truncated (max time)\n";
00247 abort = true;
00248 }
00249 if (SgUserAbort())
00250 {
00251 if (warnTruncate)
00252 SgDebug() << "SgUctTree::CopySubtree: Truncated (aborted)\n";
00253 abort = true;
00254 }
00255 }
00256 if (abort)
00257 {
00258
00259
00260 targetNode.SetPosCount(0);
00261 if (targetNode.IsProven())
00262 targetNode.SetProvenType(SG_NOT_PROVEN);
00263 return false;
00264 }
00265
00266 SgUctNode* firstTargetChild = targetAllocator.Finish();
00267 targetNode.SetFirstChild(firstTargetChild);
00268 targetNode.SetNuChildren(nuChildren);
00269
00270
00271 targetAllocator.CreateN(nuChildren);
00272
00273
00274 bool copiedCompleteTree = true;
00275 SgUctNode* targetChild = firstTargetChild;
00276 for (SgUctChildIterator it(*this, node); it; ++it, ++targetChild)
00277 {
00278 const SgUctNode& child = *it;
00279 ++currentAllocatorId;
00280 if (currentAllocatorId >= target.NuAllocators())
00281 currentAllocatorId = 0;
00282 copiedCompleteTree &= CopySubtree(target, *targetChild, child,
00283 minCount, currentAllocatorId,
00284 warnTruncate, abort, timer,
00285 maxTime, alwaysKeepProven);
00286 }
00287 if (!copiedCompleteTree && targetNode.IsProven())
00288 targetNode.SetProvenType(SG_NOT_PROVEN);
00289 return copiedCompleteTree;
00290 }
00291
00292 void SgUctTree::CreateAllocators(std::size_t nuThreads)
00293 {
00294 Clear();
00295 m_allocators.clear();
00296 for (size_t i = 0; i < nuThreads; ++i)
00297 {
00298 boost::shared_ptr<SgUctAllocator> allocator(new SgUctAllocator());
00299 m_allocators.push_back(allocator);
00300 }
00301 }
00302
00303 void SgUctTree::DumpDebugInfo(std::ostream& out) const
00304 {
00305 out << "Root " << &m_root << '\n';
00306 for (size_t i = 0; i < NuAllocators(); ++i)
00307 out << "Allocator " << i
00308 << " size=" << Allocator(i).NuNodes()
00309 << " start=" << Allocator(i).Start()
00310 << " finish=" << Allocator(i).Finish() << '\n';
00311 }
00312
00313 void SgUctTree::ExtractSubtree(SgUctTree& target, const SgUctNode& node,
00314 bool warnTruncate, double maxTime,
00315 SgUctValue minCount) const
00316 {
00317 SG_ASSERT(Contains(node));
00318 SG_ASSERT(&target != this);
00319 SG_ASSERT(target.MaxNodes() == MaxNodes());
00320 target.Clear();
00321 size_t allocatorId = 0;
00322 SgTimer timer;
00323 bool abort = false;
00324 CopySubtree(target, target.m_root, node, minCount, allocatorId, warnTruncate,
00325 abort, timer, maxTime, true);
00326 SgSynchronizeThreadMemory();
00327 }
00328
00329 void SgUctTree::MergeChildren(std::size_t allocatorId, const SgUctNode& node,
00330 const std::vector<SgUctMoveInfo>& moves,
00331 bool deleteChildTrees)
00332 {
00333 SG_ASSERT(Contains(node));
00334
00335
00336 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00337 SG_ASSERT(moves.size() <= std::size_t(std::numeric_limits<int>::max()));
00338 int nuNewChildren = int(moves.size());
00339
00340 if (nuNewChildren == 0)
00341 {
00342
00343 nonConstNode.SetNuChildren(0);
00344 SgSynchronizeThreadMemory();
00345 nonConstNode.SetFirstChild(0);
00346 return;
00347 }
00348
00349 SgUctAllocator& allocator = Allocator(allocatorId);
00350 SG_ASSERT(allocator.HasCapacity(nuNewChildren));
00351
00352 const SgUctNode* newFirstChild = allocator.Finish();
00353 SgUctValue parentCount = allocator.Create(moves);
00354
00355
00356 for (std::size_t i = 0; i < moves.size(); ++i)
00357 {
00358 SgUctNode* newChild = const_cast<SgUctNode*>(&newFirstChild[i]);
00359 for (SgUctChildIterator it(*this, node); it; ++it)
00360 {
00361 const SgUctNode& oldChild = *it;
00362 if (oldChild.Move() == moves[i].m_move)
00363 {
00364 newChild->MergeResults(oldChild);
00365 newChild->SetKnowledgeCount(oldChild.KnowledgeCount());
00366 if (! deleteChildTrees)
00367 {
00368 newChild->SetPosCount(oldChild.PosCount());
00369 parentCount += oldChild.MoveCount();
00370 if (oldChild.HasChildren())
00371 {
00372 newChild->SetFirstChild(oldChild.FirstChild());
00373 newChild->SetNuChildren(oldChild.NuChildren());
00374 }
00375 }
00376 break;
00377 }
00378 }
00379 }
00380 nonConstNode.SetPosCount(parentCount);
00381
00382
00383
00384
00385
00386 SgSynchronizeThreadMemory();
00387 if (nonConstNode.NuChildren() < nuNewChildren)
00388 {
00389 nonConstNode.SetFirstChild(newFirstChild);
00390 SgSynchronizeThreadMemory();
00391 nonConstNode.SetNuChildren(nuNewChildren);
00392 }
00393 else
00394 {
00395 nonConstNode.SetNuChildren(nuNewChildren);
00396 SgSynchronizeThreadMemory();
00397 nonConstNode.SetFirstChild(newFirstChild);
00398 }
00399 }
00400
00401 std::size_t SgUctTree::NuNodes() const
00402 {
00403 size_t nuNodes = 1;
00404 for (size_t i = 0; i < NuAllocators(); ++i)
00405 nuNodes += Allocator(i).NuNodes();
00406 return nuNodes;
00407 }
00408
00409 void SgUctTree::SetMaxNodes(std::size_t maxNodes)
00410 {
00411 Clear();
00412 size_t nuAllocators = NuAllocators();
00413 if (nuAllocators == 0)
00414 {
00415 SgDebug() << "SgUctTree::SetMaxNodes: no allocators registered\n";
00416 SG_ASSERT(false);
00417 return;
00418 }
00419 m_maxNodes = maxNodes;
00420 size_t maxNodesPerAlloc = maxNodes / nuAllocators;
00421 for (size_t i = 0; i < NuAllocators(); ++i)
00422 Allocator(i).SetMaxNodes(maxNodesPerAlloc);
00423 }
00424
00425 void SgUctTree::Swap(SgUctTree& tree)
00426 {
00427 SG_ASSERT(MaxNodes() == tree.MaxNodes());
00428 SG_ASSERT(NuAllocators() == tree.NuAllocators());
00429 swap(m_root, tree.m_root);
00430 for (size_t i = 0; i < NuAllocators(); ++i)
00431 Allocator(i).Swap(tree.Allocator(i));
00432 }
00433
00434 void SgUctTree::ThrowConsistencyError(const string& message) const
00435 {
00436 DumpDebugInfo(SgDebug());
00437 throw SgException("SgUctTree::ThrowConsistencyError: " + message);
00438 }
00439
00440
00441
00442 SgUctTreeIterator::SgUctTreeIterator(const SgUctTree& tree)
00443 : m_tree(tree),
00444 m_current(&tree.Root())
00445 {
00446 }
00447
00448 const SgUctNode& SgUctTreeIterator::operator*() const
00449 {
00450 return *m_current;
00451 }
00452
00453 void SgUctTreeIterator::operator++()
00454 {
00455 if (m_current->HasChildren())
00456 {
00457 SgUctChildIterator* it = new SgUctChildIterator(m_tree, *m_current);
00458 m_stack.push(shared_ptr<SgUctChildIterator>(it));
00459 m_current = &(**it);
00460 return;
00461 }
00462 while (! m_stack.empty())
00463 {
00464 SgUctChildIterator& it = *m_stack.top();
00465 SG_ASSERT(it);
00466 ++it;
00467 if (it)
00468 {
00469 m_current = &(*it);
00470 return;
00471 }
00472 else
00473 {
00474 m_stack.pop();
00475 m_current = 0;
00476 }
00477 }
00478 m_current = 0;
00479 }
00480
00481 SgUctTreeIterator::operator bool() const
00482 {
00483 return (m_current != 0);
00484 }
00485
00486