Source: mcts.js

/**
 * @file モンテカルロツリー探索の実装です。
 * このコードはPyaqの移植コードです。
 * @see {@link https://github.com/ymgaq/Pyaq}
 */
/*
 * @author 市川雄二
 * @copyright 2018 ICHIKAWA, Yuji (New 3 Rs)
 * @license MIT
 */
import { argsort, argmax, printProb } from './utils.js';
import { IntersectionState } from './board_constants.js';
import { Board } from './board.js';

const NODES_MAX_LENGTH = 16384;
const EXPAND_CNT = 8;
const COLLISION_DETECT = false;

/** MCTSのノードクラスです。 */
class Node {
    /**
     * MCTSのノードを生成します。
     * @param {BoardConstants} constants
     */
    constructor(constants) {
        this.C = constants;
        /** 着手候補数です。(名前のedgeはグラフ理論の枝のことです。) */
        this.edgeLength = 0;
        //頻繁なメモリアロケーションを避けるため、枝情報に必要な最大メモリを予め確保します。
        /** ポリシー確率の高い順並んだ着手候補です。 */
        this.moves = new Uint16Array(this.C.BVCNT + 1); 
        /** moves要素に対応するポリシー確率です。 */
        this.probabilities = new Float32Array(this.C.BVCNT + 1); 
        /** moves要素に対応するバリューです。ただしノードの手番から見ての値です。 */
        this.values = new Float32Array(this.C.BVCNT + 1);
        /** moves要素に対応する総アクションバリューです。ただしノードの手番から見ての値です。 */
        this.totalActionValues = new Float32Array(this.C.BVCNT + 1);
        /** moves要素に対応する訪問回数です。 */
        this.visitCounts = new Uint32Array(this.C.BVCNT + 1);
        /** moves要素に対応するノードIDです。 */
        this.nodeIds = new Int16Array(this.C.BVCNT + 1);
        /** moves要素に対応するハッシュです。 */
        this.hashes = new Int32Array(this.C.BVCNT + 1);
        /** moves要素に対応する局面のニューラルネットワークを計算したか否かを保持します。 */
        this.evaluated = new Uint8Array(this.C.BVCNT + 1);
        this.value = 0;
        this.totalCount = 0;
        this.hashValue = 0;
        this.moveNumber = -1;
        this.sortedIndices = null;
        this.position = null;
        this.clear();
    }

    /** 未使用状態にします。 */
    clear() {
        this.edgeLength = 0;
        this.value = 0;
        this.totalCount = 0;
        this.hashValue = 0;
        this.moveNumber = -1;
        this.sortedIndices = null;
        this.position = null;
    }

    /**
     * 初期化します。
     * @param {Integer} hash 現局面のハッシュです。
     * @param {Integer} moveNumber 現局面の手数です。
     * @param {UInt16[]} candidates Boardが生成する候補手情報です。
     * @param {Float32Array} prob 着手確率(ニューラルネットワークのポリシー出力)です。
     */
    initialize(hash, moveNumber, candidates, prob, value, position = null) {
        this.clear();
        this.hashValue = hash;
        this.moveNumber = moveNumber;
        this.value = value;
        this.position = position;

        for (const rv of argsort(prob, true)) {
            if (prob[rv] !== 0.0 && candidates.includes(rv)) {
                this.moves[this.edgeLength] = this.C.rv2ev(rv);
                this.probabilities[this.edgeLength] = prob[rv];
                this.values[this.edgeLength] = 0.0;
                this.totalActionValues[this.edgeLength] = 0.0;
                this.visitCounts[this.edgeLength] = 0;
                this.nodeIds[this.edgeLength] = -1;
                this.hashes[this.edgeLength] = 0;
                this.evaluated[this.edgeLength] = false;
                this.edgeLength += 1;
            }
        }
    }

    /**
     * indexのエッジの勝率を返します。
     * @param {Integer} index 
     * @returns {number}
     */
    winrate(index) {
        return this.totalActionValues[index] / Math.max(this.visitCounts[index], 1) / 2.0 + 0.5;
    }

    /**
     * visitCountsを1増やし、sortedIndicesをクリアします。
     * @param {Integer} index 
     */
    incrementVisitCount(index) {
        this.visitCounts[index] += 1;
        this.sortedIndices = null;
    }

