00001
00002
00003
00004
00005
00006 #ifndef SG_UCTTREE_H
00007 #define SG_UCTTREE_H
00008
00009 #include <limits>
00010 #include <stack>
00011 #include <boost/shared_ptr.hpp>
00012 #include "SgMove.h"
00013 #include "SgStatistics.h"
00014 #include "SgStatisticsVlt.h"
00015 #include "SgUctValue.h"
00016
00017 class SgTimer;
00018
00019
00020
00021 typedef SgStatisticsBase<float,std::size_t> SgUctStatisticsBase;
00022
00023 typedef SgStatisticsVltBase<float,std::size_t> SgUctStatisticsBaseVolatile;
00024
00025
00026
00027
00028 struct SgUctMoveInfo
00029 {
00030
00031 SgMove m_move;
00032
00033
00034
00035
00036
00037 SgUctValue m_value;
00038
00039
00040 SgUctValue m_count;
00041
00042
00043
00044
00045 SgUctValue m_raveValue;
00046
00047
00048 SgUctValue m_raveCount;
00049
00050 SgUctMoveInfo();
00051
00052 SgUctMoveInfo(SgMove move);
00053
00054 SgUctMoveInfo(SgMove move, SgUctValue value, SgUctValue count,
00055 SgUctValue raveValue, SgUctValue raveCount);
00056 };
00057
00058 inline SgUctMoveInfo::SgUctMoveInfo()
00059 : m_value(0),
00060 m_count(0),
00061 m_raveValue(0),
00062 m_raveCount(0)
00063 {
00064 }
00065
00066 inline SgUctMoveInfo::SgUctMoveInfo(SgMove move)
00067 : m_move(move),
00068 m_value(0),
00069 m_count(0),
00070 m_raveValue(0),
00071 m_raveCount(0)
00072 {
00073 }
00074
00075 inline SgUctMoveInfo::SgUctMoveInfo(SgMove move, SgUctValue value, SgUctValue count,
00076 SgUctValue raveValue, SgUctValue raveCount)
00077 : m_move(move),
00078 m_value(value),
00079 m_count(count),
00080 m_raveValue(raveValue),
00081 m_raveCount(raveCount)
00082 {
00083 }
00084
00085
00086
00087
00088 typedef enum
00089 {
00090
00091 SG_NOT_PROVEN,
00092
00093
00094 SG_PROVEN_WIN,
00095
00096
00097 SG_PROVEN_LOSS
00098
00099 } SgUctProvenType;
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111 class SgUctNode
00112 {
00113 public:
00114
00115 SgUctNode(const SgUctMoveInfo& info);
00116
00117
00118
00119 void AddGameResult(SgUctValue eval);
00120
00121
00122 void AddGameResults(SgUctValue eval, SgUctValue count);
00123
00124
00125 void MergeResults(const SgUctNode& node);
00126
00127
00128
00129 void RemoveGameResult(SgUctValue eval);
00130
00131
00132 void RemoveGameResults(SgUctValue eval, SgUctValue count);
00133
00134
00135
00136
00137
00138 SgUctValue PosCount() const;
00139
00140
00141
00142
00143 SgUctValue MoveCount() const;
00144
00145
00146
00147
00148 const SgUctNode* FirstChild() const;
00149
00150
00151 bool HasChildren() const;
00152
00153
00154
00155 SgUctValue Mean() const;
00156
00157
00158 bool HasMean() const;
00159
00160
00161
00162
00163 int NuChildren() const;
00164
00165
00166 void SetFirstChild(const SgUctNode* child);
00167
00168
00169 void SetNuChildren(int nuChildren);
00170
00171
00172
00173 void IncPosCount();
00174
00175
00176
00177 void IncPosCount(SgUctValue count);
00178
00179
00180
00181 void DecPosCount();
00182
00183
00184
00185 void DecPosCount(SgUctValue count);
00186
00187 void SetPosCount(SgUctValue value);
00188
00189
00190 void InitializeValue(SgUctValue value, SgUctValue count);
00191
00192
00193
00194
00195 void CopyDataFrom(const SgUctNode& node);
00196
00197
00198
00199 SgMove Move() const;
00200
00201
00202
00203 SgUctValue RaveCount() const;
00204
00205
00206
00207
00208 SgUctValue RaveValue() const;
00209
00210 bool HasRaveValue() const;
00211
00212
00213
00214 void AddRaveValue(SgUctValue value, SgUctValue weight);
00215
00216
00217 void RemoveRaveValue(SgUctValue value);
00218
00219 void RemoveRaveValue(SgUctValue value, SgUctValue weight);
00220
00221
00222 void InitializeRaveValue(SgUctValue value, SgUctValue count);
00223
00224 int VirtualLossCount() const;
00225
00226 void AddVirtualLoss();
00227
00228 void RemoveVirtualLoss();
00229
00230
00231 SgUctValue KnowledgeCount() const;
00232
00233
00234 void SetKnowledgeCount(SgUctValue count);
00235
00236
00237 bool IsProven() const;
00238
00239 bool IsProvenWin() const;
00240
00241 bool IsProvenLoss() const;
00242
00243 SgUctProvenType ProvenType() const;
00244
00245 void SetProvenType(SgUctProvenType type);
00246
00247 private:
00248 SgUctStatisticsVolatile m_statistics;
00249
00250 const SgUctNode* volatile m_firstChild;
00251
00252 volatile int m_nuChildren;
00253
00254 volatile SgMove m_move;
00255
00256
00257
00258
00259 SgUctStatisticsVolatile m_raveValue;
00260
00261 volatile SgUctValue m_posCount;
00262
00263 volatile SgUctValue m_knowledgeCount;
00264
00265 volatile SgUctProvenType m_provenType;
00266
00267 volatile int m_virtualLossCount;
00268 };
00269
00270 inline SgUctNode::SgUctNode(const SgUctMoveInfo& info)
00271 : m_statistics(info.m_value, info.m_count),
00272 m_nuChildren(0),
00273 m_move(info.m_move),
00274 m_raveValue(info.m_raveValue, info.m_raveCount),
00275 m_posCount(0),
00276 m_knowledgeCount(0),
00277 m_provenType(SG_NOT_PROVEN),
00278 m_virtualLossCount(0)
00279 {
00280
00281 }
00282
00283 inline void SgUctNode::AddGameResult(SgUctValue eval)
00284 {
00285 m_statistics.Add(eval);
00286 }
00287
00288 inline void SgUctNode::AddGameResults(SgUctValue eval, SgUctValue count)
00289 {
00290 m_statistics.Add(eval, count);
00291 }
00292
00293 inline void SgUctNode::MergeResults(const SgUctNode& node)
00294 {
00295 if (node.m_statistics.IsDefined())
00296 m_statistics.Add(node.m_statistics.Mean(), node.m_statistics.Count());
00297 if (node.m_raveValue.IsDefined())
00298 m_raveValue.Add(node.m_raveValue.Mean(), node.m_raveValue.Count());
00299 }
00300
00301 inline void SgUctNode::RemoveGameResult(SgUctValue eval)
00302 {
00303 m_statistics.Remove(eval);
00304 }
00305
00306 inline void SgUctNode::RemoveGameResults(SgUctValue eval, SgUctValue count)
00307 {
00308 m_statistics.Remove(eval, count);
00309 }
00310
00311 inline void SgUctNode::AddRaveValue(SgUctValue value, SgUctValue weight)
00312 {
00313 m_raveValue.Add(value, weight);
00314 }
00315
00316 inline void SgUctNode::RemoveRaveValue(SgUctValue value)
00317 {
00318 m_raveValue.Remove(value);
00319 }
00320
00321 inline void SgUctNode::RemoveRaveValue(SgUctValue value, SgUctValue weight)
00322 {
00323 m_raveValue.Remove(value, weight);
00324 }
00325
00326 inline void SgUctNode::CopyDataFrom(const SgUctNode& node)
00327 {
00328 m_statistics = node.m_statistics;
00329 m_move = node.m_move;
00330 m_raveValue = node.m_raveValue;
00331 m_posCount = node.m_posCount;
00332 m_knowledgeCount = node.m_knowledgeCount;
00333 m_provenType = node.m_provenType;
00334 m_virtualLossCount = node.m_virtualLossCount;
00335 }
00336
00337 inline const SgUctNode* SgUctNode::FirstChild() const
00338 {
00339 SG_ASSERT(HasChildren());
00340 return m_firstChild;
00341 }
00342
00343 inline bool SgUctNode::HasChildren() const
00344 {
00345 return (m_nuChildren > 0);
00346 }
00347
00348 inline bool SgUctNode::HasMean() const
00349 {
00350 return m_statistics.IsDefined();
00351 }
00352
00353 inline bool SgUctNode::HasRaveValue() const
00354 {
00355 return m_raveValue.IsDefined();
00356 }
00357
00358 inline int SgUctNode::VirtualLossCount() const
00359 {
00360 return m_virtualLossCount;
00361 }
00362
00363 inline void SgUctNode::AddVirtualLoss()
00364 {
00365 m_virtualLossCount++;
00366 }
00367
00368 inline void SgUctNode::RemoveVirtualLoss()
00369 {
00370
00371
00372
00373 m_virtualLossCount--;
00374 }
00375
00376 inline void SgUctNode::IncPosCount()
00377 {
00378 ++m_posCount;
00379 }
00380
00381 inline void SgUctNode::IncPosCount(SgUctValue count)
00382 {
00383 m_posCount += count;
00384 }
00385
00386 inline void SgUctNode::DecPosCount()
00387 {
00388 SgUctValue posCount = m_posCount;
00389 if (posCount > 0)
00390 {
00391 m_posCount = posCount - 1;
00392 }
00393 }
00394
00395 inline void SgUctNode::DecPosCount(SgUctValue count)
00396 {
00397 SgUctValue posCount = m_posCount;
00398 if (posCount >= count)
00399 {
00400 m_posCount = posCount - count;
00401 }
00402 }
00403
00404 inline void SgUctNode::InitializeValue(SgUctValue value, SgUctValue count)
00405 {
00406 m_statistics.Initialize(value, count);
00407 }
00408
00409 inline void SgUctNode::InitializeRaveValue(SgUctValue value, SgUctValue count)
00410 {
00411 m_raveValue.Initialize(value, count);
00412 }
00413
00414 inline SgUctValue SgUctNode::Mean() const
00415 {
00416 return m_statistics.Mean();
00417 }
00418
00419 inline SgMove SgUctNode::Move() const
00420 {
00421 SG_ASSERT(m_move != SG_NULLMOVE);
00422 return m_move;
00423 }
00424
00425 inline SgUctValue SgUctNode::MoveCount() const
00426 {
00427 return m_statistics.Count();
00428 }
00429
00430 inline int SgUctNode::NuChildren() const
00431 {
00432 return m_nuChildren;
00433 }
00434
00435 inline SgUctValue SgUctNode::PosCount() const
00436 {
00437 return m_posCount;
00438 }
00439
00440 inline SgUctValue SgUctNode::RaveCount() const
00441 {
00442 return m_raveValue.Count();
00443 }
00444
00445 inline SgUctValue SgUctNode::RaveValue() const
00446 {
00447 return m_raveValue.Mean();
00448 }
00449
00450 inline void SgUctNode::SetFirstChild(const SgUctNode* child)
00451 {
00452 m_firstChild = child;
00453 }
00454
00455 inline void SgUctNode::SetNuChildren(int nuChildren)
00456 {
00457 SG_ASSERT(nuChildren >= 0);
00458 m_nuChildren = nuChildren;
00459 }
00460
00461 inline void SgUctNode::SetPosCount(SgUctValue value)
00462 {
00463 m_posCount = value;
00464 }
00465
00466 inline SgUctValue SgUctNode::KnowledgeCount() const
00467 {
00468 return m_knowledgeCount;
00469 }
00470
00471 inline void SgUctNode::SetKnowledgeCount(SgUctValue count)
00472 {
00473 m_knowledgeCount = count;
00474 }
00475
00476 inline bool SgUctNode::IsProven() const
00477 {
00478 return m_provenType != SG_NOT_PROVEN;
00479 }
00480
00481 inline bool SgUctNode::IsProvenWin() const
00482 {
00483 return m_provenType == SG_PROVEN_WIN;
00484 }
00485
00486 inline bool SgUctNode::IsProvenLoss() const
00487 {
00488 return m_provenType == SG_PROVEN_LOSS;
00489 }
00490
00491 inline SgUctProvenType SgUctNode::ProvenType() const
00492 {
00493 return m_provenType;
00494 }
00495
00496 inline void SgUctNode::SetProvenType(SgUctProvenType type)
00497 {
00498 m_provenType = type;
00499 }
00500
00501
00502
00503
00504
00505
00506
00507 class SgUctAllocator
00508 {
00509 public:
00510 SgUctAllocator();
00511
00512 ~SgUctAllocator();
00513
00514 void Clear();
00515
00516
00517 bool HasCapacity(std::size_t n) const;
00518
00519 std::size_t NuNodes() const;
00520
00521 std::size_t MaxNodes() const;
00522
00523 void SetMaxNodes(std::size_t maxNodes);
00524
00525
00526
00527
00528
00529
00530
00531 bool Contains(const SgUctNode& node) const;
00532
00533 const SgUctNode* Start() const;
00534
00535 SgUctNode* Finish();
00536
00537 const SgUctNode* Finish() const;
00538
00539
00540
00541
00542
00543 SgUctNode* CreateOne(SgMove move);
00544
00545
00546
00547
00548
00549 SgUctValue Create(const std::vector<SgUctMoveInfo>& moves);
00550
00551
00552
00553
00554 void CreateN(std::size_t n);
00555
00556 void Swap(SgUctAllocator& allocator);
00557
00558 private:
00559 SgUctNode* m_start;
00560
00561 SgUctNode* m_finish;
00562
00563 SgUctNode* m_endOfStorage;
00564
00565
00566
00567
00568 SgUctAllocator& operator=(const SgUctAllocator& tree);
00569 };
00570
00571 inline SgUctAllocator::SgUctAllocator()
00572 {
00573 m_start = 0;
00574 }
00575
00576 inline void SgUctAllocator::Clear()
00577 {
00578 if (m_start != 0)
00579 {
00580 for (SgUctNode* it = m_start; it != m_finish; ++it)
00581 it->~SgUctNode();
00582 m_finish = m_start;
00583 }
00584 }
00585
00586 inline SgUctNode* SgUctAllocator::CreateOne(SgMove move)
00587 {
00588 SG_ASSERT(HasCapacity(1));
00589 new(m_finish) SgUctNode(move);
00590 return (m_finish++);
00591 }
00592
00593 inline SgUctValue SgUctAllocator::Create(
00594 const std::vector<SgUctMoveInfo>& moves)
00595 {
00596 SG_ASSERT(HasCapacity(moves.size()));
00597 SgUctValue count = 0;
00598 for (std::vector<SgUctMoveInfo>::const_iterator it = moves.begin();
00599 it != moves.end(); ++it, ++m_finish)
00600 {
00601 new(m_finish) SgUctNode(*it);
00602 count += it->m_count;
00603 }
00604 return count;
00605 }
00606
00607 inline void SgUctAllocator::CreateN(std::size_t n)
00608 {
00609 SG_ASSERT(HasCapacity(n));
00610 SgUctNode* newFinish = m_finish + n;
00611 for ( ; m_finish != newFinish; ++m_finish)
00612 new(m_finish) SgUctNode(SG_NULLMOVE);
00613 }
00614
00615 inline SgUctNode* SgUctAllocator::Finish()
00616 {
00617 return m_finish;
00618 }
00619
00620 inline const SgUctNode* SgUctAllocator::Finish() const
00621 {
00622 return m_finish;
00623 }
00624
00625 inline bool SgUctAllocator::HasCapacity(std::size_t n) const
00626 {
00627 return (m_finish + n <= m_endOfStorage);
00628 }
00629
00630 inline std::size_t SgUctAllocator::MaxNodes() const
00631 {
00632 return m_endOfStorage - m_start;
00633 }
00634
00635 inline std::size_t SgUctAllocator::NuNodes() const
00636 {
00637 return m_finish - m_start;
00638 }
00639
00640 inline const SgUctNode* SgUctAllocator::Start() const
00641 {
00642 return m_start;
00643 }
00644
00645
00646
00647
00648
00649
00650
00651
00652
00653
00654 class SgUctTree
00655 {
00656 public:
00657 friend class SgUctChildIterator;
00658
00659
00660
00661
00662 SgUctTree();
00663
00664
00665 void CreateAllocators(std::size_t nuThreads);
00666
00667
00668
00669
00670
00671 void AddGameResult(const SgUctNode& node, const SgUctNode* father,
00672 SgUctValue eval);
00673
00674
00675 void AddGameResults(const SgUctNode& node, const SgUctNode* father,
00676 SgUctValue eval, SgUctValue count);
00677
00678
00679
00680
00681
00682 void RemoveGameResult(const SgUctNode& node, const SgUctNode* father,
00683 SgUctValue eval);
00684
00685
00686 void RemoveGameResults(const SgUctNode& node, const SgUctNode* father,
00687 SgUctValue eval, SgUctValue count);
00688
00689
00690 void AddVirtualLoss(const SgUctNode &node);
00691
00692
00693 void RemoveVirtualLoss(const SgUctNode &node);
00694
00695 void SetProvenType(const SgUctNode& node, SgUctProvenType type);
00696
00697 void SetKnowledgeCount(const SgUctNode& node, SgUctValue count);
00698
00699 void Clear();
00700
00701
00702
00703
00704
00705 std::size_t MaxNodes() const;
00706
00707
00708
00709
00710
00711
00712
00713
00714 void SetMaxNodes(std::size_t maxNodes);
00715
00716
00717
00718
00719 void Swap(SgUctTree& tree);
00720
00721 bool HasCapacity(std::size_t allocatorId, std::size_t n) const;
00722
00723
00724
00725 void CreateChildren(std::size_t allocatorId, const SgUctNode& node,
00726 const std::vector<SgUctMoveInfo>& moves);
00727
00728
00729
00730 void MergeChildren(std::size_t allocatorId, const SgUctNode& node,
00731 const std::vector<SgUctMoveInfo>& moves,
00732 bool deleteChildTrees);
00733
00734
00735
00736
00737
00738
00739
00740
00741
00742
00743
00744
00745 void ExtractSubtree(SgUctTree& target, const SgUctNode& node,
00746 bool warnTruncate,
00747 double maxTime = std::numeric_limits<double>::max(),
00748 SgUctValue minCount = 0) const;
00749
00750
00751
00752
00753
00754
00755
00756
00757
00758
00759
00760 void CopyPruneLowCount(SgUctTree& target, SgUctValue minCount,
00761 bool warnTruncate,
00762 double maxTime = std::numeric_limits<double>::max()) const;
00763
00764 const SgUctNode& Root() const;
00765
00766 std::size_t NuAllocators() const;
00767
00768
00769
00770 std::size_t NuNodes() const;
00771
00772
00773 std::size_t NuNodes(std::size_t allocatorId) const;
00774
00775
00776
00777
00778
00779
00780 void AddRaveValue(const SgUctNode& node, SgUctValue value, SgUctValue weight);
00781
00782
00783
00784
00785
00786
00787 void RemoveRaveValue(const SgUctNode& node, SgUctValue value, SgUctValue weight);
00788
00789
00790 void InitializeValue(const SgUctNode& node, SgUctValue value,
00791 SgUctValue count);
00792
00793 void SetPosCount(const SgUctNode& node, SgUctValue posCount);
00794
00795
00796
00797 void InitializeRaveValue(const SgUctNode& node, SgUctValue value, SgUctValue count);
00798
00799
00800
00801
00802
00803
00804 void ApplyFilter(std::size_t allocatorId, const SgUctNode& node,
00805 const std::vector<SgMove>& rootFilter);
00806
00807
00808
00809
00810
00811 void SetChildren(std::size_t allocatorId, const SgUctNode& node,
00812 const vector<SgMove>& moves);
00813
00814
00815
00816
00817
00818
00819 void CheckConsistency() const;
00820
00821
00822
00823
00824
00825
00826
00827 bool Contains(const SgUctNode& node) const;
00828
00829 void DumpDebugInfo(std::ostream& out) const;
00830
00831
00832
00833 private:
00834 std::size_t m_maxNodes;
00835
00836 SgUctNode m_root;
00837
00838
00839
00840
00841 std::vector<boost::shared_ptr<SgUctAllocator> > m_allocators;
00842
00843
00844
00845
00846 SgUctTree& operator=(const SgUctTree& tree);
00847
00848 SgUctAllocator& Allocator(std::size_t i);
00849
00850 const SgUctAllocator& Allocator(std::size_t i) const;
00851
00852 bool CopySubtree(SgUctTree& target, SgUctNode& targetNode,
00853 const SgUctNode& node, SgUctValue minCount,
00854 std::size_t& currentAllocatorId, bool warnTruncate,
00855 bool& abort, SgTimer& timer, double maxTime,
00856 bool alwaysKeepProven) const;
00857
00858 void ThrowConsistencyError(const std::string& message) const;
00859 };
00860
00861 inline void SgUctTree::AddGameResult(const SgUctNode& node,
00862 const SgUctNode* father, SgUctValue eval)
00863 {
00864 SG_ASSERT(Contains(node));
00865
00866
00867 if (father != 0)
00868 const_cast<SgUctNode*>(father)->IncPosCount();
00869 const_cast<SgUctNode&>(node).AddGameResult(eval);
00870 }
00871
00872 inline void SgUctTree::AddGameResults(const SgUctNode& node,
00873 const SgUctNode* father, SgUctValue eval,
00874 SgUctValue count)
00875 {
00876
00877 SG_ASSERT(Contains(node));
00878
00879
00880 if (father != 0)
00881 const_cast<SgUctNode*>(father)->IncPosCount(count);
00882 const_cast<SgUctNode&>(node).AddGameResults(eval, count);
00883 }
00884
00885 inline void SgUctTree::CreateChildren(std::size_t allocatorId,
00886 const SgUctNode& node,
00887 const std::vector<SgUctMoveInfo>& moves)
00888 {
00889 SG_ASSERT(Contains(node));
00890
00891
00892 SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
00893 SG_ASSERT(moves.size() <= std::size_t(std::numeric_limits<int>::max()));
00894 int nuChildren = int(moves.size());
00895 SG_ASSERT(nuChildren > 0);
00896 SgUctAllocator& allocator = Allocator(allocatorId);
00897 SG_ASSERT(allocator.HasCapacity(nuChildren));
00898
00899
00900
00901
00902 SG_ASSERT(NuAllocators() > 1 || ! node.HasChildren());
00903
00904 const SgUctNode* firstChild = allocator.Finish();
00905
00906 SgUctValue parentCount = allocator.Create(moves);
00907
00908
00909
00910 nonConstNode.SetPosCount(parentCount);
00911 SgSynchronizeThreadMemory();
00912 nonConstNode.SetFirstChild(firstChild);
00913 SgSynchronizeThreadMemory();
00914 nonConstNode.SetNuChildren(nuChildren);
00915 }
00916
00917 inline void SgUctTree::RemoveGameResult(const SgUctNode& node,
00918 const SgUctNode* father, SgUctValue eval)
00919 {
00920 SG_ASSERT(Contains(node));
00921
00922
00923 if (father != 0)
00924 const_cast<SgUctNode*>(father)->DecPosCount();
00925 const_cast<SgUctNode&>(node).RemoveGameResult(eval);
00926 }
00927
00928 inline void SgUctTree::RemoveGameResults(const SgUctNode& node,
00929 const SgUctNode* father, SgUctValue eval,
00930 SgUctValue count)
00931 {
00932 SG_ASSERT(Contains(node));
00933
00934
00935 if (father != 0)
00936 const_cast<SgUctNode*>(father)->DecPosCount(count);
00937 const_cast<SgUctNode&>(node).RemoveGameResults(eval, count);
00938 }
00939
00940 inline void SgUctTree::AddVirtualLoss(const SgUctNode& node)
00941 {
00942 const_cast<SgUctNode&>(node).AddVirtualLoss();
00943 }
00944
00945 inline void SgUctTree::RemoveVirtualLoss(const SgUctNode& node)
00946 {
00947 const_cast<SgUctNode&>(node).RemoveVirtualLoss();
00948 }
00949
00950 inline void SgUctTree::AddRaveValue(const SgUctNode& node, SgUctValue value,
00951 SgUctValue weight)
00952 {
00953 SG_ASSERT(Contains(node));
00954
00955
00956 const_cast<SgUctNode&>(node).AddRaveValue(value, weight);
00957 }
00958
00959 inline void SgUctTree::RemoveRaveValue(const SgUctNode& node, SgUctValue value,
00960 SgUctValue weight)
00961 {
00962 SG_UNUSED(weight);
00963 SG_ASSERT(Contains(node));
00964
00965
00966 const_cast<SgUctNode&>(node).RemoveRaveValue(value, weight);
00967 }
00968
00969 inline SgUctAllocator& SgUctTree::Allocator(std::size_t i)
00970 {
00971 SG_ASSERT(i < m_allocators.size());
00972 return *m_allocators[i];
00973 }
00974
00975 inline const SgUctAllocator& SgUctTree::Allocator(std::size_t i) const
00976 {
00977 SG_ASSERT(i < m_allocators.size());
00978 return *m_allocators[i];
00979 }
00980
00981 inline bool SgUctTree::HasCapacity(std::size_t allocatorId,
00982 std::size_t n) const
00983 {
00984 return Allocator(allocatorId).HasCapacity(n);
00985 }
00986
00987 inline void SgUctTree::InitializeValue(const SgUctNode& node,
00988 SgUctValue value, SgUctValue count)
00989 {
00990 SG_ASSERT(Contains(node));
00991
00992
00993 const_cast<SgUctNode&>(node).InitializeValue(value, count);
00994 }
00995
00996 inline void SgUctTree::InitializeRaveValue(const SgUctNode& node,
00997 SgUctValue value, SgUctValue count)
00998 {
00999 SG_ASSERT(Contains(node));
01000
01001
01002 const_cast<SgUctNode&>(node).InitializeRaveValue(value, count);
01003 }
01004
01005 inline std::size_t SgUctTree::MaxNodes() const
01006 {
01007 return m_maxNodes;
01008 }
01009
01010 inline std::size_t SgUctTree::NuAllocators() const
01011 {
01012 return m_allocators.size();
01013 }
01014
01015 inline std::size_t SgUctTree::NuNodes(std::size_t allocatorId) const
01016 {
01017 return Allocator(allocatorId).NuNodes();
01018 }
01019
01020 inline const SgUctNode& SgUctTree::Root() const
01021 {
01022 return m_root;
01023 }
01024
01025 inline void SgUctTree::SetKnowledgeCount(const SgUctNode& node,
01026 SgUctValue count)
01027 {
01028 SG_ASSERT(Contains(node));
01029
01030
01031 const_cast<SgUctNode&>(node).SetKnowledgeCount(count);
01032 }
01033
01034 inline void SgUctTree::SetPosCount(const SgUctNode& node,
01035 SgUctValue posCount)
01036 {
01037 SG_ASSERT(Contains(node));
01038
01039
01040 const_cast<SgUctNode&>(node).SetPosCount(posCount);
01041 }
01042
01043 inline void SgUctTree::SetProvenType(const SgUctNode &node,
01044 SgUctProvenType type)
01045 {
01046 SG_ASSERT(Contains(node));
01047
01048
01049 const_cast<SgUctNode&>(node).SetProvenType(type);
01050 }
01051
01052
01053
01054
01055
01056
01057
01058
01059
01060 class SgUctChildIterator
01061 {
01062 public:
01063
01064
01065 SgUctChildIterator(const SgUctTree& tree, const SgUctNode& node);
01066
01067 const SgUctNode& operator*() const;
01068
01069 void operator++();
01070
01071 operator bool() const;
01072
01073 private:
01074 const SgUctNode* m_current;
01075
01076 const SgUctNode* m_last;
01077 };
01078
01079 inline SgUctChildIterator::SgUctChildIterator(const SgUctTree& tree,
01080 const SgUctNode& node)
01081 {
01082 SG_DEBUG_ONLY(tree);
01083 SG_ASSERT(tree.Contains(node));
01084 SG_ASSERT(node.HasChildren());
01085 m_current = node.FirstChild();
01086 m_last = m_current + node.NuChildren();
01087 }
01088
01089 inline const SgUctNode& SgUctChildIterator::operator*() const
01090 {
01091 return *m_current;
01092 }
01093
01094 inline void SgUctChildIterator::operator++()
01095 {
01096 ++m_current;
01097 }
01098
01099 inline SgUctChildIterator::operator bool() const
01100 {
01101 return (m_current < m_last);
01102 }
01103
01104
01105
01106
01107
01108 class SgUctTreeIterator
01109 {
01110 public:
01111 SgUctTreeIterator(const SgUctTree& tree);
01112
01113 const SgUctNode& operator*() const;
01114
01115 void operator++();
01116
01117 operator bool() const;
01118
01119 private:
01120 const SgUctTree& m_tree;
01121
01122 const SgUctNode* m_current;
01123
01124
01125
01126
01127 std::stack<boost::shared_ptr<SgUctChildIterator> > m_stack;
01128 };
01129
01130
01131
01132 #endif // SG_UCTTREE_H