強化学習と動的計画法

2020年11月15日

はじめに

Sutton and Barto 著『強化学習』(Reinforcement Learning An Introduction) を参考に、強化学習と動的計画法について見る。

環境

  • Miniconda3 (Python 3.7)

強化学習

強化学習

強化学習 (reinforcement learning) とは、環境と相互作用する意思決定エージェントの学習と行動のアルゴリズムである。出力に対してフィードバックを受けて学習するという意味では、教師あり学習 (supervised learning) と似ているが、教師あり学習は一方的に「正解」を押し付けられるのに対し、強化学習は報酬の形で環境から呼びかけられるだけであって、自分なりの正解を探索しなければならない点が異なる (厳密な意味では、教師あり学習の「正解」は「フィードバック」ではない)。また、学習の途中で得た情報を学習に利用するという点では、進化的アルゴリズムとも異なる。

強化学習のエージェントは、状態 (state) $s$ に対して行動 (action) $a$ を決定する方策 (policy) $\pi(s, a)$ を持つ。行動に対して環境は状況に応じた報酬 (reward) $r$ を返してくる。この報酬を最大化するように方策を変更していくのが、強化学習における学習の仕組みである。とはいえ、目の前の美味しそうなケーキに目を奪われて、将来の不健康という負の報酬を見落としてしまっては困るので、実際は将来にわたる報酬の総体である収益 (return) $R$ を最大化することを考える。ただ、目先の利益を優先すべきか否かというのは、問題にも依存する。そこで、収益には人間の心理同様に割引率が考慮される。

マルコフ決定過程

強化学習問題は、ある状態から別の状態へとある確率で遷移する確率過程とみなせる。単純な確率過程は、「ある状態への遷移はそのひとつ前の状態にしか依存しない」というものである。この性質をマルコフ性といい、この性質をもった確率過程をマルコフ決定過程 (MDP) という。強化学習問題では、一般的にはこのマルコフ決定過程が想定される。

マルコフ決定過程は、簡単にいうと、状態間を結んだネットワーク図として描けるものである。状態間の遷移は、行動およびその行動に関する遷移確率で決まり、その遷移についての距離だったり点数だったりが報酬である。

収益と価値

ある状態・行動の系列の時刻 $t$ において期待される収益 $R_t$ は、それ以降の報酬の総和で定義される。

$$ R_t = r_{t+1} + \gamma r_{t+2} + \gamma^2 r_{t+3} + \cdots = r_{t+1} +\gamma R_{t+1} $$

ここで $\gamma$ は割引率で $0 \le \gamma \le 1$ である。この収益が大きくなるように方策を考えるのだが、状態の経路は様々なので、ある状態からある方策を用いて辿ったときの収益の期待値を考える。これを価値 (value) といい、特に方策 $\pi$ についての状態 $s$ の価値を状態価値 $V^\pi (s)$ として次式で定義する。

$$ V^\pi(s) = E_\pi\{R_t|s_t=s\} $$

ここで $E$ は期待値の意味である。さらに行動 $a$ についての行動価値 $Q^\pi (s, a)$ も定義できる。

$$ Q^\pi(s, a) = E_\pi\{R_t|s_t=s,a_t=a\} $$

