Search on the blog

2015年8月23日日曜日

オセロのAIを作る(2)

 Monte Carlo Tree Searchを使ったオセロAIを実装しました。

試しに対戦してみましたが、私(オセロ初心者)よりは少し強い感じです。
また、以前実装した中で最強プレイヤーだったDfsStoneNumPlayer(評価関数を石の数としたゲーム木探索プレイヤー)よりも強いです。

 せっかくなので、市販のオセロAIと対戦させてみよう。
ということで、iphoneアプリのHAYABUSAと対戦させてみました。

まずは最弱レベルであるLOCAL-EASYを相手にしてみました。
MctsPlayerというのが私が作ったAIの名前です。

MctsPlayer ●44 - ○20 HAYABUSA(LOCAL-EASY)
MctsPlayer ○50 - ●14 HAYABUSA(LOCAL-EASY)
MctsPlayer ●53 - ○11 HAYABUSA(LOCAL-EASY)
MctsPlayer ○41 - ●23 HAYABUSA(LOCAL-EASY)
MctsPlayer ●41 - ○23 HAYABUSA(LOCAL-EASY)

最弱モードと言えども、圧勝できると嬉しいものです。
相手を連続パスに追い込むという場面も多々見られ、なかなかいい感じです。

次にLOCAL-NORMAL(LOCAL-EASYより1つ上のレベル)と対戦させてみました。
ちなみに私が自分でプレーした戦績は0勝2敗でした。こいつに勝てると嬉しいですね。

MctsPlayer ○20 - ●44 HAYABUSA(LOCAL-NORMAL)
MctsPlayer ●15 - ○49 HAYABUSA(LOCAL-NORMAL)
MctsPlayer ○24 - ●40 HAYABUSA(LOCAL-NORMAL)
MctsPlayer ●15 - ○49 HAYABUSA(LOCAL-NORMAL)
MctsPlayer ○16 - ●48 HAYABUSA(LOCAL-NORMAL)

はい、完敗〜。
完膚なきまでに打ちのめされてしまいました。
打ちたくない場所にいいように打たされて、最終的にひっくり返されるというパターンで負けていました。

相手の動きを観察すると、
  • 敵が打てる場所を減らす
  • 敵に空マスの隣を取らせる
ということをやっているようです。
おそらく、disc count、mobility、frontierの加重和を評価関数にしてα-β探索でもやってるのかなと思います。もっと複雑なことやってるかもしれませんが。

2015年8月17日月曜日

オセロのAIを作る(1)

 先日MCTSを実装してから、オセロのAI作りたいなーという気持ちになっていた。



で、盆休みに2日くらいで作ったのがこれ
オセロの定石知らない私でも簡単に勝てるような弱小AIしか実装できてませんが、これからちょっとずつ強化していこうと思います。

2015年8月4日火曜日

Kaggle参加記:Search Results Relevance

 KaggleのSearch Results Relevanceに参加しました。

About competition
このコンテストでは、検索クエリとそれに対する検索結果がデータとして与えられます。教師ありデータではクエリに対する検索結果のマッチ度を1-4で評価しています。テストデータではこのマッチ度を予測します。

What I did
特徴エンジニアリングに利用するデータがテキストなので、NLPまわりの知識が必要になります。知識ゼロベースからのスタートだったので、とりあえず論文やらチュートリアルやらライブラリの使い方やらを読み漁るところから始めました。

 TF-IDFベクトルのコサイン類似度、Okappi BM25あたりが検索ワード-検索結果のランキングアルゴリズムでよく使われることが分かったので、この2つをC++でカスタム実装してみました。特徴量抽出後はsklearnのランダムフォレストで識別。いい結果出ず。テキストの前処理が甘いと思い、大文字⇒小文字変換、Stop-words除去、stemmingなどを行うもスコアが思ったように伸びず。

 こういうときは、Forumだ!ということで、Forumに上がっているスレッドを全部読みました。starter scriptなるものがあり、それを実行すると自分が書いたものよりいい結果が出て凹む。
しかもそのスクリプトがTF-IDFを抽出した後、次元圧縮し、SVMで識別するだけというシンプルなモデルだったため余計に凹む。しかし何故こんなモデルでうまくいくのだと悩んでいると同じようなことを質問しているユーザがいたため、何か手がかりになるはずと考え理由を考えました。

 そして検索対象がマッチしやすいかそうでないオブジェクトに分けるとうまくいくのではという結論に達しました。実際にデータを見てみると同じ検索クエリが複数回使われていて、平均マッチ度が高いクエリと低いクエリが存在することが判明しました。もっと早めに気づけよという感じです。「特徴エンジニアリングする前にデータをよく知る」ということが出来ていなかったです。

 この知見をベースに、クエリ毎に識別器を作ったり、始めにクエリをk-meansでクラスタリングした後にクラスタ毎に識別器を作ったりして、スコアもまあまあ上がりました。cross validationでいい結果の出たモデルをいくつか選んでスタッキングしてサブミットで試合終了となりました。

What I didn't do but top kagglers did
コンテスト終了後にトップランカーたちがやっていたことを調べました。その中でこれは!というものを書いておきます。1.-4.はテクニック的な話、5.はそもそもの心構え的な話です。
  1. コサイン類似度以外にも複数の距離(編集距離、圧縮距離など)を特徴量として使う
  2. 検索結果=4のものをクエリに足し合わせて拡張クエリを作る(マッチ度が高い検索結果 = クエリをより詳細に記述したものであるはず)
  3. 識別問題ではなく、回帰問題として扱う(評価基準がkappa estimatorなので、ラベルは離散値だけど実測値からの誤差を最小化するとうまくいく)
  4. 回帰モデルの予測値をソートし、"1","2","3","4"のスコアを訓練データと同じ割合で付与する(単純に回帰モデルの予測値を整数化してもいい結果はでない)
  5. validation時にスコアの悪いものを1つ1つ確認し、なぜスコアが悪いか考えモデルを改良していく

