{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# A2: 강화학습 (2) – 4x4 grid\n", "\n", "최규빈 \n", "2023-08-29\n", "\n", "내 메모는 녹색\n", "\n", "# 강의영상\n", "\n", "\n", "\n", "# Game2: 4x4 grid\n", "\n", "`-` 문제설명: 4x4 그리드월드에서 상하좌우로 움직이는 에이전트가 목표점에\n", "도달하도록 학습하는 방법\n", "\n", "# imports" ], "id": "992b0ad2-2037-4ba6-8159-240c1c121ea3" }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import gymnasium as gym\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib.animation import FuncAnimation\n", "import IPython" ], "id": "5b1cfd48-d338-4ddd-8b59-f8ab593ebda9" }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 예비학습: 시각화" ], "id": "57672c31-f451-4652-9490-303ef9ccb898" }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def show(states):\n", " fig = plt.Figure()\n", " ax = fig.subplots()\n", " ax.matshow(np.zeros([4,4]), cmap='bwr',alpha=0.0)\n", " sc = ax.scatter(0, 0, color='red', s=500) \n", " ax.text(0, 0, 'start', ha='center', va='center')\n", " ax.text(3, 3, 'end', ha='center', va='center')\n", " # Adding grid lines to the plot\n", " ax.set_xticks(np.arange(-.5, 4, 1), minor=True)\n", " ax.set_yticks(np.arange(-.5, 4, 1), minor=True)\n", " ax.grid(which='minor', color='black', linestyle='-', linewidth=2)\n", " def update(t):\n", " sc.set_offsets(states[t])\n", " ani = FuncAnimation(fig,update,frames=len(states))\n", " display(IPython.display.HTML(ani.to_jshtml()))" ], "id": "0b3ec243-70a1-46e7-bd95-2bb4a5338898" }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " Once\n", " \n", " Loop\n", " \n", " Reflect\n", " \n", " \n", "\n", "\n", "\n", "" ] } } ], "source": [ "show([[0,0],[0,1],[1,1],[1,2],[1,3],[1,2],[1,3],[1,2],[1,3],[1,2],[1,3]])" ], "id": "5379f2af-62d5-43ff-b16f-4689e7a53fe5" }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Env 클래스 구현\n", "\n", "`-` GridWorld: 강화학습에서 많이 예시로 사용되는 기본적인 시뮬레이션\n", "환경\n", "\n", "1. **State**: 각 격자 셀이 하나의 상태이며, 에이전트는 이러한 상태 중\n", " 하나에 있을 수 있음.\n", "2. **Action**: 에이전트는 현재상태에서 다음상태로 이동하기 위해\n", " 상,하,좌,우 중 하나의 행동을 취할 수 있음.\n", "3. **Reward**: 에이전트가 현재상태에서 특정 action을 하면 얻어지는 보상\n", "4. **Terminated**: 하나의 에피소드가 종료되었음을 나타내는 상태" ], "id": "994533f5-c7f2-4b72-852a-67e79e55f931" }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "action = 3\n", "current_state = np.array([1,1])" ], "id": "aedc105c-7a89-46c1-8f5b-100d938085a5" }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "action_to_direction = { \n", " 0 : np.array([1, 0]), # x+ \n", " 1 : np.array([0, 1]), # y+ \n", " 2 : np.array([-1 ,0]), # x- \n", " 3 : np.array([0, -1]) # y- \n", " }" ], "id": "023fb0ce-13a1-4e4c-8152-9a5f617dfd1c" }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "array([1, 0])" ] } } ], "source": [ "next_state = current_state + action_to_direction[action]\n", "next_state" ], "id": "e50e1ecd-765b-4764-9ac4-30c4234352e5" }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class GridWorld:\n", " def __init__(self):\n", " self.reset()\n", " self.state_space = gym.spaces.MultiDiscrete([4,4])\n", " self.action_space = gym.spaces.Discrete(4) \n", " self._action_to_direction = { \n", " 0 : np.array([1, 0]), # x+ \n", " 1 : np.array([0, 1]), # y+ \n", " 2 : np.array([-1 ,0]), # x- \n", " 3 : np.array([0, -1]) # y- \n", " }\n", " def reset(self):\n", " self.agent_action = None \n", " self.agent_state = np.array([0,0]) \n", " return self.agent_state \n", " def step(self,action):\n", " direction = self._action_to_direction[action]\n", " self.agent_state = self.agent_state + direction\n", " if self.agent_state not in env.state_space: # 4x4 그리드 밖에 있는 경우\n", " reward = -10 \n", " terminated = True\n", " self.agent_state = self.agent_state -1/2 * direction\n", " elif np.array_equal(env.agent_state, np.array([3,3])): # 목표지점에 도달할 경우 \n", " reward = 100 \n", " terminated = True\n", " else: \n", " reward = -1 \n", " terminated = False \n", " return self.agent_state, reward, terminated\n" ], "id": "414b9e1a-c494-4c0f-abc3-16fe43d7ede5" }, { "cell_type": "markdown", "metadata": {}, "source": [ "grid를 벗어나는 경우를 reward가 -10이 되게 함" ], "id": "b5bdd9f8-9804-475c-9d1f-3bdb50a47540" }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "env = GridWorld()" ], "id": "d90d8d11-e438-4528-ac96-107fe32c39be" }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "states = [] \n", "state = env.reset()\n", "states.append(state) \n", "for t in range(50):\n", " action = env.action_space.sample() \n", " state,reward,terminated = env.step(action)\n", " states.append(state) \n", " if terminated: break " ], "id": "f8bcc04d-d188-4c44-86bb-bf2394cd66d8" }, { "cell_type": "markdown", "metadata": {}, "source": [ "에이전트가 무지한 경우" ], "id": "dc5e9ad9-942b-43ed-a2e8-3f9b33cb985e" }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "[array([0, 0]),\n", " array([0, 1]),\n", " array([0, 0]),\n", " array([1, 0]),\n", " array([2, 0]),\n", " array([ 2. , -0.5])]" ] } } ], "source": [ "states" ], "id": "92960f3b-396e-4673-abb8-9a9840a2d47f" }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " Once\n", " \n", " Loop\n", " \n", " Reflect\n", " \n", " \n", "\n", "\n", "\n", "" ] } } ], "source": [ "show(states)" ], "id": "886e5543-619a-4488-a39f-77bc8a2fa254" }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Agent1 클래스 구현 + Run\n", "\n", "`-` 우리가 구현하고 싶은 기능\n", "\n", "- `.act()`: 액션을 결정 –\\> 여기서는 그냥 랜덤액션\n", "- `.save_experience()`: 데이터를 저장 –\\> 여기에 일단 초점을 맞추자\n", "- `.learn()`: 데이터로에서 학습 –\\> 패스\n", "\n", "`-` 첫번째 시도" ], "id": "5a15c0c2-ee8c-437b-837e-acec15690ece" }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "class Agent1:\n", " def __init__(self,env):\n", " self.action_space = env.action_space\n", " self.state_spcae = env.state_space \n", " self.n_experiences = 0 \n", " self.n_episodes = 0 \n", " self.score = 0 \n", " \n", " # episode-wise info \n", " self.scores = [] \n", " self.playtimes = []\n", "\n", " # time-wise info\n", " self.current_state = None \n", " self.action = None \n", " self.reward = None \n", " self.next_state = None \n", " self.terminated = None \n", "\n", " # replay_buffer \n", " self.actions = []\n", " self.current_states = [] \n", " self.rewards = []\n", " self.next_states = [] \n", " self.terminations = [] \n", "\n", " def act(self):\n", " self.action = self.action_space.sample() \n", "\n", " def save_experience(self):\n", " self.actions.append(self.action) \n", " self.current_states.append(self.current_state)\n", " self.rewards.append(self.reward)\n", " self.next_states.append(self.next_state)\n", " self.terminations.append(self.terminated) \n", " self.n_experiences += 1 \n", " self.score = self.score + self.reward \n", " \n", " def learn(self):\n", " pass " ], "id": "e146d5e8-4c91-45ce-990f-f501db858092" }, { "cell_type": "markdown", "metadata": {}, "source": [ "일단 랜덤으로 acition 선택\n", "\n", "에피소드의 개념은 몇번째 게임중이냐!" ], "id": "d7a8a9c8-2bc2-47d1-8012-63b3e3c1f516" }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epsiode: 1 Score: -10 Playtime: 1\n", "Epsiode: 2 Score: -10 Playtime: 1\n", "Epsiode: 3 Score: -10 Playtime: 1\n", "Epsiode: 4 Score: -17 Playtime: 8\n", "Epsiode: 5 Score: -10 Playtime: 1\n", "Epsiode: 6 Score: -11 Playtime: 2\n", "Epsiode: 7 Score: -22 Playtime: 13\n", "Epsiode: 8 Score: -10 Playtime: 1\n", "Epsiode: 9 Score: -21 Playtime: 12\n", "Epsiode: 10 Score: -10 Playtime: 1\n", "Epsiode: 11 Score: -18 Playtime: 9\n", "Epsiode: 12 Score: -11 Playtime: 2\n", "Epsiode: 13 Score: -11 Playtime: 2\n", "Epsiode: 14 Score: -10 Playtime: 1\n", "Epsiode: 15 Score: -10 Playtime: 1\n", "Epsiode: 16 Score: -12 Playtime: 3\n", "Epsiode: 17 Score: 91 Playtime: 10\n", "Epsiode: 18 Score: -10 Playtime: 1\n", "Epsiode: 19 Score: -12 Playtime: 3\n", "Epsiode: 20 Score: -12 Playtime: 3" ] } ], "source": [ "env = GridWorld() \n", "agent = Agent1(env) \n", "for _ in range(20):\n", " ## 본질적인 코드 \n", " agent.current_state = env.reset()\n", " agent.terminated = False \n", " agent.score = 0 \n", " for t in range(50):\n", " # step1: agent >> env \n", " agent.act() \n", " env.agent_action = agent.action \n", " # step2: agent << env \n", " agent.next_state, agent.reward, agent.terminated = env.step(env.agent_action)\n", " agent.save_experience() \n", " # step3: learn \n", " # agent.learn()\n", " # step4: state update \n", " agent.current_state = agent.next_state \n", " # step5: \n", " if agent.terminated: break \n", " agent.scores.append(agent.score) \n", " agent.playtimes.append(t+1)\n", " agent.n_episodes = agent.n_episodes + 1 \n", " ## 덜 본질적인 코드 \n", " print(\n", " f\"Epsiode: {agent.n_episodes} \\t\"\n", " f\"Score: {agent.scores[-1]} \\t\"\n", " f\"Playtime: {agent.playtimes[-1]}\"\n", " ) " ], "id": "1d53c7dc-4e20-4e65-9d10-56b071baef9a" }, { "cell_type": "markdown", "metadata": {}, "source": [ "if agent.terminated: break True이면 멈추라는 뜻\n", "" ], "id": "f3f576e2-3690-483f-a54c-50e6954e1a30" }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "27" ] } } ], "source": [ "sum(agent.playtimes[:7])" ], "id": "39b306fe-1afb-4a96-9908-f931d1e96e8c" }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "28" ] } } ], "source": [ "sum(agent.playtimes[:8])" ], "id": "c93e1dd8-ceff-4789-8809-244f639a2d09" }, { "cell_type": "markdown", "metadata": {}, "source": [ "위에서 맞춘 거만 가져와봄(48 = Playtime의 누적\n", "합)" ], "id": "c0c40e04-68f3-4d73-974c-bc366c285543" }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " Once\n", " \n", " Loop\n", " \n", " Reflect\n", " \n", " \n", "\n", "\n", "\n", "" ] } } ], "source": [ "states = [np.array([0,0])] + agent.next_states[48:60]\n", "show(states)" ], "id": "e251a785-472d-45ce-9aa8-9e5086fdd5d4" }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 우연히 잘맞춘 케이스\n", "\n", "# 환경의 이해 (1차원적 이해)\n", "\n", "`-` 무작위로 10000판을 진행해보자." ], "id": "7e1670b5-082c-4c8d-97ee-42dbaf375cae" }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [], "source": [ "env = GridWorld() \n", "agent = Agent1(env) \n", "for _ in range(10000):\n", " ## 본질적인 코드 \n", " agent.current_state = env.reset()\n", " agent.terminated = False \n", " agent.score = 0 \n", " for t in range(50):\n", " # step1: agent >> env \n", " agent.act() \n", " env.agent_action = agent.action \n", " # step2: agent << env \n", " agent.next_state, agent.reward, agent.terminated = env.step(env.agent_action)\n", " agent.save_experience() \n", " # step3: learn \n", " # agent.learn()\n", " # step4: state update \n", " agent.current_state = agent.next_state \n", " # step5: \n", " if agent.terminated: break \n", " agent.scores.append(agent.score) \n", " agent.playtimes.append(t+1)\n", " agent.n_episodes = agent.n_episodes + 1 " ], "id": "40b12099-567e-419d-b629-d34f01810511" }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "32858" ] } } ], "source": [ "agent.n_experiences" ], "id": "ceba972e-030e-4bd1-ae1d-5367cea52107" }, { "cell_type": "markdown", "metadata": {}, "source": [ "`-` 데이터관찰" ], "id": "60aa5700-2561-49aa-aa7a-b9a0f7be0b4a" }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "(array([0, 0]), 3, -10, array([ 0. , -0.5]))" ] } } ], "source": [ "agent.current_states[0], agent.actions[0], agent.rewards[0], agent.next_states[0]" ], "id": "907a43bc-1644-4ad2-ba8f-f9296035721a" }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "(array([0, 0]), 3, -10, array([ 0. , -0.5]))" ] } } ], "source": [ "agent.current_states[1], agent.actions[1], agent.rewards[1], agent.next_states[1]" ], "id": "57ae01f6-3178-477f-ba2b-b708a91338c4" }, { "cell_type": "code", "execution_count": 110, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "(array([0, 0]), 0, -1, array([1, 0]))" ] } } ], "source": [ "agent.current_states[2], agent.actions[2], agent.rewards[2], agent.next_states[2]" ], "id": "9fdc44e3-02df-4b0b-a4e4-39765d50e7c2" }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "(array([1, 0]), 3, -10, array([ 1. , -0.5]))" ] } } ], "source": [ "agent.current_states[3], agent.actions[3], agent.rewards[3], agent.next_states[3]" ], "id": "69c21c21-3995-4167-8741-becd4ac9a2a1" }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "(array([0, 0]), 0, -1, array([1, 0]))" ] } } ], "source": [ "agent.current_states[4], agent.actions[4], agent.rewards[4], agent.next_states[4]" ], "id": "4288e684-decd-4056-83bd-9f2a0b79ec35" }, { "cell_type": "markdown", "metadata": {}, "source": [ "`-` 환경을 이해하기 위한 기록 (1)\n", "\n", "q = x,y,a\n", "\n", "x,y - 축 생각하면 될 듯 " ], "id": "6c0c405a-af95-4df6-bb1d-a9d84fe46ec6" }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [], "source": [ "q = np.zeros([4,4,4])\n", "count = np.zeros([4,4,4])\n", "for i in range(agent.n_experiences):\n", " x,y = agent.current_states[i] \n", " a = agent.actions[i] \n", " q[x,y,a] = q[x,y,a] + agent.rewards[i] \n", " count[x,y,a] = count[x,y,a] + 1 " ], "id": "dc434942-f67f-4c5f-8315-6a6f2da7a175" }, { "cell_type": "markdown", "metadata": {}, "source": [ "q의 x,y,a 차원에 rewards를 더하자\n", "\n", "count를 기록해야 한다는 단점" ], "id": "5f5c7ac7-bbcb-4642-a825-035e03b82301" }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [], "source": [ "count[count == 0] = 0.01 \n", "q = q/count" ], "id": "6a2698f4-75cd-463b-9256-c84a23793b1a" }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "array([[-10., -1., -1., -1.],\n", " [-10., -1., -1., -1.],\n", " [-10., -1., -1., -1.],\n", " [-10., -1., -1., 0.]])" ] } } ], "source": [ "q[:,:,3]" ], "id": "b0da7c22-98c3-421f-888c-a441da28700d" }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "action = 0\n", "action-value function = \n", " [[ -1. -1. -1. -1.]\n", " [ -1. -1. -1. -1.]\n", " [ -1. -1. -1. 100.]\n", " [-10. -10. -10. 0.]]\n", "\n", "action = 1\n", "action-value function = \n", " [[ -1. -1. -1. -10.]\n", " [ -1. -1. -1. -10.]\n", " [ -1. -1. -1. -10.]\n", " [ -1. -1. 100. 0.]]\n", "\n", "action = 2\n", "action-value function = \n", " [[-10. -10. -10. -10.]\n", " [ -1. -1. -1. -1.]\n", " [ -1. -1. -1. -1.]\n", " [ -1. -1. -1. 0.]]\n", "\n", "action = 3\n", "action-value function = \n", " [[-10. -1. -1. -1.]\n", " [-10. -1. -1. -1.]\n", " [-10. -1. -1. -1.]\n", " [-10. -1. -1. 0.]]\n" ] } ], "source": [ "for a in range(4):\n", " print(\n", " f\"action = {a}\\n\" \n", " f\"action-value function = \\n {q[:,:,a]}\\n\" \n", ")" ], "id": "c7575209-a13d-4851-a793-dc61bd46f5da" }, { "cell_type": "markdown", "metadata": {}, "source": [ "`-` 환경을 이해하기 위한 기록 (2)\n", "\n", "real 과 estimate의 차이를 이용하여 update" ], "id": "40b9ea8b-5408-45ea-878e-4f812a3d411c" }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [], "source": [ "q = np.zeros([4,4,4])\n", "for i in range(agent.n_experiences):\n", " x,y = agent.current_states[i]\n", " a = agent.actions[i]\n", " q_estimated = q[x,y,a] # 우리가 환경을 이해하고 있는 값, 우리가 풀어낸 답 \n", " q_realistic = agent.rewards[i] # 실제 답 \n", " diff = q_realistic - q_estimated # 실제답과 풀이한값의 차이 = 오차피드백값 \n", " q[x,y,a] = q_estimated + 0.05 * diff ## 새로운답 = 원래답 + 오차피드백값 " ], "id": "914e7423-5d32-4b5a-9a2c-3cf0324e721a" }, { "cell_type": "markdown", "metadata": {}, "source": [ "오차를 5 %만 반영하자" ], "id": "bd9f3c27-e958-4ccc-8c19-8cd1ad3ac57a" }, { "cell_type": "code", "execution_count": 118, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "action = 0\n", "action-value function = \n", " [[-1. -1. -1. -0.99879276]\n", " [-1. -1. -0.99999999 -0.99923914]\n", " [-1. -1. -0.99999572 99.15219633]\n", " [-9.99277183 -9.99788945 -9.9626859 0. ]]\n", "\n", "action = 1\n", "action-value function = \n", " [[-1. -1. -1. -9.98910469]\n", " [-1. -1. -0.99999997 -9.99411261]\n", " [-1. -1. -0.99999418 -9.88466698]\n", " [-0.99981905 -0.99974088 99.40794708 0. ]]\n", "\n", "action = 2\n", "action-value function = \n", " [[-10. -10. -9.99999978 -9.9923914 ]\n", " [ -1. -1. -0.99999999 -0.99934766]\n", " [ -0.99999998 -0.99999998 -0.99997791 -0.98722072]\n", " [ -0.9990167 -0.99960942 -0.98584013 0. ]]\n", "\n", "action = 3\n", "action-value function = \n", " [[-10. -1. -1. -0.99764828]\n", " [-10. -1. -0.99999996 -0.99818028]\n", " [ -9.99999999 -0.99999999 -0.99999716 -0.99645516]\n", " [ -9.98357707 -0.99988595 -0.99645516 0. ]]\n" ] } ], "source": [ "for a in range(4):\n", " print(\n", " f\"action = {a}\\n\" \n", " f\"action-value function = \\n {q[:,:,a]}\\n\" \n", ")" ], "id": "0d2778f3-b24b-4a2b-9abd-78f1a687e9ca" }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 환경의 깊은 이해 (좀 더 고차원적인 이해)\n", "\n", "`-` action=1 일때 각 state의 가치 (=기대보상)" ], "id": "a615da03-1e60-435b-b7d6-0335eca5d4e2" }, { "cell_type": "code", "execution_count": 119, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "array([[-1. , -1. , -1. , -9.98910469],\n", " [-1. , -1. , -0.99999997, -9.99411261],\n", " [-1. , -1. , -0.99999418, -9.88466698],\n", " [-0.99981905, -0.99974088, 99.40794708, 0. ]])" ] } } ], "source": [ "q[:,:,1]" ], "id": "a1214644-c901-4eb9-8d32-563df474baf2" }, { "cell_type": "markdown", "metadata": {}, "source": [ "`-` 분석1" ], "id": "6295a252-45a9-403e-ad8c-75edbe2e89dd" }, { "cell_type": "code", "execution_count": 120, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "99.40794707796658" ] } } ], "source": [ "q[3,2,1]" ], "id": "08ca1c41-211a-4c55-85e9-bce10f20b3d5" }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 상태 (3,2)에서 행동 1을 하게되면 100의 보상을 얻으므로 기대보상값은\n", " 100근처 –\\> 합리적임\n", "\n", "`-` 분석2" ], "id": "f0607bed-e1cd-47d7-852f-f077c4c61cf7" }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "-0.9997408802884766" ] } } ], "source": [ "q[3,1,1]" ], "id": "0586f726-9795-4435-89aa-f73dd145c176" }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 상태 (3,1)에서 행동 1을 하게되면 -1 의 보상을 얻으므로 기대보상값은\n", " -1 근처 –\\> 합리적일까??\n", "\n", "`-` 비판: 분석2는 합리적인것 처럼 보이지만 data를 분석한 뒤에는 그다지\n", "합리적이지 못함\n", "\n", "`-` 상황상상\n", "\n", "- 빈 종이를 줌\n", "- 빈 종이에는 0 또는 1을 쓸 수 있음 (action = 0 혹은 1)\n", "- 0을 쓸때와 1을 쓸때 보상이 다름\n", "- 무수히 많은 데이터를 분석해보니, 0을 쓰면 0원을 주고 1을 쓰면\n", " 10만원을 보상을 준다는 것을 “알게 되었음”\n", "- 이때 빈 종이의 가치는 5만원인가? 10만원인가? –\\> 10만원아니야?\n", "\n", "`-` 직관: 생각해보니 현재 $s=(3,1)$ $a=1$에서 추정된(esitated) 값은\n", "`q[3,1,1]= -0.9997128867462345` 이지만\\[1\\], 현실적으로는\n", "“실제보상(-1)과 잠재적보상(100)”을 동시에 고려해야 하는게 합리적임\n", "\n", "\\[1\\] 즉 next_state가 가지는 잠재적값어치는 고려되어있지 않음" ], "id": "2b6acd67-3c8b-4b95-8283-b3d8ebb2b899" }, { "cell_type": "code", "execution_count": 122, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "-0.9997408802884766" ] } } ], "source": [ "q_estimated = q[3,1,1]\n", "q_estimated" ], "id": "cfa4a389-21a8-4505-a864-fe28600a15c5" }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "98.0" ] } } ], "source": [ "q_realistic = (-1) + 0.99 * 100 \n", "q_realistic" ], "id": "c9c2af60-dfdb-4053-8609-3aa02e559007" }, { "cell_type": "markdown", "metadata": {}, "source": [ "0.01은 약간의 패널티..\n", "\n", "- 여기에서 0.99는 “미래에 받을 보상이 현재에 비해 얼마나 중요한지를\n", " 결정하는 가중치” 이다.\n", "- 1에 가까울수록 미래에 받을 보상을 매우 중시한다는 의미 (즉 빈종이=\n", " 십만원 으로 생각한다는 의미)\n", "\n", "`-` 즉 $q(s,a)$는 모든 $s$, $a$에 대하여\n", "\n", "$$q(s,a) \\approx \\text{reward}(s,a) + 0.99 \\times \\max_{a}q(s',a)$$\n", "\n", "가 성립한다면 $q(s,a)$는 타당하게 추정된 것이라 볼 수 있다. 물론 수식을\n", "좀 더 엄밀하게 쓰면 아래와 같다.\n", "\n", "$$q(s,a) \\approx \\begin{cases} \\text{reward}(s,a) & \\text{terminated} \\\\ \\text{reward}(s,a) + 0.99 \\times \\max_{a}q(s',a) & \\text{not terminated}\\end{cases}$$\n", "\n", "s는 상태 a는 action\n", "\n", "$q(s,a) \\approx \\text{reward}(s,a) + 0.99 \\times \\max_{a}q(s',a)$\n", "여기서 $\\text{reward}(s,a)$이거는 바로 받는 거 \\$ \\_{a}q(s’,a)\\$ 이거는\n", "내각 가질 수 있는 최대 리워드" ], "id": "74f146e4-0183-4426-8a2c-da5e600e5312" }, { "cell_type": "code", "execution_count": 125, "metadata": {}, "outputs": [], "source": [ "q = np.zeros([4,4,4])\n", "for i in range(agent.n_experiences):\n", " x,y = agent.current_states[i]\n", " xx,yy = agent.next_states[i]\n", " a = agent.actions[i]\n", " q_estimated = q[x,y,a] \n", " if agent.terminations[i]:\n", " q_realistic = agent.rewards[i]\n", " else:\n", " q_future = q[xx,yy,:].max()\n", " q_realistic = agent.rewards[i] + 0.99 * q_future\n", " diff = q_realistic - q_estimated \n", " q[x,y,a] = q_estimated + 0.05 * diff " ], "id": "4d6ffc83-3edb-42a7-a12c-23dc77e849cc" }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "action = 0\n", "action-value function = \n", " [[88.53307362 90.49709464 92.36408758 88.68673925]\n", " [90.28856398 92.49369954 94.61842445 95.49617968]\n", " [90.91491115 94.21657181 96.89901308 99.15219633]\n", " [-9.99277183 -9.99788945 -9.9626859 0. ]]\n", "\n", "action = 1\n", "action-value function = \n", " [[88.55960706 90.3131511 83.87809217 -9.98910469]\n", " [90.472853 92.49913938 92.55929104 -9.99411261]\n", " [92.35065011 94.61963597 96.65724194 -9.88466698]\n", " [93.42457258 96.5945232 99.40794708 0. ]]\n", "\n", "action = 2\n", "action-value function = \n", " [[-10. -10. -9.99999978 -9.9923914 ]\n", " [ 86.56190669 88.46124563 89.96848094 80.18849597]\n", " [ 88.03732538 90.28026548 91.62827094 84.50628885]\n", " [ 87.41906298 91.06145181 87.82431486 0. ]]\n", "\n", "action = 3\n", "action-value function = \n", " [[-10. 86.5658665 88.21628148 86.74874619]\n", " [-10. 88.40364698 90.19865977 90.75947241]\n", " [ -9.99999999 89.90158238 91.81108837 92.72733049]\n", " [ -9.98357707 88.02167685 91.41860035 0. ]]\n" ] } ], "source": [ "for a in range(4):\n", " print(\n", " f\"action = {a}\\n\" \n", " f\"action-value function = \\n {q[:,:,a]}\\n\" \n", ")" ], "id": "ff789b47-1166-40f2-af86-1403f91a3643" }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 행동 전략 수립\n", "\n", "`-` 상태 (0,0)에 있다고 가정해보자." ], "id": "dcf74b84-7638-4a55-b31f-d69691d06e4d" }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "array([ 88.53307362, 88.55960706, -10. , -10. ])" ] } } ], "source": [ "q[0,0,:]" ], "id": "4f17808b-c852-4027-baa0-f364d2dd3ee0" }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 행동 0 혹은 행동 1을 하는게 유리하다. // 행동 2,3을 하면 망한다.\n", "\n", "`-` 상태 (2,3)에 있다고 가정해보자." ], "id": "42eec9d8-dd06-4712-b5a3-341d2a62bced" }, { "cell_type": "code", "execution_count": 128, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "array([99.15219633, -9.88466698, 84.50628885, 92.72733049])" ] } } ], "source": [ "q[2,3,:]" ], "id": "b0a53547-7ae8-4950-b10c-2146755a87b7" }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 행동 0을 하는게 유리함.\n", "\n", "`-` 상태 (3,2)에 있다고 가정해보자." ], "id": "89ad034d-2e9f-40bf-9f9b-3f98da5456b9" }, { "cell_type": "code", "execution_count": 129, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "array([-9.9626859 , 99.40794708, 87.82431486, 91.41860035])" ] } } ], "source": [ "q[3,2,:]" ], "id": "92aac9f3-5e47-42bb-8e62-5bf64f36cc29" }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 행동1을 하는게 유리함\n", "\n", "`-` 각 상태에서 최적은 action은 아래와 같다." ], "id": "ee06abb7-5ec6-4a92-87a8-8fb725b953b1" }, { "cell_type": "code", "execution_count": 130, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "1" ] } } ], "source": [ "q[0,0,:].argmax()" ], "id": "f1b692c3-94b6-4c9e-8a92-ff0bb836c83b" }, { "cell_type": "code", "execution_count": 131, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "0" ] } } ], "source": [ "q[2,3,:].argmax()" ], "id": "982e1340-a7d2-4e2a-84d8-f4d41072227b" }, { "cell_type": "code", "execution_count": 132, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "1" ] } } ], "source": [ "q[3,2,:].argmax()" ], "id": "ede24e2e-ed08-405e-870e-f9b780ec5451" }, { "cell_type": "markdown", "metadata": {}, "source": [ "`-` 전략(=정책)을 정리해보자." ], "id": "bf2ba71a-4500-4e76-82de-ef674c456258" }, { "cell_type": "code", "execution_count": 133, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "array([['?????', '?????', '?????', '?????'],\n", " ['?????', '?????', '?????', '?????'],\n", " ['?????', '?????', '?????', '?????'],\n", " ['?????', '?????', '?????', '?????']], dtype='experiences 쌓일 때마다 업데이트 하기" ], "id": "c0c7f9bd-2987-4520-bc0c-2956053713ac" }, { "cell_type": "code", "execution_count": 148, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epsiode: 100 Score: -11.76 Playtime: 2.76\n", "Epsiode: 200 Score: -9.53 Playtime: 3.83\n", "Epsiode: 300 Score: -9.0 Playtime: 3.3\n", "Epsiode: 400 Score: -12.1 Playtime: 3.1\n", "Epsiode: 500 Score: -10.38 Playtime: 3.58\n", "Epsiode: 600 Score: -8.94 Playtime: 3.24\n", "Epsiode: 700 Score: -12.16 Playtime: 3.16\n", "Epsiode: 800 Score: -8.94 Playtime: 3.24\n", "Epsiode: 900 Score: -10.02 Playtime: 5.78\n", "Epsiode: 1000 Score: -50.0 Playtime: 50.0\n", "Epsiode: 1100 Score: -50.0 Playtime: 50.0\n", "Epsiode: 1200 Score: -50.0 Playtime: 50.0\n", "Epsiode: 1300 Score: -50.0 Playtime: 50.0\n", "Epsiode: 1400 Score: -50.0 Playtime: 50.0\n", "Epsiode: 1500 Score: -50.0 Playtime: 50.0\n", "Epsiode: 1600 Score: -50.0 Playtime: 50.0\n", "Epsiode: 1700 Score: -50.0 Playtime: 50.0\n", "Epsiode: 1800 Score: -50.0 Playtime: 50.0\n", "Epsiode: 1900 Score: -50.0 Playtime: 50.0\n", "Epsiode: 2000 Score: -50.0 Playtime: 50.0" ] } ], "source": [ "env = GridWorld() \n", "agent = Agent2(env) \n", "for _ in range(2000):\n", " ## 본질적인 코드 \n", " agent.current_state = env.reset()\n", " agent.terminated = False \n", " agent.score = 0 \n", " for t in range(50):\n", " # step1: agent >> env \n", " agent.act() \n", " env.agent_action = agent.action \n", " # step2: agent << env \n", " agent.next_state, agent.reward, agent.terminated = env.step(env.agent_action)\n", " agent.save_experience() \n", " # step3: learn \n", " agent.learn()\n", " # step4: state update \n", " agent.current_state = agent.next_state \n", " # step5: \n", " if agent.terminated: break \n", " agent.scores.append(agent.score) \n", " agent.playtimes.append(t+1)\n", " agent.n_episodes = agent.n_episodes + 1 \n", " ## 덜 본질적인 코드 \n", " if (agent.n_episodes % 100) ==0:\n", " print(\n", " f\"Epsiode: {agent.n_episodes} \\t\"\n", " f\"Score: {np.mean(agent.scores[-100:])} \\t\"\n", " f\"Playtime: {np.mean(agent.playtimes[-100:])}\"\n", " ) " ], "id": "3bf71d4d-f5b3-48f7-a437-dd0eedb4d211" }, { "cell_type": "markdown", "metadata": {}, "source": [ "agent.n_episodes % 100 –\\> 100의 배수" ], "id": "eb66eb8b-3d93-4b53-afca-6f94ef50a0d4" }, { "cell_type": "code", "execution_count": 149, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " Once\n", " \n", " Loop\n", " \n", " Reflect\n", " \n", " \n", "\n", "\n", "\n", "" ] } } ], "source": [ "states = [np.array([0,0])] + agent.next_states[-agent.playtimes[-1]:] \n", "show(states)" ], "id": "3cc1eb34-48f5-4096-8bdb-d9af4cab8e62" }, { "cell_type": "code", "execution_count": 150, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/plain": [ "array([[88.59212369, 90.00967758, 80.87907551, 57.04075186],\n", " [90.44044671, 89.98145875, 79.93920263, 60.65439373],\n", " [88.59212369, 82.69209965, 67.47304639, 43.63362743],\n", " [47.92350641, 55.86149947, 40.12630608, 0. ]])" ] } } ], "source": [ "agent.q.max(-1).T" ], "id": "19b890c1-a87a-4371-ab09-093271b4a0e1" }, { "cell_type": "markdown", "metadata": {}, "source": [ "갇혀서 업데이트가 되지 않는 상황\n", "\n", "max로만 가지 말고 랜덤으로 다른 action 취해보고 좋으면\n", "거기로 가자\n", "\n", "# Agnet3 클래스 구현 + Run" ], "id": "e684251d-b645-46de-8064-6ed8a8d32c83" }, { "cell_type": "code", "execution_count": 151, "metadata": {}, "outputs": [], "source": [ "class Agent3(Agent2):\n", " def __init__(self,env):\n", " super().__init__(env)\n", " self.eps = 0 \n", " def act(self):\n", " if np.random.rand() < self.eps:\n", " self.action = self.action_space.sample() \n", " else:\n", " x,y = self.current_state \n", " self.action = self.q[x,y,:].argmax()" ], "id": "50575733-9c15-4c33-81fd-9a8b321a54d0" }, { "cell_type": "code", "execution_count": 152, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epsiode: 200 Score: -12.82 Playtime: 3.82 Epsilon: 0.82\n", "Epsiode: 400 Score: -13.66 Playtime: 4.66 Epsilon: 0.67\n", "Epsiode: 600 Score: -11.49 Playtime: 6.89 Epsilon: 0.55\n", "Epsiode: 800 Score: -13.44 Playtime: 12.68 Epsilon: 0.45\n", "Epsiode: 1000 Score: -14.79 Playtime: 15.04 Epsilon: 0.37\n", "Epsiode: 1200 Score: -12.01 Playtime: 15.29 Epsilon: 0.30\n", "Epsiode: 1400 Score: 28.38 Playtime: 12.57 Epsilon: 0.25\n", "Epsiode: 1600 Score: 72.7 Playtime: 6.3 Epsilon: 0.20\n", "Epsiode: 1800 Score: 73.92 Playtime: 6.18 Epsilon: 0.17\n", "Epsiode: 2000 Score: 82.54 Playtime: 6.36 Epsilon: 0.14\n", "Epsiode: 2200 Score: 82.56 Playtime: 6.34 Epsilon: 0.11\n", "Epsiode: 2400 Score: 77.03 Playtime: 6.37 Epsilon: 0.09\n", "Epsiode: 2600 Score: 81.36 Playtime: 6.53 Epsilon: 0.07\n", "Epsiode: 2800 Score: 88.94 Playtime: 6.56 Epsilon: 0.06\n", "Epsiode: 3000 Score: 83.42 Playtime: 6.76 Epsilon: 0.05\n", "Epsiode: 3200 Score: 93.8 Playtime: 6.1 Epsilon: 0.04\n", "Epsiode: 3400 Score: 91.67 Playtime: 6.03 Epsilon: 0.03\n", "Epsiode: 3600 Score: 89.19 Playtime: 6.4 Epsilon: 0.03\n", "Epsiode: 3800 Score: 91.32 Playtime: 6.38 Epsilon: 0.02\n", "Epsiode: 4000 Score: 93.71 Playtime: 6.19 Epsilon: 0.02\n", "Epsiode: 4200 Score: 93.89 Playtime: 6.01 Epsilon: 0.01\n", "Epsiode: 4400 Score: 92.45 Playtime: 6.44 Epsilon: 0.01\n", "Epsiode: 4600 Score: 94.94 Playtime: 6.06 Epsilon: 0.01\n", "Epsiode: 4800 Score: 90.08 Playtime: 6.61 Epsilon: 0.01\n", "Epsiode: 5000 Score: 93.91 Playtime: 5.99 Epsilon: 0.01" ] } ], "source": [ "env = GridWorld() \n", "agent = Agent3(env) \n", "agent.eps = 1\n", "for _ in range(5000):\n", " ## 본질적인 코드 \n", " agent.current_state = env.reset()\n", " agent.terminated = False \n", " agent.score = 0 \n", " for t in range(50):\n", " # step1: agent >> env \n", " agent.act() \n", " env.agent_action = agent.action \n", " # step2: agent << env \n", " agent.next_state, agent.reward, agent.terminated = env.step(env.agent_action)\n", " agent.save_experience() \n", " # step3: learn \n", " agent.learn()\n", " # step4: state update \n", " agent.current_state = agent.next_state \n", " # step5: \n", " if agent.terminated: break \n", " agent.scores.append(agent.score) \n", " agent.playtimes.append(t+1)\n", " agent.n_episodes = agent.n_episodes + 1\n", " agent.eps = agent.eps * 0.999\n", " ## 덜 본질적인 코드 \n", " if (agent.n_episodes % 200) ==0:\n", " print(\n", " f\"Epsiode: {agent.n_episodes} \\t\"\n", " f\"Score: {np.mean(agent.scores[-100:])} \\t\"\n", " f\"Playtime: {np.mean(agent.playtimes[-100:])}\\t\"\n", " f\"Epsilon: {agent.eps : .2f}\"\n", " ) " ], "id": "50ba284c-2a6f-4218-b521-533352867e68" }, { "cell_type": "code", "execution_count": 153, "metadata": {}, "outputs": [ { "output_type": "display_data", "metadata": {}, "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " Once\n", " \n", " Loop\n", " \n", " Reflect\n", " \n", " \n", "\n", "\n", "\n", "" ] } } ], "source": [ "states = [np.array([0,0])] + agent.next_states[-agent.playtimes[-1]:] \n", "show(states)" ], "id": "c1513a2f-7822-4b77-af8f-6a11be9eb192" } ], "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "name": "python3", "display_name": "Python 3", "language": "python" }, "language_info": { "name": "python", "codemirror_mode": { "name": "ipython", "version": "3" }, "file_extension": ".py", "mimetype": "text/x-python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" } } }