マルコフ決定過程を考え、行動確率 (=方策) を $\pi(s, a)$ とし、行動 $a$ による状態 $s$ から $s'$ への遷移確率を $P^a_{ss'}$、その時の報酬の期待値を $R^a_{ss'}$ とすると、状態価値 $V^\pi (s)$ は次式のように書ける。

$$ V^\pi(s) = \sum_a \pi(s, a) \sum_{s'} P^a_{ss'} [R^a_{ss'} + \gamma V^\pi(s')] = \sum_a \pi(s, a) Q^\pi(s, a) $$

これは、ややこしそうに見えるが、状態 $s$ 以降のすべての行動、すべての遷移についての収益の平均値を計算しているだけである。この式を Bellman 方程式という。強化学習問題を解くにはこの方程式を解けばよいのだが、最適方策を求めるための最適状態価値を求めるために最適方策が必要だし、すべての状態からのすべての経路を計算するのは大変だし、そもそも遷移確率がわからない (やってみないとわからない) こともあるしで、簡単にはいかない。そういうわけで、うまいこと最適状態価値あるいは行動価値を推定しながら最適方策を探るような Bellman 方程式の近似解法がいくつも提案されており、それがつまり強化学習のアルゴリズムである。

動的計画法

反復方策評価

Bellman 方程式をわりと直接的に解こうとする方法が、動的計画法 (dynamic programming) である。まず、とりあえずは方策がわかっているとして、この方策による価値を評価することを考える。遷移確率についてもわかっているものとする。それでも、Bellman 方程式を計算するには問題がある。というのも、$V^\pi(s)$ は両辺に出てくるからである。そこで、これを反復計算で計算する。

$$ V_{k+1}(s) = \sum_a \pi(s, a) \sum_{s'} P^a_{ss'} [R^a_{ss'} + \gamma V_k(s')] $$

ここで $k$ は反復回数である。これを適当な初期値から繰り返し計算していき、変化が小さくなったところで止めれば、状態価値の近似値が得られる。この方法は反復方策評価 (iterative policy evaluation) と呼ばれる。

方策反復

方策を行動確率ではなく状態から行動を返す関数 $\pi(s)$ で表すことにしよう。つまり、状態に対してそれぞれ行動がひとつ決まるものとする。もし行動価値 $Q^\pi(s, a)$ がわかっているとき、$\pi(s)$ を決める方法として、$Q^\pi(s, a)$ が最大となる $a$ を取ることが考えられる。これを次式のように書く。

$$ \pi(s) = \arg \max_a Q^\pi(s, a) = \arg \max_a \sum_{s'} P^a_{ss'} [R^a_{ss'} + \gamma V^\pi(s')] $$

これはグリーディ (greedy) 方策と呼ばれる。

このとき、$V^\pi(s)$ は単純に次のように表される。

$$ V^\pi(s) = Q^\pi(s, \pi(s)) = \sum_{s'} P^{\pi(s)}_{ss'} [R^{\pi(s)}_{ss'} + \gamma V^\pi(s')] $$

反復形式では次のようになる。

$$ V_{k+1}(s) = Q_{k+1}(s, \pi(s)) = \sum_{s'} P^{\pi(s)}_{ss'} [R^{\pi(s)}_{ss'} + \gamma V_k(s')] $$

反復計算で $V_k(s)$ を求め、それにより求まった $Q_k(s, a)$ から $\pi(s)$ を更新し、再度 $V^\pi_k(s)$ を計算して…という感じで $\pi(s)$ が変化しなくなるまで繰り返すと、改善された方策が得られる。この方法は方策反復 (policy iteration) と呼ばれる。

価値反復

ところで、グリーディ方策は、次式を満たす $\pi(s)$ である。

$$ Q^\pi(s, \pi(s)) = \max_a Q^\pi(s,a) $$

$\pi(s)$ が最適方策になっているとしたら、最適価値 $V^*(s)$ は次式で表される。

$$ V^*(s) = \max_a Q^*(s, a) = \max_a \sum_{s'} P^a_{ss'} [R^a_{ss'} + \gamma V^*(s')] $$

これを反復計算に用いる方法も考えられる。

$$ V_{k+1}(s) = \max_a \sum_{s'} P^a_{ss'} [R^a_{ss'} + \gamma V_k(s')] $$

反復計算で $V_k(s)$ を求め、それにより求まった $Q_k(s, a)$ から $\pi(s)$ を求めて終わりである。この方法は価値反復 (value iteration) と呼ばれる。

格子世界

2 次元の格子世界 (gridworld) を動き回る問題を考える。動く方向は上下右左の 4 方向とする。

Python で問題を記述する。必要なモジュールを準備する。

import numpy as np

格子世界を実装する。

class Gridworld:
    def __init__(self, nx, ny, goals):
        self.nx = nx
        self.ny = ny
        
        self.goals = []
        for g in goals:
            self.goals.append(g[0] + g[1]*self.nx)
        
        self.stop = self.goals.copy()
        
    def get_states(self):
        return np.arange(self.nx*self.ny)
    
    def get_actions(self):
        return np.arange(0, 4)
    
    def step(self, s, a):
        assert(a < 4)
        
        if s in self.goals:
            r = 0
        else:
            i = s % self.nx
            j = s//self.nx
            
            if a == 0: # up
                j += 1
            elif a == 1: # down
                j -= 1
            elif a == 2: # right
                i += 1
            elif a == 3: # left
                i -=1
            
            i = max(i, 0)
            i = min(i, self.nx - 1)
            j = max(j, 0)
            j = min(j, self.ny - 1)
            
            s = i + self.nx*j
            r = -1
            
        return s, r
    
    def display_grid(self):
        for j in range(self.ny-1, -1, -1):
            print("{} ".format(j % 10), end="")
            for i in range(self.nx):
                s = i + j*self.nx
                if s in self.goals:
                    print("G ", end="")
                else:
                    print(". ", end="")
            print()

        print("  ", end="")
        for i in range(self.nx):
            print("{} ".format(i % 10), end="")
        print()
        
    def display_grid_value(self, v, d=6):
        for j in range(self.ny-1, -1, -1):
            print("{} ".format(j), end="")
            for i in range(self.nx):
                s = i + j*self.nx
                print(("{:>%d.2f} " % d).format(v[s]), end="")
            print()

        print("  ", end="")
        for i in range(self.nx):
            print(("{:>%d} " % d).format(i), end="")
        print()

    def display_grid_action(self, pi):
        for j in range(self.ny-1, -1, -1):
            print("{} ".format(j % 10), end="")
            for i in range(self.nx):
                s = i + j*self.nx
                if s in self.goals:
                    print("G ", end="")
                else:
                    if len(pi.shape) == 1:
                        a = pi[s]
                    else:
                        a = np.argmax(pi[s, :])
                    if a == 0: # up
                        print("↑ ", end="")
                    elif a == 1: # down
                        print("↓ ", end="")
                    elif a == 2: # right
                        print("→ ", end="")
                    elif a == 3: # left
                        print("← ", end="")
            print()

        print("  ", end="")
        for i in range(self.nx):
            print("{} ".format(i % 10), end="")
        print()

格子の大きさを nx x ny とし、格子の位置を (i, j) で表すと、状態変数は s = i + j*nx とする。左下を (0, 0) とする。報酬はひとつ動くたびに -1 とし、ゴールでは 0 とする。終端状態に移動するときではなく、終端状態から移動しようとしたときに報酬 0 が得られる。

ここでは 4 x 4 の格子とし、ゴールは (0, 3), (3, 0) の 2 つとする。

nx = 4
ny = 4
goals = [(0, ny-1), (nx-1, 0)]

gw = Gridworld(nx, ny, goals)
gw.display_grid()
3 G . . . 
2 . . . . 
1 . . . . 
0 . . . G 
  0 1 2 3 

反復方策評価

行動はランダムとして、反復方策評価で状態価値を計算後、グリーディ方策を求める。

def gridworld_iterative_policy_evaluation(gw, tol=1e-5):
    states = gw.get_states()
    actions = gw.get_actions()
    n = len(states)
    
    v = np.zeros(n)
    p = 1./len(actions)
    pi = np.zeros(n, dtype=int)
    
    itr = 0
    while True:
        itr += 1
        v0 = v.copy()
        for s in states:
            vnew = 0.
            for a in actions:
                s2, r = gw.step(s, a)
                vnew += p*(r + v[s2])
            v[s] = vnew
        if np.max(np.abs(v0 - v)) < tol:
            break
    
    print("iter = {}".format(itr))
    
    for s in states:
        qs = np.zeros(len(actions))
        for a in actions:
            s2, r = gw.step(s, a)
            qs[a] = r + v[s2]
        pi[s] = actions[np.argmax(qs)]
            
    return v, pi

この問題の場合、行動の結果の状態は 1 通りなので、幾分簡単である。

実行。

v_ipe, pi_ipe = gridworld_iterative_policy_evaluation(gw)
iter = 142

状態価値の表示。

gw.display_grid_value(v_ipe)
3   0.00 -14.00 -20.00 -22.00 
2 -14.00 -18.00 -20.00 -20.00 
1 -20.00 -20.00 -18.00 -14.00 
0 -22.00 -20.00 -14.00   0.00 
       0      1      2      3

この場合の状態価値はゴールまでのステップ数を表すが、妙に大きい値である。これはランダムに動く前提のためである。

方策の表示。

gw.display_grid_action(pi_ipe)
3 G ← ← ↓ 
2 ↑ ← ↓ ↓ 
1 ↑ ↑ ↓ ↓ 
0 ↑ → → G 
  0 1 2 3 

ゴール (G) に向かう方向になっている。状態価値が同じ方向が複数ある場合があるが、ここではそのうちの 1 つだけ表示している。

方策反復

方策反復で方策を求める。ここでは、状態価値の計算が止まらなかったので、最大回数を設定している。

def gridworld_policy_iteration(gw, tol=1e-5, max_viter=100):
    states = gw.get_states()
    actions = gw.get_actions()
    n = len(states)
    
    v = np.zeros(n)
    pi = np.zeros(n, dtype=int)
    
    itr = 0
    while True:
        for i in range(max_viter):
            itr += 1
            v0 = v.copy()
            for s in states:
                a = pi[s]
                s2, r = gw.step(s, a)
                v[s] = r + v[s2]
            if np.max(np.abs(v0 - v)) < tol:
                break
        
        pi0 = pi.copy()
        for s in states:
            qs = np.zeros(len(actions))
            for a in actions:
                s2, r = gw.step(s, a)
                qs[a] = r + v[s2]
            pi[s] = actions[np.argmax(qs)]
        
        if np.all(pi == pi0):
            break
    
    print("iter = {}".format(itr))
            
    return v, pi

実行。

v_p, pi_p = gridworld_policy_iteration(gw)
iter = 302

状態価値。

gw.display_grid_value(v_p)
3   0.00  -1.00  -2.00  -3.00 
2  -1.00  -2.00  -3.00  -2.00 
1  -2.00  -3.00  -2.00  -1.00 
0  -3.00  -2.00  -1.00   0.00 
       0      1      2      3 

価値はゴールまでのステップ数になっている。

方策。

gw.display_grid_action(pi_p)
3 G ← ← ↓ 
2 ↑ ↑ ↑ ↓ 
1 ↑ ↑ ↓ ↓ 
0 ↑ → → G 
  0 1 2 3 

価値反復

価値反復で方策を求める。

def gridworld_value_iteration(gw, tol=1e-5, max_viter=100):
    states = gw.get_states()
    actions = gw.get_actions()
    n = len(states)
    
    v = np.zeros(n)
    pi = np.zeros(n, dtype=int)
    
    itr = 0
    for i in range(max_viter):
        itr += 1
        v0 = v.copy()
        for s in states:
            qs = np.zeros(len(actions))
            for a in actions:
                s2, r = gw.step(s, a)
                qs[a] = r + v[s2]
            v[s] = np.max(qs)
        if np.max(np.abs(v0 - v)) < tol:
            break
    
    print("iter = {}".format(itr))

    for s in states:
        qs = np.zeros(len(actions))
        for a in actions:
            s2, r = gw.step(s, a)
            qs[a] = r + v[s2]
        pi[s] = actions[np.argmax(qs)]
            
    return v, pi

実行。

v_v, pi_v = gridworld_value_iteration(gw)
iter = 4

状態価値。

gw.display_grid_value(v_v)
3   0.00  -1.00  -2.00  -3.00 
2  -1.00  -2.00  -3.00  -2.00 
1  -2.00  -3.00  -2.00  -1.00 
0  -3.00  -2.00  -1.00   0.00 
       0      1      2      3 

方策。

gw.display_grid_action(pi_v)
3 G ← ← ↓ 
2 ↑ ↑ ↑ ↓ 
1 ↑ ↑ ↓ ↓ 
0 ↑ → → G 
  0 1 2 3

結果は方策反復の場合と同じである。

風の吹く格子世界

下から上に向かって風が吹く格子世界を考える。

class WindyGridworld(Gridworld):
    def __init__(self, nx, ny, goals, wind):
        super().__init__(nx, ny, goals)
        self.wind = np.array(wind, dtype=int)
    
    def step(self, s, a):
        assert(a < 4)
        
        if s in self.goals:
            r = 0
        else:
            i = s % self.nx
            j = s//self.nx
            
            j += self.wind[i]
            
            if a == 0: # up
                j += 1
            elif a == 1: # down
                j -= 1
            elif a == 2: # right
                i += 1
            elif a == 3: # left
                i -=1
            
            i = max(i, 0)
            i = min(i, self.nx - 1)
            j = max(j, 0)
            j = min(j, self.ny - 1)
            
            s = i + self.nx*j
            r = -1
            
        return s, r
    
    def display_grid(self):
        super().display_grid()
        print("w ", end="")
        for w in self.wind:
            print("{} ".format(w), end="")
        print()

    def display_grid_action(self, pi):
        super().display_grid_action(pi)
        print("w ", end="")
        for w in self.wind:
            print("{} ".format(w), end="")
        print()

風は格子の水平方向にそれぞれ整数で設定されており、その位置から動くときに、上方向にその個数分だけ動く。

ここでは、10 x 7 の格子を考え、ゴールは (7, 3) とする。ゴール付近に風を設定する。

nx = 10
ny = 7
goals = [(7, 3)]
wind = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]

