時間差分学習

2020年11月18日

はじめに

Sutton and Barto 著『強化学習』(Reinforcement Learning An Introduction) を参考に、時間差分学習について見る。

環境

  • Miniconda3 (Python 3.7)

時間差分学習

時間差分学習

状態 $s$、行動 $a$ についての方策 $\pi(s, a)$ に関する状態価値を表す Bellman 方程式は、次式で表される。

$$ V^\pi(s) = \sum_a \pi(s, a) \sum_{s'} P^a_{ss'} [R^a_{ss'} + \gamma V^\pi(s')] $$

ここで、ある状態・行動の系列 (エピソード) の時刻 $t$ において期待される収益 $R_t$ は次式で表される。

$$ R_t = r_{t+1} + \gamma r_{t+2} + \gamma^2 r_{t+3} + \cdots = r_{t+1} + \gamma R_{t+1} $$

$r_t$ は報酬、$\gamma$ は割引率である。

Bellman 方程式を近似的に解くために、動的計画法では、繰り返し計算の形が取られる。

$$ V_{k+1}(s) = \sum_a \pi(s, a) \sum_{s'} P^a_{ss'} [R^a_{ss'} + \gamma V_k(s')] $$

モンテカルロ法では、あるエピソードをとにかく止まるまで走ってみて、直接的に収益を計算し、その平均をとる。

$$ V(s_t) = V(s_t) + \alpha (R_t - V(s_t)) $$

ここで、$\alpha$ はステップサイズパラメータである。ここでは逐一訪問 MC 法を想定している。めんどくさいので添字を省略しているが、繰り返し計算の形である。

ところで、収益は $R_t = r_{t+1} + \gamma R_{t+1}$ と書けるが、状態価値は収益の期待値なので $R_{t+1} \approx V(s_{t+1})$ と置ける。そこで、上の式を次のように書く。

$$ V(s_t) = V(s_t) + \alpha \{r_{t+1} + \gamma V(s_{t+1}) - V(s_t)\} $$

この形式だと、モンテカルロ法のように、収益を求めるために最後まで走ってみる必要はなく、せいぜい 1 ステップだけ進めばよい。この方法を時間差分学習 (temporal-difference learning) あるいは TD 学習という。

本来であれば $V(s_t) = r_{t+1} + \gamma V(s_{t+1})$ なので、上式の第 2 項は 0 であるが、繰り返し計算中は成り立たない。そこで、次式 $\delta_t$ を TD 誤差 (TD error) と呼ぶ。

$$ \delta_t = r_{t+1} + \gamma V(s_{t+1}) - V(s_t) $$

TD(0)

上述の TD 形式の状態価値の式で方策を評価する方法を TD(0) と呼ぶ。

Sarsa

方策を改善するには、行動価値の方がよい。上述の TD 形式の状態価値の式を行動価値になおすと、次式のようになる。

$$ Q(s_t, a_t) = Q(s_t, a_t) + \alpha \{r_{t+1} + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t)\} $$

まず、現状の方策でもって $s_t$ から $a_t$ を得て、それでもって行動して $r_{t+1}$、$s_{t+1}$ を得る。さらに、現状の方策でもって $a_{t+1}$ を得て、上式で $Q(s_t, a_t)$ を計算して $\varepsilon$ グリーディ方策などとして方策改善を行う。$s_t$、$a_t$、$r_{t+1}$、$s_{t+1}$、$a_{t+1}$ を用いるので、Sarsa というらしい。実際の行動の方策と、行動価値を計算する際の行動の方策が同じなので、これは方策オン型である。

Q 学習

行動価値を計算する際の $a_{t+1}$ はグリーディに決めればいいじゃん、というのが Q 学習である。

$$ Q(s_t, a_t) = Q(s_t, a_t) + \alpha \{r_{t+1} + \gamma \max_a Q(s_{t+1}, a) - Q(s_t, a_t)\} $$

実際の行動の方策と、行動価値のための方策が別なので、方策オフ型である。

Expected Sarsa

行動価値 $Q(s_{t+1}, a_{t+1})$ を期待値 (平均値) にしよう、というのが Expected Sarsa である。

$$ Q(s_t, a_t) = Q(s_t, a_t) + \alpha \{r_{t+1} + \gamma \sum_a \pi(s_{t+1}, a) Q(s_{t+1}, a) - Q(s_t, a_t)\} $$

アクター・クリティック法

