目录

前言

本文由浅入深介绍如果使用强化学习玩cart pole游戏。

概念

在电子游戏世界(特指Atari 2600这一类的简单游戏。不包括推理解密类的游戏)中:

  • 环境指的是游戏本身,包括其内部的各种逻辑;
  • Agent指的是操作游戏的玩家,当然也可以是指操作游戏的AI算法; 状态就是指游戏在屏幕上展现的画面。游戏通过屏幕画面把状态信息传达给Agent。如果是棋类游戏,状态是离散的,状态的数量是有限的。但在动作类游戏(如打飞机)中,状态是画面中的每个物体(飞机,敌人,子弹等等)所处的位置和运动速度的组合。状态是连续的,而且数量几乎是无限的。
  • 动作是指手柄的按键组合,包括方向键和按钮的组合,当然也包括什么都不按(不做任何动作)。
  • 奖励是指游戏的得分,每击中一个敌人都可以得到一些得分的奖励。
  • 策略是Agent脑子里从状态到动作的映射。也就是说,每当Agent看到一个游戏画面(状态),就应该知道该如何操纵手柄(动作)。

Reinforcement Learning算法的任务就是找到最佳的策略。策略的表示方法可以有很多。比如: 当状态的个数是有限的情况下,我们可以给每个状态指定一个最佳动作,以后只要看到某种状态出现,就去查询这个状态对应的动作就可以了。这种表示方法类似于查询字典。用这种方法表示最佳策略是没问题的,但是却不利于我们找出最佳策略,因为这种表示法是不可微分的(non-differentiable)。 我们还可以给每个状态下所有可能的动作估一个价值(Q-value),当我们看到某个状态的时候,我们去比较这个状态下所有动作的价值,找出价值最大的那一个。这种表示方法实际上可以用一个二维的表格来表示(Tabular Representation)。 当状态的个数是无限多时,我们已经不可能用表格来表示策略了。这时我们只能用一些函数拟合(Function Approximation)的方法把状态,动作,价值这三者的关系表示出来。神经网络目前是最佳的选择。

Cart Pole游戏介绍

游戏规则

cart pole即车杆游戏,游戏如下,很简单,游戏里面有一个小车,上有竖着一根杆子。小车需要左右移动来保持杆子竖直。如果杆子倾斜的角度大于15°,那么游戏结束。小车也不能移动出一个范围(中间到两边各2.4个单位长度)。

action

  • 左移
  • 右移

state variables

  • position of the cart on the track
  • angle of the pole with the vertical
  • cart velocity
  • rate of change of the angle

分别表示车的位置,杆子的角度,车速,角度变化率

游戏奖励

在gym的Cart Pole环境(env)里面,左移或者右移小车的action之后,env都会返回一个+1的reward。到达200个reward之后,游戏也会结束。

avatar

gym介绍

简单玩法

一直左移

# coding: utf8

import gym
import time

env = gym.make("CartPole-v0")

env.reset()

# 一直左移
action = 0
sum_reward = 0

for i in range(1000):
    env.render()
    time.sleep(1)
    observation, reward, done, info = env.step(action)
    sum_reward += reward
    print(observation, reward, sum_reward, done, info)
    if done:
        break

最后输出: 注意每次输出不一定一样。

[ 0.01332186 -0.20951928 -0.01201124  0.27060169] 1.0 1.0 False {} 0
[ 0.00913147 -0.40446779 -0.0065992   0.55947214] 1.0 2.0 False {} 1
[ 0.00104212 -0.59949649  0.00459024  0.85006868] 1.0 3.0 False {} 2
[-0.01094781 -0.79468073  0.02159161  1.14419149] 1.0 4.0 False {} 3
[-0.02684143 -0.99007801  0.04447544  1.44356652] 1.0 5.0 False {} 4
[-0.04664299 -1.18571827  0.07334677  1.74980819] 1.0 6.0 False {} 5
[-0.07035735 -1.38159249  0.10834294  2.06437417] 1.0 7.0 False {} 6
[-0.0979892  -1.5776388   0.14963042  2.38850996] 1.0 8.0 False {} 7
[-0.12954198 -1.77372622  0.19740062  2.72318191] 1.0 9.0 False {} 8
[-0.1650165  -1.96963594  0.25186426  3.06899915] 1.0 10.0 True {} 9

根据杆子角度调整

如果杆子角度为负,则右移;否则左移。

# coding: utf8

import gym
import time

env = gym.make("CartPole-v0")

env.reset()

# 一直往左移
action = 0
sum_reward = 0