gw = WindyGridworld(nx, ny, goals, wind)
gw.display_grid()
6 . . . . . . . . . . 
5 . . . . . . . . . . 
4 . . . . . . . . . . 
3 . . . . . . . G . . 
2 . . . . . . . . . . 
1 . . . . . . . . . . 
0 . . . . . . . . . . 
  0 1 2 3 4 5 6 7 8 9 
w 0 0 0 1 1 1 2 2 1 0 

一番下の数字の列が風である。この場合、ゴール (G) 横の (6, 3) から右に動くと、(7, 5) に行ってしまうし、(8, 3) から左に動くと (7, 4) に行ってしまう。

方策反復

方策反復で計算するには、格子世界で用いた関数がそのまま使える。

v_p, pi_p = gridworld_policy_iteration(gw)
iter = 1502

状態価値。

gw.display_grid_value(v_p)
6 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00  -8.00  -7.00  -6.00 
5 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00  -8.00  -7.00  -5.00 
4 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00  -8.00  -6.00  -4.00 
3 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00   0.00  -5.00  -3.00 
2 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00  -1.00  -1.00  -2.00 
1 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -1.00  -2.00  -2.00  -3.00 
0 -15.00 -14.00 -13.00 -12.00 -11.00  -2.00  -2.00  -1.00  -2.00  -3.00 
       0      1      2      3      4      5      6      7      8      9 