アクター・クリティック法 (actor-critic method) は、行動するアクター (actor) と、それを評価するクリティック (critic) とに分けて処理する手法である。と、これだけだと Sarsa や Q 学習と違わないように聞こえるが、この手法は行動価値を用いないところに特徴がある。

TD 学習に、n 本腕バンディット問題の解法の強化比較の考え方を用いる。強化比較では報酬と参照報酬との比較を用いているが、ここでは収益と状態価値との比較を用いる。つまり、TD 誤差を用いて優先度を計算する。

$$ p(s_t, a_t) = p(s_t, a_t) + \beta \delta_t \{1 - \pi(s_t, a_t)\} $$

状態価値は (上述の通りだが) 次式で更新する。

$$ V(s_t) = V(s_t) + \alpha \delta_t $$

方策はソフトマックス手法で求める。

$$ \pi_t(s, a) = \frac{e^{p(s, a)}}{\sum_{b=1}^n e^{p(s, b)}} $$

行動価値を用いて方策を求める場合、方策は「行動価値が最大になるような行動をとる」というように、陰的に表現されている。離散的な行動であればこれで問題ないが、連続的な行動の場合、行動価値から方策を求めるのはちょっと大変である。この点、アクター・クリティック法のように方策が陽に表現されていると有利である。

格子世界

2 次元の格子世界 (gridworld) を動き回る問題を考える。動く方向は上下右左の 4 方向とする。

Python で問題を記述する。必要なモジュールを準備する。

import numpy as np
from tqdm import tqdm

格子世界を実装する。

class Gridworld:
    def __init__(self, nx, ny, goals):
        self.nx = nx
        self.ny = ny
        
        self.goals = []
        for g in goals:
            self.goals.append(g[0] + g[1]*self.nx)
        
        self.stop = self.goals.copy()
        
    def get_states(self):
        return np.arange(self.nx*self.ny)
    
    def get_actions(self):
        return np.arange(0, 4)
    
    def step(self, s, a):
        assert(a < 4)
        
        if s in self.goals:
            r = 0
        else:
            i = s % self.nx
            j = s//self.nx
            
            if a == 0: # up
                j += 1
            elif a == 1: # down
                j -= 1
            elif a == 2: # right
                i += 1
            elif a == 3: # left
                i -=1
            
            i = max(i, 0)
            i = min(i, self.nx - 1)
            j = max(j, 0)
            j = min(j, self.ny - 1)
            
            s = i + self.nx*j
            r = -1
            
        return s, r
    
    def display_grid(self):
        for j in range(self.ny-1, -1, -1):
            print("{} ".format(j % 10), end="")
            for i in range(self.nx):
                s = i + j*self.nx
                if s in self.goals:
                    print("G ", end="")
                else:
                    print(". ", end="")
            print()

        print("  ", end="")
        for i in range(self.nx):
            print("{} ".format(i % 10), end="")
        print()
        
    def display_grid_value(self, v, d=6):
        for j in range(self.ny-1, -1, -1):
            print("{} ".format(j), end="")
            for i in range(self.nx):
                s = i + j*self.nx
                print(("{:>%d.2f} " % d).format(v[s]), end="")
            print()

        print("  ", end="")
        for i in range(self.nx):
            print(("{:>%d} " % d).format(i), end="")
        print()
        
    def display_grid_action(self, pi):
        for j in range(self.ny-1, -1, -1):
            print("{} ".format(j % 10), end="")
            for i in range(self.nx):
                s = i + j*self.nx
                if s in self.goals:
                    print("G ", end="")
                else:
                    if len(pi.shape) == 1:
                        a = pi[s]
                    else:
                        a = np.argmax(pi[s, :])
                    if a == 0: # up
                        print("↑ ", end="")
                    elif a == 1: # down
                        print("↓ ", end="")
                    elif a == 2: # right
                        print("→ ", end="")
                    elif a == 3: # left
                        print("← ", end="")
            print()

        print("  ", end="")
        for i in range(self.nx):
            print("{} ".format(i % 10), end="")
        print()

格子の大きさを nx x ny とし、格子の位置を (i, j) で表すと、状態変数は s = i + j*nx とする。左下を (0, 0) とする。報酬はひとつ動くたびに -1 とし、ゴールでは 0 とする。

ここでは 4 x 4 の格子とし、ゴールは (0, 3), (3, 0) の 2 つとする。

nx = 4
ny = 4
goals = [(0, ny-1), (nx-1, 0)]