    /**
     * visitCountsの多い順のインデックスの配列を返します。
     */
    getSortedIndices() {
        if (this.sortedIndices == null) {
            this.sortedIndices = argsort(this.visitCounts.slice(0, this.edgeLength), true, this.values.slice(0, this.edgeLength));
        }
        return this.sortedIndices;
    }


    /**
     * UCB評価で最善の着手情報を返します。
     * @param {number} c_puct
     * @returns {Array} [UCB選択インデックス, 最善ブランチの子ノードID, 着手]
     */
    selectByUCB(c_puct) {
        const cpsv = c_puct * Math.sqrt(this.totalCount);
        const meanActionValues = new Float32Array(this.edgeLength);
        for (let i = 0; i < meanActionValues.length; i++) {
            /* AlphaGo Zero論文ではmeanActionValueの初期値は0ですが、その場合、悪い局面でMCTSがエッジをすべて一度は評価しようとします。
             * それを避けるために初期値を現局面の評価値にしました。 */
            meanActionValues[i] = this.visitCounts[i] === 0 ? this.value : this.totalActionValues[i] / this.visitCounts[i];
        }
        const ucb = new Float32Array(this.edgeLength);
        for (let i = 0; i < ucb.length; i++) {
            ucb[i] = meanActionValues[i] + cpsv * this.probabilities[i] / (1 + this.visitCounts[i]);
        }
        const selectedIndex = argmax(ucb);
        const selectedId = this.nodeIds[selectedIndex];
        const selectedMove = this.moves[selectedIndex];
        return [selectedIndex, selectedId, selectedMove];
    }
}

/** モンテカルロツリー探索を実行するクラスです。 */
export class MCTS {
    /**
     * コンストラクタ
     * @param {NeuralNetwork} nn 
     * @param {BoardConstants} C
     * @param {Function} evaluatePlugin
     */
    constructor(nn, C, evaluatePlugin = null) {
        this.C_PUCT = 1.5; // ELF OpenGoのボット設定
        this.mainTime = 0.0;
        this.byoyomi = 1.0;
        this.leftTime = 0.0;
        this.nodes = [];
        this.nodesLength = 0;
        for (let i = 0; i < NODES_MAX_LENGTH; i++) {
            this.nodes.push(new Node(C));
        }
        this.rootId = 0;
        this.rootMoveNumber = 0;
        this.nodeHashes = new Map();
        this.evalCount = 0;
        this.nn = nn;
        this.terminateFlag = false;
        this.exitCondition = null;
        this.evaluatePlugin = evaluatePlugin;
        this.collisions = 0;
    }

    /**
     * 持ち時間の設定をします。
     * 残り時間もリセットします。
     * @param {number} mainTime 秒
     * @param {number} byoyomi 秒
     */
    setTime(mainTime, byoyomi) {
        this.mainTime = mainTime;
        this.leftTime = mainTime;
        this.byoyomi = byoyomi;
    }

    /**
     * 残り時間を設定します。
     * @param {number} leftTime 秒
     */
    setLeftTime(leftTime) {
        this.leftTime = leftTime;
    }

    /**
     * 内部状態をクリアします。
     * 同一時間設定で初手から対局できるようになります。
     */
    clear() {
        this.leftTime = this.mainTime;
        for (const node of this.nodes) {
            node.clear();
        }
        this.nodesLength = 0;
        this.rootId = 0;
        this.rootMoveNumber = 0;
        this.nodeHashes.clear();
        this.evalCount = 0;
    }