方策。

gw.display_grid_action(pi_p)
6 → → → → → → → → → ↓ 
5 → → → → → → → → → ↓ 
4 → → → → → → → → → ↓ 
3 → → → → → → → G → ↓ 
2 → → → → → → → ↓ ← ← 
1 → → → → → → → ↓ ← ↑ 
0 → → → → → → ↓ ↑ ↑ ← 
  0 1 2 3 4 5 6 7 8 9 
w 0 0 0 1 1 1 2 2 1 0 

もし (0, 3) からスタートしたなら、右に (3, 3) まで進んだ後、風に押されて (6, 6) まで行ってしまう。そこから (9, 6) に進み、風がないので (9, 2) まで進める。そこから左に動けば (8, 2) から風に乗ってゴールの (7, 3) に到達できる。

価値反復

価値反復の場合も、格子世界の関数がそのまま使える。

v_v, pi_v = gridworld_value_iteration(gw)
iter = 16

状態価値。

gw.display_grid_value(v_v)
6 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00  -8.00  -7.00  -6.00 
5 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00  -8.00  -7.00  -5.00 
4 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00  -8.00  -6.00  -4.00 
3 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00   0.00  -5.00  -3.00 
2 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -9.00  -1.00  -1.00  -2.00 
1 -15.00 -14.00 -13.00 -12.00 -11.00 -10.00  -1.00  -2.00  -2.00  -3.00 
0 -15.00 -14.00 -13.00 -12.00 -11.00  -2.00  -2.00  -1.00  -2.00  -3.00 
       0      1      2      3      4      5      6      7      8      9 

方策。

gw.display_grid_action(pi_v)
6 → → → → → → → → → ↓ 
5 → → → → → → → → → ↓ 
4 → → → → → → → → → ↓ 
3 → → → → → → → G → ↓ 
2 → → → → → → → ↓ ← ← 
1 → → → → → → → ↓ ← ↑ 
0 ↑ → → → → → ↓ ↑ ↑ ← 
  0 1 2 3 4 5 6 7 8 9 
w 0 0 0 1 1 1 2 2 1 0 

方策反復の場合と同じである。

まとめ

動的計画法は、全状態についての計算を行なっている。格子世界でゴールは設定されているのにスタートが設定されていないのはそのためである。どこからスタートしてもよい情報を求めているわけだが、いつもそれほどの情報が必要なわけではないし、そもそも全状態の計算など現実的には不可能な場合もある。とにかく実際に走ってみるしかできないこともある。そういった現実的な問題のために、さまざまな手法が提案されているが、動的計画法はそれらの手法の基盤となるものである。