gw = Gridworld(nx, ny, goals)
gw.display_grid()
3 G . . . 
2 . . . . 
1 . . . . 
0 . . . G 
  0 1 2 3 

TD(0)

行動はランダムとして、TD(0) で状態価値を計算後、グリーディ方策を求める。

def gridworld_td0(gw, episodes, steps, alpha):
    states = gw.get_states()
    actions = gw.get_actions()
    n = len(states)
    
    v = np.zeros(n)
    pi = np.zeros(n, dtype=int)
    
    for episode in tqdm(range(episodes)):
        s = np.random.choice(states)
        for i in range(steps):
            a = np.random.choice(actions)
            s2, r = gw.step(s, a)
            v[s] += alpha*(r + v[s2] - v[s])
            
            if s in gw.stop:
                break
            
            s = s2
            
    for s in states:
        qs = np.zeros(len(actions))
        for a in actions:
            s2, r = gw.step(s, a)
            qs[a] = r + v[s2]
        pi[s] = actions[np.argmax(qs)]
            
    return v, pi

実行。

np.random.seed(0)
v_td0, pi_td0 = gridworld_td0(gw, 10000, 100, 0.01)

ここでは再現性確保のため、乱数シードを固定している。実行関数の数字の 1 つめはエピソード数、2 つめは計算ステップ数、3 つめはステップサイズパラメータである。

状態価値の表示。

gw.display_grid_value(v_td0)
3   0.00 -14.62 -20.34 -21.69 
2 -13.67 -17.92 -19.94 -19.66 
1 -19.57 -19.89 -18.14 -13.92 
0 -21.83 -19.74 -13.96   0.00 
       0      1      2      3 

動的計画法で計算した状態価値は、以下の通りなので、まあまあの値である。

3   0.00 -14.00 -20.00 -22.00 
2 -14.00 -18.00 -20.00 -20.00 
1 -20.00 -20.00 -18.00 -14.00 
0 -22.00 -20.00 -14.00   0.00 
       0      1      2      3

方策の表示。

gw.display_grid_action(pi_td0)
3 G ← ← ↓ 
2 ↑ ← ← ↓ 
1 ↑ ↑ → ↓ 
0 ↑ → → G 
  0 1 2 3 

動的計画法で計算した方策は、以下の通りである。ただし、行動価値が同じ方向が複数ある場合でも、1 つだけを表示している。

3 G ← ← ↓ 
2 ↑ ← ↓ ↓ 
1 ↑ ↑ ↓ ↓ 
0 ↑ → → G 
  0 1 2 3

Sarsa

Sarsa で方策を求める。ここでは、初期状態はランダムに選ぶようにした。

def gridworld_sarsa(gw, episodes, steps, alpha, eps=0.):
    states = gw.get_states()
    actions = gw.get_actions()
    n = len(states)
    m = len(actions)
    
    v = np.zeros(n)
    q = np.zeros((n, m))
    pi = np.zeros((n, m)) + 1./m
    
    for episode in tqdm(range(episodes)):
        s = np.random.choice(states)
        a = np.random.choice(actions, p=pi[s, :])
        for i in range(steps):
            s2, r = gw.step(s, a)
            
            a2 = np.random.choice(actions, p=pi[s2, :])
            q[s, a] += alpha*(r + q[s2, a2] - q[s, a])
            
            a_max = actions[np.argmax(q[s, :])]
            pi[s, :] = eps/m
            pi[s, a_max] += 1. - eps
            
            if s in gw.stop:
                break
            
            s = s2
            a = a2
            
    for s in states:
        v[s] = np.max(q[s, :])
            
    return v, pi

実行。

np.random.seed(0)
v_sarsa, pi_sarsa = gridworld_sarsa(gw, 1000, 5, 0.5, 0.01)

ここでは、$\varepsilon$ を 0.01 に設定している。

状態価値。

gw.display_grid_value(v_sarsa)
3   0.00  -1.00  -2.00  -3.00 
2  -1.00  -2.00  -3.00  -2.00 
1  -2.00  -3.00  -2.00  -1.00 
0  -3.00  -2.00  -1.00   0.00 
       0      1      2      3

動的計画法による結果は、以下の通りである。

3   0.00  -1.00  -2.00  -3.00 
2  -1.00  -2.00  -3.00  -2.00 
1  -2.00  -3.00  -2.00  -1.00 
0  -3.00  -2.00  -1.00   0.00 
       0      1      2      3

