00001 //---------------------------------------------------------------------------- 00002 /** @file SgUctTreeUtil.cpp */ 00003 //---------------------------------------------------------------------------- 00004 00005 #include "SgSystem.h" 00006 #include "SgUctTreeUtil.h" 00007 00008 #include <iomanip> 00009 #include "SgUctSearch.h" 00010 #include "SgWrite.h" 00011 00012 using namespace std; 00013 00014 //---------------------------------------------------------------------------- 00015 00016 SgUctTreeStatistics::SgUctTreeStatistics() 00017 { 00018 Clear(); 00019 } 00020 00021 void SgUctTreeStatistics::Clear() 00022 { 00023 m_nuNodes = 0; 00024 for (size_t i = 0; i < (size_t)MAX_MOVECOUNT; ++i) 00025 m_moveCounts[i] = 0; 00026 m_biasRave.Clear(); 00027 } 00028 00029 void SgUctTreeStatistics::Compute(const SgUctTree& tree) 00030 { 00031 Clear(); 00032 for (SgUctTreeIterator it(tree); it; ++it) 00033 { 00034 const SgUctNode& node = *it; 00035 ++m_nuNodes; 00036 SgUctValue count = node.MoveCount(); 00037 if (count < (SgUctValue)SgUctTreeStatistics::MAX_MOVECOUNT) 00038 ++m_moveCounts[(size_t)count]; 00039 if (! node.HasChildren()) 00040 continue; 00041 for (SgUctChildIterator it(tree, node); it; ++it) 00042 { 00043 const SgUctNode& child = *it; 00044 if (child.HasRaveValue() && child.HasMean()) 00045 { 00046 SgUctValue childValue = 00047 SgUctSearch::InverseEstimate(child.Mean()); 00048 SgUctValue biasRave = child.RaveValue() - childValue; 00049 m_biasRave.Add(biasRave); 00050 } 00051 } 00052 } 00053 } 00054 00055 void SgUctTreeStatistics::Write(ostream& out) const 00056 { 00057 out << SgWriteLabel("NuNodes") << m_nuNodes << '\n'; 00058 for (size_t i = 0; i < MAX_MOVECOUNT; ++i) 00059 { 00060 ostringstream label; 00061 label << "MoveCount[" << i << ']'; 00062 size_t percent = m_moveCounts[i] * 100 / m_nuNodes; 00063 out << SgWriteLabel(label.str()) << setw(2) << right << percent 00064 << "%\n"; 00065 } 00066 out << SgWriteLabel("BiasRave"); 00067 m_biasRave.Write(out); 00068 out << '\n'; 00069 } 00070 00071 std::ostream& operator<<(ostream& out, const SgUctTreeStatistics& stat) 00072 { 00073 stat.Write(out); 00074 return out; 00075 } 00076 00077 //---------------------------------------------------------------------------- 00078 00079 void SgUctTreeUtil::ExtractSubtree(const SgUctTree& tree, SgUctTree& target, 00080 const std::vector<SgMove>& sequence, 00081 bool warnTruncate, double maxTime, 00082 SgUctValue minCount) 00083 { 00084 target.Clear(); 00085 const SgUctNode* node = &tree.Root(); 00086 for (vector<SgMove>::const_iterator it = sequence.begin(); 00087 it != sequence.end(); ++it) 00088 { 00089 SgMove mv = *it; 00090 node = SgUctTreeUtil::FindChildWithMove(tree, *node, mv); 00091 if (node == 0) 00092 return; 00093 } 00094 tree.ExtractSubtree(target, *node, warnTruncate, maxTime, minCount); 00095 } 00096 00097 const SgUctNode* SgUctTreeUtil::FindChildWithMove(const SgUctTree& tree, 00098 const SgUctNode& node, 00099 SgMove move) 00100 { 00101 if (! node.HasChildren()) 00102 return 0; 00103 for (SgUctChildIterator it(tree, node); it; ++it) 00104 { 00105 const SgUctNode& child = *it; 00106 if (child.Move() == move) 00107 return &child; 00108 } 00109 return 0; 00110 } 00111 00112 //----------------------------------------------------------------------------