    /**
     * this.nodesに同じ局面があればそのインデックスを返します。
     * なければnullを返します。
     * @param {Board} b 
     */
    getNodeIdInNodes(b) {
        const hash = b.hash();
        if (this.nodeHashes.has(hash)) {
            const id = this.nodeHashes.get(hash);
            if (b.moveNumber === this.nodes[id].moveNumber) {
                if (!COLLISION_DETECT || b.toString() === this.nodes[id].position) {
                    return id;
                } else {
                    this.collisions += 1;
                    console.log('hash collision');
                    b.showboard(false);
                    console.log(this.nodes[id].position);
                }
            }
        }
        return null;
    }
    /**
     * 局面bのMCTSの探索ノードを生成してノードIDを返します。
     * @param {Board} b 
     * @param {Float32Array} prob 
     * @returns {Integer} ノードID
     */
    createNode(b, prob, value) {
        const hash = b.hash();
        let nodeId = Math.abs(hash) % NODES_MAX_LENGTH;
        while (this.nodes[nodeId].moveNumber !== -1) {
            nodeId = nodeId + 1 < NODES_MAX_LENGTH ? nodeId + 1 : 0;
        }

        this.nodeHashes.set(hash, nodeId);
        this.nodesLength += 1;

        const node = this.nodes[nodeId];
        node.initialize(hash, b.moveNumber, b.candidates(), prob, value, COLLISION_DETECT && b.toString());
        return nodeId;
    }

    /**
     * nodesの中の不要なノードを未使用状態に戻します。
     */
    cleanupNodes() {
        if (this.nodesLength < NODES_MAX_LENGTH / 2) {
            return;
        }
        for (let i = 0; i < NODES_MAX_LENGTH; i++) {
            const mn = this.nodes[i].moveNumber;
            if (mn >= 0 && mn < this.rootMoveNumber) {
                this.nodeHashes.delete(this.nodes[i].hashValue);
                this.nodes[i].clear();
                this.nodesLength -= 1;
            }
        }
    }

    /**
     * 検索するかどうかを決定します。
     * @param {Integer} best 
     * @param {Integer} second 
     * @returns {bool}
     */
    shouldSearch(best, second) {
        const node = this.nodes[this.rootId];
        const winrate = node.winrate(best);

        return winrate <= 0.5 || node.probabilities[best] <= 0.99;
    }

    /**
     * 次の着手の考慮時間を算出します。
     * @returns {number} 使用する時間(秒)
     */
    getSearchTime(C) {
        if (this.mainTime === 0.0 || this.leftTime < this.byoyomi * 2.0) {
            return Math.max(this.byoyomi, 1.0);
        } else {
            // 碁盤を埋めることを仮定し、残りの手数を算出します。
            const assumedRemainingMoves = (C.BVCNT - this.rootMoveNumber) / 2;
            //布石ではより多くの手数を仮定し、急ぎます。
            const openingOffset = Math.max(C.BVCNT / 10 - this.rootMoveNumber, 0);
            return this.leftTime / (assumedRemainingMoves + openingOffset) + this.byoyomi;
        }
    }

    /**
     * nodeIdのノードのedgeIndexのエッジに対応するノードが既に存在するか返します。
     * @param {Integer} edgeIndex 
     * @param {Integer} nodeId 
     * @param {Integer} moveNumber 
     * @returns {bool}
     */
    hasEdgeNode(edgeIndex, nodeId, moveNumber) {
        const node = this.nodes[nodeId];
        const edgeId = node.nodeIds[edgeIndex];
        if (edgeId < 0) {
            return false;
        }
        return node.hashes[edgeIndex] === this.nodes[edgeId].hashValue &&
            this.nodes[edgeId].moveNumber === moveNumber;
    }

    /**
     * 
     * @private
     * @param {Integer} nodeId 
     * @param {BoardConstants} c 
     */
    edgeWinrate(nodeId, c) {
        let value = NaN;
        let parity = 1;
        while (true) {
            const node = this.nodes[nodeId];
            if (node.edgeLength < 1) {
                break;
            }

            const [best] = node.getSortedIndices();
            if (node.visitCounts[best] === 0) {
                break;
            }
            value = node.values[best] * parity;
            if (!this.hasEdgeNode(best, nodeId, node.moveNumber + 1)) {
                break;
            }
            nodeId = node.nodeIds[best];
            parity *= -1;
        }

        return (value + 1.0) / 2.0;
    }