方策。

gw.display_grid_action(pi_sarsa)
3 G ← ← ↓ 
2 ↑ ← → ↓ 
1 ↑ ↑ ↓ ↓ 
0 → → → G 
  0 1 2 3

動的計画法による結果は、以下の通りである。

3 G ← ← ↓ 
2 ↑ ↑ ↑ ↓ 
1 ↑ ↑ ↓ ↓ 
0 ↑ → → G 
  0 1 2 3

Q 学習

Q 学習で方策を求める。ここでは、初期状態はランダムに選ぶようにした。

def gridworld_q(gw, episodes, steps, alpha, eps=0.):
    states = gw.get_states()
    actions = gw.get_actions()
    n = len(states)
    m = len(actions)
    
    v = np.zeros(n)
    q = np.zeros((n, m))
    pi = np.zeros((n, m)) + 1./m
    
    for episode in tqdm(range(episodes)):
        s = np.random.choice(states)
        for i in range(steps):
            a = np.random.choice(actions, p=pi[s, :])
            s2, r = gw.step(s, a)
            
            a2 = actions[np.argmax(q[s2, :])]
            q[s, a] += alpha*(r + q[s2, a2] - q[s, a])
            
            a_max = actions[np.argmax(q[s, :])]
            pi[s, :] = eps/m
            pi[s, a_max] += 1. - eps
            
            if s in gw.stop:
                break
            
            s = s2
            
    for s in states:
        v[s] = np.max(q[s, :])
            
    return v, pi

実行。

np.random.seed(0)
v_q, pi_q = gridworld_q(gw, 1000, 5, 0.5, 0.01)

ここでは、$\varepsilon$ を 0.01 に設定している。

状態価値。

gw.display_grid_value(v_q)
3   0.00  -1.00  -2.00  -3.00 
2  -1.00  -2.00  -3.00  -2.00 
1  -2.00  -3.00  -2.00  -1.00 
0  -3.00  -2.00  -1.00   0.00 
       0      1      2      3 

方策。

gw.display_grid_action(pi_q)
3 G ← ← ↓ 
2 ↑ ← ↑ ↓ 
1 ↑ ↓ → ↓ 
0 → → → G 
  0 1 2 3 

Expected Sarsa

Expected Sarsa で方策を求める。ここでは、初期状態は各状態を順番に選ぶものとした。

def gridworld_expected_sarsa(gw, episodes, steps, alpha, eps=0.):
    states = gw.get_states()
    actions = gw.get_actions()
    n = len(states)
    m = len(actions)
    
    v = np.zeros(n)
    q = np.zeros((n, m))
    pi = np.zeros((n, m)) + 1./m
    
    for episode in tqdm(range(episodes)):
        s = np.random.choice(states)
        for i in range(steps):
            a = np.random.choice(actions, p=pi[s, :])
            s2, r = gw.step(s, a)
            
            q2 = 0.
            for a2 in actions:
                q2 += pi[s2, a2]*q[s2, a2]
                
            q[s, a] += alpha*(r + q2 - q[s, a])
            
            a_max = actions[np.argmax(q[s, :])]
            pi[s, :] = eps/m
            pi[s, a_max] += 1. - eps
            
            if s in gw.stop:
                break
            
            s = s2
            
    for s in states:
        v[s] = np.max(q[s, :])
            
    return v, pi

実行。

np.random.seed(0)
v_esarsa, pi_esarsa = gridworld_expected_sarsa(gw, 1000, 5, 0.5, 0.01)

状態価値。

gw.display_grid_value(v_esarsa)
3   0.00  -1.00  -2.00  -3.01 
2  -1.00  -2.00  -3.00  -2.00 
1  -2.00  -3.00  -2.00  -1.00 
0  -3.01  -2.01  -1.00   0.00 

方策。

gw.display_grid_action(pi_esarsa)
3 G ← ← ← 
2 ↑ ↑ ↑ ↓ 
1 ↑ ↑ → ↓ 
0 ↑ → → G 
  0 1 2 3 

アクター・クリティック法

アクター・クリティック法で方策を求める。ここでは、初期状態は各状態を順番に選ぶものとした。

