EnsekiTT Blog

EnsekiTTが書くブログです。

Unityで強化学習していたAgentのソースコードを読む話

こんにちは、えんせきです。
今日は久々にカレーを作りました。ニンジンが、柔らかくならない!学習が必要だ。

つまりなにしたの?

先日、球をのせ続けることが得意になったAgentが何をやっているのか読みといてみた。
Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.csが今日のターゲット
f:id:ensekitt:20171227002414j:plain

どんなことをやるAgentなのか(再掲)

f:id:ensekitt:20171225004753g:plain

Agentのコードの大まかな構成

  • initializeAgent()
  • CollectState()
  • AgentStep()
  • AgentReset()
  • AgentOnDone()

の5つが用意されている。
これらをオーバーライドしてやることでエージェントに状態を持たせてアクションを実施して報酬を得られるようにする。
色々事前に決まっているので中身を読みながら、ドキュメントを読むと良さそうだった。
github.com

この辺を参考に進めた。

InitializeAgent()

Agentを初期化するための関数。Agent生成時に呼び出される。まちがってもAwake(),Start(),OnEnable()をつかってはいけない。
ちなみに3DBallでは特に何もされていない。

CollectState()

Agentが保持するべき情報のリストを保持する。
3DBallでは、プレートの角度情報、ボールとプレートの位置関係、ボールの速度がリストに入っている

    public override List<float> CollectState()
    {
        List<float> state = new List<float>();
        // プレートの角度情報
        state.Add(gameObject.transform.rotation.z);
        state.Add(gameObject.transform.rotation.x);
        // ボールとプレートの位置関係
        state.Add((ball.transform.position.x - gameObject.transform.position.x));
        state.Add((ball.transform.position.y - gameObject.transform.position.y));
        state.Add((ball.transform.position.z - gameObject.transform.position.z));
        // ボールの速度
        state.Add(ball.transform.GetComponent<Rigidbody>().velocity.x);
        state.Add(ball.transform.GetComponent<Rigidbody>().velocity.y);
        state.Add(ball.transform.GetComponent<Rigidbody>().velocity.z);
        // Stateを返す
        return state;
    }

AgentStep(float[] act)

入力のアクション(act)が与えられた時にAgentが行う行動を決定する。また、報酬とエージェントのアクションが完了したかどうかを指定する。報酬はreward、完了はdoneというパブリック変数で指定することができる。
actはFloatの配列なのでact[0],act[1]などで値を取り出すことができる。

    public override void AgentStep(float[] act)
    {
        // 状態空間モデルを連続だった場合
        if (brain.brainParameters.actionSpaceType == StateType.continuous)
        {
            // アクションを読み取る(Z軸の回転)
            float action_z = act[0];
            // アクションの値域を制限する
            if (action_z > 2f)
            {
                action_z = 2f;
            }
            if (action_z < -2f)
            {
                action_z = -2f;
            }
            // オブジェクトの傾き角度が一定以上にならない場合は回転を実施する
            if ((gameObject.transform.rotation.z < 0.25f && action_z > 0f) ||
                (gameObject.transform.rotation.z > -0.25f && action_z < 0f))
            {
                gameObject.transform.Rotate(new Vector3(0, 0, 1), action_z);
            }

            // アクションを読み取る(X軸の回転)
            float action_x = act[1];
            // アクションの値域を制限する
            if (action_x > 2f)
            {
                action_x = 2f;
            }
            if (action_x < -2f)
            {
                action_x = -2f;
            }
            // オブジェクトの傾き角度が一定以上にならない場合は回転を実施する
            if ((gameObject.transform.rotation.x < 0.25f && action_x > 0f) ||
                (gameObject.transform.rotation.x > -0.25f && action_x < 0f))
            {
                gameObject.transform.Rotate(new Vector3(1, 0, 0), action_x);
            }
	
            // doneがfalseだったら報酬を0.1追加する
            if (done == false)
            {
                reward = 0.1f;
            }
        }
        // 状態空間モデルを連続じゃなかった場合
        else
        {
            // アクションを読み取る
            int action = (int)act[0];
            // アクションが0か1だったら±1にして、オブジェクトの傾き角度が一定以上にならない場合は回転を実施する
            if (action == 0 || action == 1)
            {
                action = (action * 2) - 1;
                float changeValue = action * 2f;
                if ((gameObject.transform.rotation.z < 0.25f && changeValue > 0f) ||
                    (gameObject.transform.rotation.z > -0.25f && changeValue < 0f))
                {
                    gameObject.transform.Rotate(new Vector3(0, 0, 1), changeValue);
                }
            }
            // アクションが2か3だったら±1にして、オブジェクトの傾き角度が一定以上にならない場合は回転を実施する
            if (action == 2 || action == 3)
            {
                action = ((action - 2) * 2) - 1;
                float changeValue = action * 2f;
                if ((gameObject.transform.rotation.x < 0.25f && changeValue > 0f) ||
                    (gameObject.transform.rotation.x > -0.25f && changeValue < 0f))
                {
                    gameObject.transform.Rotate(new Vector3(1, 0, 0), changeValue);
                }
            }

            // doneがfalseだったら報酬を0.1追加する
            if (done == false)
            {
                reward = 0.1f;
            }
        }
        // ボールがプレートよりも低くなるか、プレートの上を逸脱したらDoneをTrueにして、報酬をマイナス1にする
        if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
            Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
            Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
        {
            done = true;
            reward = -1f;
        }

    }

AgentReset()

この関数はAcademy(学習環境)がリセットされたときと、エージェントが完了(Doneフラグ)した時に呼び出される。
3DBallでは回転を0に戻して、回転角度をランダムに設定する。ボールの速度を0にして、高さ一定、位置ランダムに設定する。
どちらかと言うと初期化でやりそうなことをここでやっているみたい。

    public override void AgentReset()
    {
        gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
        gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
        gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
        ball.GetComponent<Rigidbody>().velocity = new Vector3(0f, 0f, 0f);
        ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f)) + gameObject.transform.position;
    }

AgentOnDone()

エージェントが完了(Doneフラグ)されないままエージェントが終了するとこの関数が呼び出される。Academyがリセットされた時にのみ呼び出される。

読んだらなんか雰囲気つかめてきた

あとはDecision.csみたいなのをBrainに設定すると自前のエージェントも作れる気がする。

これまでの記事はこちら

ensekitt.hatenablog.com
ensekitt.hatenablog.com

クリエイティブ・コモンズ・ライセンス
この 作品 は クリエイティブ・コモンズ 表示 4.0 国際 ライセンスの下に提供されています。