00001 //---------------------------------------------------------------------------- 00002 /** @file GoAutoBook.cpp */ 00003 //---------------------------------------------------------------------------- 00004 00005 #include "SgSystem.h" 00006 #include "GoAutoBook.h" 00007 00008 //---------------------------------------------------------------------------- 00009 00010 GoAutoBookState::GoAutoBookState(const GoBoard& brd) 00011 : m_synchronizer(brd) 00012 { 00013 m_synchronizer.SetSubscriber(m_brd[0]); 00014 } 00015 00016 GoAutoBookState::~GoAutoBookState() 00017 { 00018 } 00019 00020 SgHashCode GoAutoBookState::GetHashCode() const 00021 { 00022 return m_hash; 00023 } 00024 00025 void GoAutoBookState::Synchronize() 00026 { 00027 m_synchronizer.UpdateSubscriber(); 00028 int size = m_brd[0].Size(); 00029 int numMoves = m_brd[0].MoveNumber(); 00030 for (int rot = 1; rot < 8; ++rot) 00031 { 00032 m_brd[rot].Init(size, size); 00033 for (int i = 0; i < numMoves; ++i) 00034 { 00035 SgMove move = m_brd[0].Move(i).Point(); 00036 m_brd[rot].Play(SgPointUtil::Rotate(rot, move, size)); 00037 } 00038 } 00039 ComputeHashCode(); 00040 } 00041 00042 void GoAutoBookState::Play(SgMove move) 00043 { 00044 m_hash = m_brd[0].GetHashCodeInclToPlay(); 00045 for (int rot = 0; rot < 8; ++rot) 00046 m_brd[rot].Play(SgPointUtil::Rotate(rot, move, m_brd[0].Size())); 00047 ComputeHashCode(); 00048 } 00049 00050 void GoAutoBookState::Undo() 00051 { 00052 for (int rot = 0; rot < 8; ++rot) 00053 m_brd[rot].Undo(); 00054 ComputeHashCode(); 00055 } 00056 00057 void GoAutoBookState::ComputeHashCode() 00058 { 00059 m_hash = m_brd[0].GetHashCodeInclToPlay(); 00060 for (int rot = 1; rot < 8; ++rot) 00061 { 00062 SgHashCode curHash = m_brd[rot].GetHashCodeInclToPlay(); 00063 if (curHash < m_hash) 00064 m_hash = curHash; 00065 } 00066 } 00067 00068 //---------------------------------------------------------------------------- 00069 00070 GoAutoBookParam::GoAutoBookParam() 00071 : m_usageCountThreshold(0), 00072 m_selectType(GO_AUTOBOOK_SELECT_VALUE) 00073 { 00074 } 00075 00076 //---------------------------------------------------------------------------- 00077 00078 GoAutoBook::GoAutoBook(const std::string& filename, 00079 const GoAutoBookParam& param) 00080 : m_param(param), 00081 m_filename(filename) 00082 { 00083 std::ifstream is(filename.c_str()); 00084 if (!is) 00085 { 00086 std::ofstream of(filename.c_str()); 00087 if (!of) 00088 throw SgException("Invalid file name!"); 00089 of.close(); 00090 } 00091 else 00092 { 00093 while (is) 00094 { 00095 std::string line; 00096 std::getline(is, line); 00097 if (line.size() < 19) 00098 continue; 00099 std::string str; 00100 std::istringstream iss(line); 00101 iss >> str; 00102 SgHashCode hash; 00103 hash.FromString(str); 00104 SgBookNode node(line.substr(19)); 00105 m_data[hash] = node; 00106 } 00107 SgDebug() << "GoAutoBook: Parsed " << m_data.size() << " lines.\n"; 00108 } 00109 } 00110 00111 GoAutoBook::~GoAutoBook() 00112 { 00113 } 00114 00115 bool GoAutoBook::Get(const GoAutoBookState& state, SgBookNode& node) const 00116 { 00117 Map::const_iterator it = m_data.find(state.GetHashCode()); 00118 if (it != m_data.end()) 00119 { 00120 node = it->second; 00121 return true; 00122 } 00123 return false; 00124 } 00125 00126 void GoAutoBook::Put(const GoAutoBookState& state, const SgBookNode& node) 00127 { 00128 m_data[state.GetHashCode()] = node; 00129 } 00130 00131 void GoAutoBook::Flush() 00132 { 00133 Save(m_filename); 00134 } 00135 00136 void GoAutoBook::Save(const std::string& filename) const 00137 { 00138 std::ofstream out(filename.c_str()); 00139 for (Map::const_iterator it = m_data.begin(); it != m_data.end(); ++it) 00140 { 00141 out << it->first.ToString() << '\t' 00142 << it->second.ToString() << '\n'; 00143 } 00144 out.close(); 00145 } 00146 00147 void GoAutoBook::Merge(const GoAutoBook& other) 00148 { 00149 SgDebug() << "GoAutoBook::Merge()\n"; 00150 std::size_t newLeafs = 0; 00151 std::size_t newInternal = 0; 00152 std::size_t leafsInCommon = 0; 00153 std::size_t internalInCommon = 0; 00154 std::size_t leafToInternal = 0; 00155 for (Map::const_iterator it = other.m_data.begin(); 00156 it != other.m_data.end(); ++it) 00157 { 00158 Map::iterator mine = m_data.find(it->first); 00159 SgBookNode newNode(it->second); 00160 if (mine == m_data.end()) 00161 { 00162 m_data[it->first] = it->second; 00163 if (newNode.IsLeaf()) 00164 newLeafs++; 00165 else 00166 newInternal++; 00167 } 00168 else 00169 { 00170 SgBookNode oldNode(mine->second); 00171 if (newNode.IsLeaf() && oldNode.IsLeaf()) 00172 { 00173 newNode.m_heurValue = 0.5f * (newNode.m_heurValue 00174 + oldNode.m_heurValue); 00175 m_data[it->first] = newNode; 00176 leafsInCommon++; 00177 } 00178 else if (!newNode.IsLeaf()) 00179 { 00180 // Take the max of the count; can't just add them 00181 // together because then merging a book with itself 00182 // doubles the counts of everything, which doesn't 00183 // make sense. Need the parent of these books and do a 00184 // three-way merge if we want the counts to be 00185 // accurate after the merge. I don't think it matters 00186 // that much. 00187 newNode.m_count = std::max(newNode.m_count, oldNode.m_count); 00188 m_data[it->first] = newNode; 00189 if (!oldNode.IsLeaf()) 00190 internalInCommon++; 00191 else 00192 leafToInternal++; 00193 } 00194 } 00195 } 00196 SgDebug() << "Statistics\n" 00197 << "New Leafs " << newLeafs << '\n' 00198 << "New Internal " << newInternal << '\n' 00199 << "Common Leafs " << leafsInCommon << '\n' 00200 << "Common Internal " << internalInCommon << '\n' 00201 << "Leaf to Internal " << leafToInternal << '\n'; 00202 } 00203 00204 //---------------------------------------------------------------------------- 00205 00206 void GoAutoBook::TruncateByDepth(int depth, GoAutoBookState& state, 00207 GoAutoBook& other) const 00208 { 00209 std::set<SgHashCode> seen; 00210 TruncateByDepth(depth, state, other, seen); 00211 } 00212 00213 void GoAutoBook::TruncateByDepth(int depth, GoAutoBookState& state, 00214 GoAutoBook& other, 00215 std::set<SgHashCode>& seen) const 00216 { 00217 if (seen.count(state.GetHashCode())) 00218 return; 00219 SgBookNode node; 00220 if (!Get(state, node)) 00221 return; 00222 seen.insert(state.GetHashCode()); 00223 if (depth == 0) 00224 { 00225 // Set this node to be a leaf: copy its heuristic value into 00226 // its propagated value and set count to 0. 00227 node.m_count = 0; 00228 node.m_priority = SgBookNode::LEAF_PRIORITY; 00229 node.m_value = node.m_heurValue; 00230 other.Put(state, node); 00231 return; 00232 } 00233 other.Put(state, node); 00234 if (node.IsLeaf() || node.IsTerminal()) 00235 return; 00236 for (GoBoard::Iterator it(state.Board()); it; ++it) 00237 { 00238 if (state.Board().IsLegal(*it)) 00239 { 00240 state.Play(*it); 00241 TruncateByDepth(depth - 1, state, other, seen); 00242 state.Undo(); 00243 } 00244 } 00245 } 00246 00247 //---------------------------------------------------------------------------- 00248 00249 void GoAutoBook::ImportHashValuePairs(std::istream& in) 00250 { 00251 std::size_t count = 0; 00252 while (in) 00253 { 00254 SgHashCode hash; 00255 std::string hashStr; 00256 in >> hashStr; 00257 hash.FromString(hashStr); 00258 float value; 00259 if (!in) 00260 break; 00261 in >> value; 00262 if (m_data.count(hash) == 0) 00263 { 00264 std::ostringstream os; 00265 os << "Unknown hash: " << hash << '\n'; 00266 throw SgException(os.str()); 00267 } 00268 SgBookNode node(m_data[hash]); 00269 node.m_heurValue = value; 00270 node.m_value = value; 00271 m_data[hash] = node; 00272 count++; 00273 } 00274 SgDebug() << "GoAutoBook::ImportHashValue: imported " 00275 << count << " values.\n"; 00276 } 00277 00278 //---------------------------------------------------------------------------- 00279 00280 SgMove GoAutoBook::FindBestChild(GoAutoBookState& state) const 00281 { 00282 std::size_t bestCount = 0; 00283 SgMove bestMove = SG_NULLMOVE; 00284 float bestScore = 100.0f; 00285 SgBookNode node; 00286 if (!Get(state, node)) 00287 return SG_NULLMOVE; 00288 if (node.IsLeaf()) 00289 return SG_NULLMOVE; 00290 for (GoBoard::Iterator it(state.Board()); it; ++it) 00291 { 00292 if (state.Board().IsLegal(*it)) 00293 { 00294 state.Play(*it); 00295 if (m_disabled.count(state.GetHashCode()) > 0) 00296 SgDebug() << "Ignoring disabled move " 00297 << SgWritePoint(*it) << '\n'; 00298 // NOTE: Terminal nodes aren't supported at this time, so 00299 // we ignore them here. 00300 else if (Get(state, node) 00301 && !node.IsTerminal() 00302 && node.m_count >= m_param.m_usageCountThreshold) 00303 { 00304 if (m_param.m_selectType == GO_AUTOBOOK_SELECT_COUNT) 00305 { 00306 // Select by count, tiebreak by value. 00307 if (node.m_count > bestCount) 00308 { 00309 bestCount = node.m_count; 00310 bestMove = *it; 00311 bestScore = node.m_value; 00312 } 00313 // NOTE: do not have access to inverse function, 00314 // so we're minimizing here as a temporary solution. 00315 else if (node.m_count == bestCount 00316 && node.m_value < bestScore) 00317 { 00318 bestMove = *it; 00319 bestScore = node.m_value; 00320 } 00321 } 00322 else if (m_param.m_selectType == GO_AUTOBOOK_SELECT_VALUE) 00323 { 00324 // NOTE: do not have access to inverse function, 00325 // so we're minimizing here as a temporary solution. 00326 if (node.m_value < bestScore) 00327 { 00328 bestMove = *it; 00329 bestScore = node.m_value; 00330 } 00331 } 00332 } 00333 state.Undo(); 00334 } 00335 } 00336 return bestMove; 00337 } 00338 00339 SgMove GoAutoBook::LookupMove(const GoBoard& brd) const 00340 { 00341 GoAutoBookState state(brd); 00342 state.Synchronize(); 00343 return FindBestChild(state); 00344 } 00345 00346 //---------------------------------------------------------------------------- 00347 00348 void GoAutoBook::ExportToOldFormat(GoAutoBookState& state, std::ostream& out, 00349 std::set<SgHashCode>& seen) const 00350 { 00351 if (seen.count(state.GetHashCode())) 00352 return; 00353 SgBookNode node; 00354 if (!Get(state, node)) 00355 return; 00356 if (node.IsTerminal() || node.IsLeaf()) 00357 return; 00358 seen.insert(state.GetHashCode()); 00359 SgPoint move = FindBestChild(state); 00360 // If no move to play here, do not include it in the book 00361 if (move == SG_NULLMOVE) 00362 return; 00363 const GoBoard& brd = state.Board(); 00364 out << brd.Size() << ' '; 00365 for (int i = 0; i < brd.MoveNumber(); ++i) 00366 out << ' ' << SgWritePoint(brd.Move(i).Point()); 00367 out << " | " << SgWritePoint(move); 00368 out << '\n'; 00369 for (GoBoard::Iterator it(brd); it; ++it) 00370 if (brd.IsLegal(*it)) 00371 { 00372 state.Play(*it); 00373 ExportToOldFormat(state, out, seen); 00374 state.Undo(); 00375 } 00376 } 00377 00378 void GoAutoBook::ExportToOldFormat(GoAutoBookState& state, 00379 std::ostream& os) const 00380 { 00381 std::set<SgHashCode> seen; 00382 ExportToOldFormat(state, os, seen); 00383 } 00384 00385 //---------------------------------------------------------------------------- 00386 00387 std::vector< std::vector<SgMove> > GoAutoBook::ParseWorkList(std::istream& in) 00388 { 00389 std::vector< std::vector<SgMove> > ret; 00390 while (in) 00391 { 00392 std::string line; 00393 std::getline(in, line); 00394 if (line == "") 00395 continue; 00396 std::vector<SgMove> var; 00397 std::istringstream in2(line); 00398 while (true) 00399 { 00400 std::string s; 00401 in2 >> s; 00402 if (! in2 || s == "|") 00403 break; 00404 std::istringstream in3(s); 00405 SgPoint p; 00406 in3 >> SgReadPoint(p); 00407 if (! in3) 00408 throw SgException("Invalid point"); 00409 var.push_back(p); 00410 } 00411 ret.push_back(var); 00412 } 00413 SgDebug() << "GoAutoBook::ParseWorkList: Read " << ret.size() 00414 << " variations.\n"; 00415 return ret; 00416 } 00417 00418 //----------------------------------------------------------------------------