def gridworld_actor_critic(gw, episodes, steps, alpha, beta):
    states = gw.get_states()
    actions = gw.get_actions()
    n = len(states)
    m = len(actions)
    
    v = np.zeros(n)
    p = np.ones((n, m))
    pi = np.zeros((n, m)) + 1./m
    
    for episode in tqdm(range(episodes)):
        s = np.random.choice(states)    
        for i in range(steps):
            # actor
            a = np.random.choice(actions, p=pi[s, :])
            s2, r = gw.step(s, a)
            
            # critic
            delta = r + v[s2] - v[s]
            p[s, a] += beta*delta*(1. - pi[s, a])
            v[s] += alpha*delta
            
            pin = np.exp(p[s, :])
            pid = np.sum(pin)
            pi[s, :] = pin/pid
            
            if s in gw.stop:
                break
            
            s = s2
            
    return v, pi

実行。

np.random.seed(0)
v_ac, pi_ac = gridworld_actor_critic(gw, 1000, 5, 0.5, 1.)

状態価値。

gw.display_grid_value(v_ac)
3   0.00  -1.00  -2.00  -3.00 
2  -1.00  -2.00  -3.00  -2.00 
1  -2.00  -3.00  -2.00  -1.00 
0  -3.02  -2.06  -1.00   0.00 
       0      1      2      3 

方策。

gw.display_grid_action(pi_ac)
3 G ← ← ← 
2 ↑ ↑ ← ↓ 
1 ↑ ↑ → ↓ 
0 ↑ → → G 
  0 1 2 3 

崖のある格子世界

崖がある格子世界を考える。

class CliffGridworld(Gridworld):
    def __init__(self, nx, ny, goals):
        super().__init__(nx, ny, goals)
        self.cliff = list(range(1, nx - 1))
        self.stop += self.cliff
    
    def step(self, s, a):
        assert(a < 4)
        
        if s in self.goals:
            r = 0
        elif s in self.cliff:
            r = -100
        else:
            i = s % self.nx
            j = s//self.nx
            
            if a == 0: # up
                j += 1
            elif a == 1: # down
                j -= 1
            elif a == 2: # right
                i += 1
            elif a == 3: # left
                i -=1
            
            i = max(i, 0)
            i = min(i, self.nx - 1)
            j = max(j, 0)
            j = min(j, self.ny - 1)
            
            s = i + self.nx*j
            r = -1
            
        return s, r
    
    def display_grid(self):
        for j in range(self.ny-1, -1, -1):
            print("{} ".format(j % 10), end="")
            for i in range(self.nx):
                s = i + j*self.nx
                if s in self.goals:
                    print("G ", end="")
                elif s in self.cliff:
                    print("C ", end="")
                else:
                    print(". ", end="")
            print()

        print("  ", end="")
        for i in range(self.nx):
            print("{} ".format(i % 10), end="")
        print()

    def display_grid_action(self, pi):
        for j in range(self.ny-1, -1, -1):
            print("{} ".format(j % 10), end="")
            for i in range(self.nx):
                s = i + j*self.nx
                if s in self.goals:
                    print("G ", end="")
                elif s in self.cliff:
                    print("C ", end="")
                else:
                    if len(pi.shape) == 1:
                        a = pi[s]
                    else:
                        a = np.argmax(pi[s, :])
                    if a == 0: # up
                        print("↑ ", end="")
                    elif a == 1: # down
                        print("↓ ", end="")
                    elif a == 2: # right
                        print("→ ", end="")
                    elif a == 3: # left
                        print("← ", end="")
            print()

        print("  ", end="")
        for i in range(self.nx):
            print("{} ".format(i % 10), end="")
        print()

格子世界の左下をスタートとし、右下をゴールとして、その間を崖とする。移動中の報酬は -1 だが、崖の報酬は -100 で、崖に入ったらエピソードは終了するものとする。

nx = 12
ny = 4
goals = [(nx-1, 0)]

gw = CliffGridworld(nx, ny, goals)
gw.display_grid()
3 . . . . . . . . . . . . 
2 . . . . . . . . . . . . 
1 . . . . . . . . . . . . 
0 . C C C C C C C C C C G 
  0 1 2 3 4 5 6 7 8 9 0 1

上の "C" が崖である。スタート地点は (0, 0) とするが、計算上は特にスタート地点を設けないものとする。

Sarsa

Sarsa は、格子世界の関数をそのまま使えばよい。ここでは $\varepsilon = 0.1$ とする。

np.random.seed(0)
v_sarsa, pi_sarsa = gridworld_sarsa(gw, 10000, 100, 0.05, 0.1)

