/**
* @file ニューラルネットワークを計算するクラスです。
*/
/*
* @author 市川雄二
* @copyright 2018 ICHIKAWA, Yuji (New 3 Rs)
* @license MIT
*/
/* global WebDNN $ */
import { softmax } from './utils.js';
/* sliceのpolyfill */
if (!ArrayBuffer.prototype.slice) {
ArrayBuffer.prototype.slice = function(start, end) {
var that = new Uint8Array(this);
if (end == undefined) end = that.length;
var result = new ArrayBuffer(end - start);
var resultArray = new Uint8Array(result);
for (var i = 0; i < resultArray.length; i++)
resultArray[i] = that[i + start];
return result;
};
}
/**
* ウェイトをロードする際のプログレスバーを更新します。
* @param {number} percentage
*/
function setLoadingBar(percentage) {
const $loadingBar = $('#loading-bar');
$loadingBar.attr('aria-valuenow', percentage);
$loadingBar.css('width', percentage.toString() + '%');
}
/** ニューラルネットワークを計算するクラス(WebDNNのDescriptorRunnerのラッパークラス) */
export class NeuralNetwork {
constructor() {
this.version = 1;
this.nn = null;
}
/**
* ウェイトファイルをダウンロードします。
* @param {string} path WebDNNデータのURL
* @param {Integer} version Leela Zeroのウェイトフォーマット番号
*/
async load(path, version = 1) {
this.version = version;
if (this.nn) {
setLoadingBar(100);
return;
}
const options = {
backendOrder: ['webgpu', 'webgl'],
progressCallback: function(loaded, total) {
setLoadingBar(loaded / total * 100);
}
};
setLoadingBar(0);
this.nn = await WebDNN.load(path, options);
setLoadingBar(100); // progressCallbackがコールさえないパターンがあるので完了時にコールします。
}
/**
* ニューラルネットワークを評価した結果を返します。
* @param {Array} inputs
* @returns {Array}
*/
async evaluate(...inputs) {
const views = this.nn.getInputViews();
for (let i = 0; i < inputs.length; i++) {
views[i].set(inputs[i]);
}
await this.nn.run();
const result = this.nn.getOutputViews().map(e => e.toActual());
result[0] = softmax(result[0]);
result[1] = result[1].slice(0);// to.ActualそのものではpostMessageでdetachができないのでコピーする。
if (this.version === 2 && inputs[0][inputs[0].length - 1] === 1.0) {
result[1][0] = -result[1][0];
}
return result;
}
}