    /**
     * printInfoのヘルパー関数です。
     * @private
     * @param {Uint16} headMove nodeIdのノードに至る着手です
     * @param {Integer} nodeId 
     * @param {BoardConstants} c 
     */
    bestSequence(nodeId, c) {
        let seqStr = '';
        for (let i = 0; i < 7; i++) {
            const node = this.nodes[nodeId];
            if (node.edgeLength < 1) {
                break;
            }

            const [best] = node.getSortedIndices();
            if (node.visitCounts[best] === 0) {
                break;
            }
            const bestMove = node.moves[best];
            seqStr += '->' + (c.ev2str(bestMove) + '   ').slice(0, 3);

            if (!this.hasEdgeNode(best, nodeId, node.moveNumber + 1)) {
                break;
            }
            nodeId = node.nodeIds[best];
        }

        return seqStr;
    }

    /**
     * 探索結果の詳細を表示します。
     * @param {Integer} nodeId 
     * @param {BoardConstants} c
     */
    printInfo(nodeId, c) {
        const node = this.nodes[nodeId];
        const order = node.getSortedIndices();
        console.log('|move|count  |rate |value|prob | best sequence');
        const length = Math.min(order.length, 9);
        for (let i = 0; i < length; i++) {
            const m = order[i];
            const visitCount = node.visitCounts[m];
            if (visitCount === 0) {
                break;
            }

            const rate = visitCount === 0 ? 0.0 : node.winrate(m) * 100.0;
            const value = (node.values[m] / 2.0 + 0.5) * 100.0;
            console.log(
                '|%s|%s|%s|%s|%s| %s',
                (c.ev2str(node.moves[m]) + '    ').slice(0, 4),
                ('       ' + visitCount).slice(-7),
                ('  ' + rate.toFixed(1)).slice(-5),
                ('  ' + value.toFixed(1)).slice(-5),
                ('  ' + (node.probabilities[m] * 100.0).toFixed(1)).slice(-5),
                (c.ev2str(node.moves[m]) + '   ').slice(0, 3) + this.bestSequence(node.nodeIds[m], c)
            );
        }
    }

    /**
     * ニューラルネットワークで局面を評価します。
     * @param {Board} b
     * @param {bool} random
     * @returns {Float32Array[]}
     */
    async evaluate(b, random = true) {
        let [prob, value] = await b.evaluate(this.nn, random);
        if (this.evaluatePlugin) {
            prob = this.evaluatePlugin(b, prob);
        }
        this.evalCount += 1;
        return [prob, value];
    }

    /**
     * 検索の前処理です。
     * @private
     * @param {Board} b 
     * @returns {Node}
     */
    async prepareRootNode(b) {
        this.rootMoveNumber = b.moveNumber;
        this.rootId = this.getNodeIdInNodes(b);
        if (this.rootId == null) {
            const [prob, value] = await this.evaluate(b);
            this.rootId = this.createNode(b, prob, value);
        }
        // AlphaGo Zeroでは自己対戦時にはここでprobに"Dirichletノイズ"を追加します。
        return this.nodes[this.rootId];
    }

    /**
     * edgeIndexのエッジの局面を評価しノードを生成してバリューを返します。
     * @private
     * @param {Board} b 
     * @param {Integer} edgeIndex 
     * @param {Node} parentNode 
     * @returns {number} parentNodeの手番から見たedge局面のバリュー
     */
    async evaluateEdge(b, edgeIndex, parentNode) {
        let [prob, value] = await this.evaluate(b);
        value = -value[0]; // parentNodeの手番から見たバリューに変換します。
        if (this.nodesLength > 0.85 * NODES_MAX_LENGTH) {
            this.cleanupNodes();
        }
        const nodeId = this.createNode(b, prob, value);
        parentNode.nodeIds[edgeIndex] = nodeId;
        parentNode.hashes[edgeIndex] = b.hash();
        parentNode.values[edgeIndex] = value;
        parentNode.evaluated[edgeIndex] = true;
        /*
        if (!this.isConsistentNode(nodeId, b)) {
            const node = this.nodes[nodeId];
            for (let i = 0; i < node.edgeLength; i++) {
                console.log(b.C.ev2str(node.moves[i]));
            }
            b.showboard();
            throw new Error('inconsistent node');
        }
        */
        return value;
    }