方策。

gw.display_grid_action(pi_sarsa)
3 → → → → → → → → → → → ↓ 
2 → → → → → → → → → → → ↓ 
1 ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ → ↓ 
0 ↑ C C C C C C C C C C G 
  0 1 2 3 4 5 6 7 8 9 0 1 

崖から少し離れたところを通っている。パラメータをいじると通る道が多少変わるが、基本的に崖から離れがちである。方策オン型らしい挙動である。

Sarsa は、他の手法に比べると、結果が安定しにくい (探索行動の影響を受けやすい?) 感じがする。

Q 学習

Q 学習も、格子世界の関数をそのまま使えばよい。

np.random.seed(0)
v_q, pi_q = gridworld_q(gw, 3000, 100, 0.5, 0.1)

方策。

gw.display_grid_action(pi_q)
3 ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ 
2 ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ 
1 → → → → → → → → → → → ↓ 
0 ↑ C C C C C C C C C C G 
  0 1 2 3 4 5 6 7 8 9 0 1 

最短の経路を通っている。それどころか、動的計画法と同様の結果である。方策オフ型らしく、探索の影響をあまり受けずに最適経路を見つけられている。また、Sarsa もそうだが、モンテカルロ法に比べて少ないエピソード数で解が得られる。Q 学習の場合は、ステップサイズパラメータを結構大きくしても解が得られる。Sarsa はもっと小さくする必要があるようである。

これだけ見ると Q 学習は Sarsa より圧倒的に性能がよさそうに見えるが、これがシミュレータ上であればよいが、もしソフト方策で動く現実のロボットに学習させる場合、Q 学習ではぎりぎりを攻めることになるため、たまに崖に落ちることになるわけで、どちらがよいかは問題によりけりである。

Expected Sarsa

np.random.seed(0)
v_esarsa, pi_esarsa = gridworld_expected_sarsa(gw, 3000, 100, 0.5, 0.1)

方策。

gw.display_grid_action(pi_esarsa)
3 → → → → → → → → → → → ↓ 
2 ↑ → → → → → → → → → → ↓ 
1 ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ → ↓ 
0 ↑ C C C C C C C C C C G 
  0 1 2 3 4 5 6 7 8 9 0 1 

Sarsa よりもステップサイズパラメータを大きくできる。方策オン型らしい挙動を示している。

アクター・クリティック法

np.random.seed(0)
v_ac, pi_ac = gridworld_actor_critic(gw, 3000, 100, 0.1, 0.01)

方策。

gw.display_grid_action(pi_ac)
3 → → → → → → → → → → → ↓ 
2 → → → → → → → → → → → ↓ 
1 ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ → ↓ 
0 ↑ C C C C C C C C C C G 
  0 1 2 3 4 5 6 7 8 9 0 1 

$\varepsilon$ グリーディ方策ではないので、パラメータの意味が取りづらい。$\alpha$ を小さくすると最短経路を通るので、$\alpha$ に探索的傾向が含まれるようであるが、そう単純でもなさそう。

風の吹く格子世界

下から上に向かって風が吹く格子世界を考える。

class WindyGridworld(Gridworld):
    def __init__(self, nx, ny, goals, wind):
        super().__init__(nx, ny, goals)
        self.wind = np.array(wind, dtype=int)
    
    def step(self, s, a):
        assert(a < 4)
        
        if s in self.goals:
            r = 0
        else:
            i = s % self.nx
            j = s//self.nx
            
            j += self.wind[i]
            
            if a == 0: # up
                j += 1
            elif a == 1: # down
                j -= 1
            elif a == 2: # right
                i += 1
            elif a == 3: # left
                i -=1
            
            i = max(i, 0)
            i = min(i, self.nx - 1)
            j = max(j, 0)
            j = min(j, self.ny - 1)
            
            s = i + self.nx*j
            r = -1
            
        return s, r
    
    def display_grid(self):
        super().display_grid()
        print("w ", end="")
        for w in self.wind:
            print("{} ".format(w), end="")
        print()

    def display_grid_action(self, pi):
        super().display_grid_action(pi)
        print("w ", end="")
        for w in self.wind:
            print("{} ".format(w), end="")
        print()

風は格子の水平方向にそれぞれ整数で設定されており、その位置から動くときに、上方向にその個数分だけ動く。

ここでは、10 x 7 の格子を考え、ゴールは (7, 3) とする。ゴール付近に風を設定する。