My model after the competition
ということで、上の反省をふまえてモデリングをやり直し、一から実装しなおしました。
まあこんな感じのものです。

ソースコードはこちら
private score = 0.66830でした。これでも180位くらいなのでまだまだですね。


はじめてのMonte Carlo Tree Search

 以前から実装してみたかったMonte Carlo Tree Search(MCTS)を実装してみました。MCTSはその名のとおり、モンテカルロな木探索です。主にゲーム木の探索に用いられます。

 普通の全探索では木の節点すべてを調べますが、MCTSでは根から葉へのパスをいくつかランダムに選んで探索を行います。

 まず1度、根から葉へのパスを調べたとします。これはランダムにゲームを終局まで進めたことに対応します。これによりパス上の節点において、選んだ手を打ったときの勝率が概算できます。もちろん1回ランダムにゲームをしただけなので精度は低いです。
2回目以降も同様にランダムに根から葉へのパスを選びゲームを進めていきます。回数を重ねるごとに情報が溜まってきて、各ノードからどのノードに行くのが良さそうか分かってきます。

 これだけだと精度を上げるのに時間がかかるため、子ノードの選択を工夫します。
  1. 勝率の高そうな子ノードを重点的に選択する
  2. 訪問数の少ない子ノードも積極的に開拓する
1.と2.はトレードオフの関係にあるため、どちらを優先するかはパラメータの設定で決めます。

例題
まずは簡単な例題で実装してみました。
「最高3つカウントダウン出来て、1を言った人が勝ち」というゲームを考えます。
以下にプレイヤーAとBの対戦例を示します。

初期値: 10
A: 10, 9
B: 8
A: 7,6,5
B: 4, 3
A: 2, 1

でAの勝ちです。このゲームの最善手は動的計画法を使えば分かりますが(もしくは法則を知っていれば”4k+1を取れば勝ち”で終わり)、敢えてMCTSで解いてみます。

ソースコード
最近Pythonが多いので、久しぶりにJavaで書いてみました。 実行するとそれっぽい結果が得られました。ある状態で自分の手番が回ってきたときの勝率と最善手を計算できます。 i=19まではそれっぽいですが、i=20以降はダメです(あれ・・、全探索した方が速いような・・)。ロールアウト回数を増やすと、計算時間は増えますが大きなiでもそれっぽい結果になるはずです。
package com.kenjih.mcts;

import java.util.HashMap;
import java.util.Map;
import java.util.Stack;

public class MCTS {
    
    private static final int MAX_NUM = 3;
    private int rollOut;
    private Node root;
    
    public MCTS(int initState, int rollOut) {
        this.rollOut = rollOut;
        this.root = new Node(initState);
    }

    public int getNextBestHand() {
        
        for (int _ = 0; _ < rollOut; _++) {
            Node node = root;
            
            Stack<Node> stack = new Stack<Node>();
            while (!node.isLeaf()) {
                stack.add(node);
                node = node.expand().select();
            }
            stack.add(node);
            
            int win = 1;   // 1:win, 0:lose
            while (!stack.empty()) {
                node = stack.pop();
                ++node.n;
                node.w += win;
                win ^= 1;
            }
        }
        
        int ret = -1;
        double bestRate = getBestWinRate();
        
        for (int i : root.children.keySet()) {
            double rate = root.children.get(i).getWinRate();
            if (rate == bestRate) {
                ret = i;
                break;
            }
        }
        
        return ret;
    }
    
    public double getBestWinRate() {
        double ret = -1.0;
        
        for (int i : root.children.keySet()) {
            double rate = root.children.get(i).getWinRate();
            ret = Math.max(ret, rate);
        }
        
        return ret;        
    }
    
    class Node {
        int w;      // # of wins
        int n;      // # of visits
        int state;  // game state (current number in this case)
        Map<Integer, Node> children = null;  // hand -> next state
        
        Node(int state) {
            this.state = state;
            this.w = 0;
            this.n = 1;    
        }
        
        boolean isLeaf() {
            return state == 0;
        }
        
        Node expand() {
            if (children == null) {
                children = new HashMap<Integer, MCTS.Node>();
                
                for (int i = 1; i <= MAX_NUM; i++) {
                    if (state - i >= 0) {
                        children.put(i, new Node(state - i));
                    }
                }
            }
            
            return this;
        }
                
        Node select() {
            double bestScore = -1.0;
            Node ret = null;
            
            for (int hand : children.keySet()) {
                Node nxt = children.get(hand);
                double score = nxt.getScore(n);
                if (score > bestScore) {
                    bestScore = score;
                    ret = nxt;
                }
            }
            return ret;
        }
        
        public double getWinRate() {
            return 1.*w/n;
        }
        
        private double getScore(int total) {
            double c = Math.sqrt(2);
            double t = Math.log(total);
            return 1.*w/n + c*Math.sqrt(t / n);
        }
        
    }
    
    public static void main(String[] args) {
        
        for (int i = 1; i < 50; i++) {
            MCTS mcts = new MCTS(i, 10000);
            int ret = mcts.getNextBestHand();
            double rate = mcts.getBestWinRate();
            System.out.println(i + "->" + ret + ": " + rate);    
        }        
                    
    }

}