    /**
     * MCTSツリーをUCBに従って下り、リーフノードに到達したら展開します。
     * @private
     * @param {Board} b 
     * @param {Integer} nodeId
     * @returns {number} バリュー
     */
    async playout(b, nodeId) {
        const node = this.nodes[nodeId];
        const [selectedIndex, selectedId, selectedMove] = node.selectByUCB(this.C_PUCT);
        try {
            b.play(selectedMove);
        } catch (e) {
            console.error(e);
            b.showboard();
            console.log('%s %d', b.C.ev2str(selectedMove), this.collisions);
            b.play(b.C.PASS);
        }
        let value;
        if (selectedId >= 0) {
            value = - await this.playout(b, selectedId); // selectedIdの手番でのバリューが返されるから符号を反転させます。
        } else {
            const nodeId = this.nodeHashes.get(b.hash());
            if (nodeId && this.nodes[nodeId].moveNumber === b.moveNumber) {
                const edgeNode = this.nodes[nodeId];
                node.nodeIds[selectedIndex] = nodeId;
                node.hashes[selectedIndex] = b.hash();
                node.values[selectedIndex] = -edgeNode.value;
                node.evaluated[selectedIndex] = true;
                value = - await this.playout(b, nodeId);
            } else {
                value = await this.evaluateEdge(b, selectedIndex, node);
            }
        }
        node.totalCount += 1;
        node.totalActionValue += value;
        node.totalActionValues[selectedIndex] += value;
        node.incrementVisitCount(selectedIndex);
        return value;
    }

    /**
     * プレイアウトを繰り返してMCTSツリーを更新します。
     * @private
     * @param {Board} b 
     */
    async keepPlayout(b) {
        let bCpy = new Board(b.C, b.komi);
        do {
            b.copyTo(bCpy);
            await this.playout(bCpy, this.rootId);
        } while (!this.exitCondition());
    }

    /**
     * MCTS探索メソッドです。
     * 局面bをルートノード設定して、終了条件を設定し、time時間探索し、結果をログ出力してルートノードを返します。
     * @param {Board} b 
     * @param {number} time 探索時間を秒単位で指定します
     * @param {bool} ponder ttrueのときstopメソッドが呼ばれるまで探索を継続します
     * @param {bool} clean 形勢が変わらない限りパス以外の着手を選びます
     * @returns {Object[]} [Node, Integer] ルートノードと評価数
     */
    async search(b, time, ponder) {
        const start = Date.now();
        this.evalCount = 0;
        const rootNode = await this.prepareRootNode(b);

        if (rootNode.edgeLength <= 1) { // 候補手がパスしかなければ
            console.log('\nmove number=%d:', this.rootMoveNumber + 1);
            this.printInfo(this.rootId, b.C);
            return [rootNode, this.evalCount];
        }

        this.cleanupNodes();

        const time_ = (time === 0.0 ? this.getSearchTime(b.C) : time) * 1000 - 500; // 0.5秒のマージン
        this.terminateFlag = false;
        this.exitCondition = ponder ? () => this.terminateFlag :
            () => this.terminateFlag || Date.now() - start > time_;

        let [best, second] = rootNode.getSortedIndices();
        if (ponder || this.shouldSearch(best, second)) {
            await this.keepPlayout(b);
            [best, second] = rootNode.getSortedIndices();
        }

        console.log(
            '\nmove number=%d: left time=%s[sec] evaluated=%d collisions=%d',
            this.rootMoveNumber + 1,
            Math.max(this.leftTime - time, 0.0).toFixed(1),
            this.evalCount,
            this.collisions
        );
        this.printInfo(this.rootId, b.C);
        this.leftTime = this.leftTime - (Date.now() - start) / 1000;
        return [rootNode, this.evalCount];
    }

    /**
     * 実行中のkeepPlayoutを停止させます。
     */
    stop() {
        this.terminateFlag = true;
    }

    /**
     * 
     * @param {Integer} nodeId
     * @param {Board} b 
     */
    isConsistentNode(nodeId, b) {
        const node = this.nodes[nodeId];
        for (let i = 0; i < node.edgeLength; i++) {
            const ev = node.moves[i];
            if (ev === b.C.EBVCNT) {
                continue
            }
            if (b.state[ev] !== IntersectionState.EMPTY) {
                console.log('isConsistentNode', b.C.ev2str(ev));
                return false;
            }
        }
        return true
    }
}