nx = 10
ny = 7
goals = [(7, 3)]
wind = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]

gw = WindyGridworld(nx, ny, goals, wind)
gw.display_grid()
6 . . . . . . . . . . 
5 . . . . . . . . . . 
4 . . . . . . . . . . 
3 . . . . . . . G . . 
2 . . . . . . . . . . 
1 . . . . . . . . . . 
0 . . . . . . . . . . 
  0 1 2 3 4 5 6 7 8 9 
w 0 0 0 1 1 1 2 2 1 0 

一番下の数字の列が風である。この場合、ゴール (G) 横の (6, 3) から右に動くと、(7, 5) に行ってしまうし、(8, 3) から左に動くと (7, 4) に行ってしまう。

Sarsa

v_sarsa, pi_sarsa = gridworld_sarsa(gw, 5000, 10, 0.2, 0.1)
gw.display_grid_action(pi_sarsa)
6 → → → → → → → → → ↓ 
5 → → → → → → → → → ↓ 
4 → → → → → → → → → ↓ 
3 → → → → → → → G → ↓ 
2 → → → → → → → ↓ ← ← 
1 → → → → → → → ↓ → ↑ 
0 → → → → → → → ↑ ↑ ← 
  0 1 2 3 4 5 6 7 8 9 
w 0 0 0 1 1 1 2 2 1 0 

動的計画法による結果は以下の通りである。

6 → → → → → → → → → ↓ 
5 → → → → → → → → → ↓ 
4 → → → → → → → → → ↓ 
3 → → → → → → → G → ↓ 
2 → → → → → → → ↓ ← ← 
1 → → → → → → → ↓ ← ↑ 
0 → → → → → → ↓ ↑ ↑ ← 
  0 1 2 3 4 5 6 7 8 9 
w 0 0 0 1 1 1 2 2 1 0 

一部異なるところがあるが、ほぼ最適な結果になっている。

Q 学習

np.random.seed(0)
v_q, pi_q = gridworld_q(gw, 3000, 10, 0.5, 0.1)
gw.display_grid_action(pi_q)
6 → → → → → → → → → ↓ 
5 → → → → → → → → → ↓ 
4 → → → → → → → → → ↓ 
3 → → → → → → → G → ↓ 
2 → → → → → → → ↓ ← ← 
1 → → → → → → → ↓ ← ← 
0 → → → → → → → ↑ ↑ ← 
  0 1 2 3 4 5 6 7 8 9 
w 0 0 0 1 1 1 2 2 1 0 

Sarsa に比べてあっさりと最適解を出している。一部結果が異なるのは、経路が異なるだけである。

Expected Sarsa

np.random.seed(0)
v_esarsa, pi_esarsa = gridworld_expected_sarsa(gw, 3000, 10, 0.5, 0.1)
gw.display_grid_action(pi_esarsa)
6 → → → → → → → → → ↓ 
5 → → → → → → → → → ↓ 
4 → → → → → → → → → ↓ 
3 → → → → → → → G → ↓ 
2 → → → → → → → ↓ ← ← 
1 → → → → → → → ↓ ← ↑ 
0 → → → → → → → ↑ ↑ ← 
  0 1 2 3 4 5 6 7 8 9 
w 0 0 0 1 1 1 2 2 1 0 

これもまた Sarsa に比べてあっさりと最適解を出している。

アクター・クリティック法

np.random.seed(0)
v_ac, pi_ac = gridworld_actor_critic(gw, 10000, 100, 0.05, 0.05)
gw.display_grid_action(pi_ac)
6 → → → → → → → → → ↓ 
5 → → → → → → → → → ↓ 
4 → → → → → → → → → ↓ 
3 → → → → → → → G → ↓ 
2 → → → → → → → ↓ ← ← 
1 → → → → → → → ↓ ← ↑ 
0 → → → → → → → ↑ ↑ ← 
  0 1 2 3 4 5 6 7 8 9 
w 0 0 0 1 1 1 2 2 1 0 

最適解を出してはいるが、かなりパラメータをいじった。正直 Q 学習などに比べてかなり使いにくい。

ちなみに、強化比較の代わりに勾配法でもやってみたが、結果が安定しなかった。

まとめ

TD 学習は、モンテカルロ法と同等かそれ以上の結果をかなり軽い計算で出す。実装も簡単である。今回見た中では、方策オン型は Expected Sarsa、方策オフ型は Q 学習が使いやすい。