for i in range(1000):
    env.render()
    time.sleep(0.1)
    observation, reward, done, info = env.step(action)
    if observation[1] < 0:
        action = 1
    else:
        action = 0
    sum_reward += reward
    print(observation, reward, sum_reward, done, info)
    if done:
        break

最后输出

[-0.01882535 -0.17022571 -0.03011459  0.27344505] 1.0 1.0 False {}
[-0.02222987  0.0253127  -0.02464569 -0.02858193] 1.0 2.0 False {}
[-0.02172361 -0.16944731 -0.02521733  0.25622426] 1.0 3.0 False {}
[-0.02511256  0.02602544 -0.02009284 -0.04430474] 1.0 4.0 False {}
[-0.02459205 -0.16880272 -0.02097894  0.2419716 ] 1.0 5.0 False {}
[-0.0279681   0.02661253 -0.0161395  -0.05725412] 1.0 6.0 False {}
[-0.02743585 -0.16827434 -0.01728459  0.2302933 ] 1.0 7.0 False {}
[-0.03080134  0.02709028 -0.01267872 -0.06779128] 1.0 8.0 False {}
[-0.03025953 -0.16784762 -0.01403455  0.22086463] 1.0 9.0 False {}
...
[-0.10049214 -0.18592199  0.11853794  0.62351417] 1.0 57.0 False {}
[-0.10421058  0.00736289  0.13100822  0.37038998] 1.0 58.0 False {}
[-0.10406332 -0.18935337  0.13841602  0.70134211] 1.0 59.0 False {}
[-0.10785039  0.00360618  0.15244286  0.45523554] 1.0 60.0 False {}
[-0.10777827 -0.19330536  0.16154758  0.79182232] 1.0 61.0 False {}
[-0.11164438 -0.00072631  0.17738402  0.55400362] 1.0 62.0 False {}
[-0.1116589   0.19151957  0.18846409  0.32203753] 1.0 63.0 False {}
[-0.10782851 -0.00571617  0.19490485  0.66773352] 1.0 64.0 False {}
[-0.10794283  0.18623823  0.20825952  0.44219351] 1.0 65.0 False {}
[-0.10421807 -0.01112806  0.21710339  0.79263264] 1.0 66.0 True {}

Random Guessing Algorithm

observation是一个四维向量,如果对这个向量求它的加权和,就可以得到一个值,那么就可以根据加权和的符号来决定action,同样可以用sigmoid函数当成二分类问题。这样就可以通过改变权重值来改变policy。

代码如下, 包含了hill climbing算法的代码。

# coding: utf8

import numpy as np
import gym
import time

def get_action(weights, observation):
    wxb = np.dot(weights[:4], observation) + weights[4]
    if wxb >= 0:
        return 1
    else:
        return 0

def get_sum_reward_by_weights(env, weights):
    observation = env.reset()
    sum_reward = 0
    for t in range(1000):
        # time.sleep(0.01)
        # env.render()
        action = get_action(weights, observation)
        observation, reward, done, info = env.step(action)
        sum_reward += reward
        # print(sum_reward, action, observation, reward, done, info)
        if done:
            break
    return sum_reward


def get_weights_by_random_guess():
    return np.random.rand(5)

def get_weights_by_hill_climbing(best_weights):
    return best_weights + np.random.normal(0, 0.1, 5)

def get_best_result(algo="random_guess"):
    env = gym.make("CartPole-v0")
    np.random.seed(10)
    best_reward = 0
    best_weights = np.random.rand(5)

    for iter in range(10000):
        cur_weights = None

        if algo == "hill_climbing":
            # print(best_weights)
            cur_weights = get_weights_by_hill_climbing(best_weights)
        else:
            cur_weights = get_weights_by_random_guess()

        cur_sum_reward = get_sum_reward_by_weights(env, cur_weights)

        # print(cur_sum_reward, cur_weights)
        if cur_sum_reward > best_reward:
            best_reward = cur_sum_reward
            best_weights = cur_weights

        if best_reward >= 200:
            break

    print(iter, best_reward, best_weights)
    return best_reward, best_weights

print(get_best_result("hill_climbing"))

# env = gym.make("CartPole-v0")
# get_sum_reward_by_weights(env, [0.22479665, 0.19806286, 0.76053071, 0.16911084, 0.08833981])

Hill Climbing Algorithm

Hill Climbing Algorithm会给当前最好的权重加上一组随机值,如果加上这组值持续时间变长了那么就更新最好的权重,如果没有变的更好就不更新。