๐NerveNet
Abstract
We address the problem of learning
structured policies
forcontinuous
control. In traditional reinforcement learning, policies of agents are learned by multi-layer perceptrons (MLPs) which take the concatenation of all observations from the environment as input for predicting actions. In this work, we proposeNerveNet
toexplicitly model the structure of an agent
, which naturally takes the form ofa graph
. Specifically, serving as the agentโspolicy network
, NerveNetfirst propagates information over the structure of the agent
and thenpredict actions for different parts of the agent
. In the experiments, we first show that our NerveNet is comparable to state-of-the-art methods on standard MuJoCo environments. We further propose our customized reinforcement learning environments for benchmarking two types of structuretransfer learning tasks
, i.e., size and disability transfer, as well asmulti-task learning
. We demonstrate that policies learned by NerveNet are significantly more transferable and generalizable than policies learned by other models and are able to transfer even in azero-shot setting
.
๋ณดํต ๊ฐํํ์ต์์ agent๋ค์ policy๋ multi-layer perceptrons (MLPs)์ผ๋ก ๋คํธ์ํฌ๋ฅผ ๋ง๋ค๊ธฐ ๋๋ฌธ์ agent๊ฐ environment์์ ๋ฐ์ observation๋ค์ ๋จ์ํ ์์์(concatenation) policy network์ ์
๋ ฅ์ผ๋ก ๋ค์ด๊ฐ๊ฒ ๋๋ค. ํ์ง๋ง ์์ ์๋ ์ ๋ณด
์ ๋ฐ์ ์๋ ์ ๋ณด
๊ฐ ๊ฐ์ ์๋
๋ฒ์ฃผ์ด์ง๋ง ์์น๊ฐ ๋ค๋ฅด๊ธฐ ๋๋ฌธ์
๊ตฌ๋ถ์ด ์์ ์ ์๋ฏ์ด agent์ ์ด๋ฐ ๊ตฌ์กฐ์ ์ธ ํน์ฑ์ ๋ฐ์ํด์ policy๋ฅผ ๋ง๋ ๋ค๋ฉด observation ์ ๋ณด๋ค๊ฐ์ ๊ตฌ๋ถ์ ํ ์ ์์ ๊ฒ์ด๋ค. ์ด๋ฐ agent์ ๊ตฌ์กฐ์ ๊ด๊ณ์ฑ์ ๋ํ๋ด๊ธฐ ์ํด์ MLP๋์ ๊ทธ๋ํ๋ฅผ ํ์ฉํ๊ฒ ๋์๊ณ NerveNet์ ๊ณ ์ํ๊ฒ ๋์๋ค. NerveNet์ ๊ทธ๋ํ ๊ตฌ์กฐ๋ก ๋์ด ์๋ policy network์์ ๊ฐ ๋
ธ๋๋ค์ ์ ๋ณด๋ค์ด ์ ํ(propagation)๋๋ฉฐ agent์ ๋ถ๋ถ๋ค์ ๋ํ๋ด๋ ๋
ธ๋๋ง๋ค action์ prediction ํ๊ฒ ๋๋ค. MuJoCo ํ๊ฒฝ์์ MLP ๊ธฐ๋ฐ์ ๋ฒค์น๋งํฌ๋ค๊ณผ ๋น๋ฑํ ํ์ต๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ฃผ์์ผ๋ฉฐ, transfer learning task
๋ก agent์ ํฌ๊ธฐ(size)์ agent์ ์ผ๋ถ ํํธ๊ฐ ์๋ํ์ง ์๋(disability) variation์ ์ฃผ์์ ๋๋ ์ ํ์ต๋์์ผ๋ฉฐ multi-task learning
์ผ๋ก walker ๊ทธ๋ฃน์ ๋ค์ํ ํ๊ฒฝ์์์ ํ์ต ๊ฒฐ๊ณผ๋ค๋ ์ข์๋ค. ์ด๋ฐ ๊ฒฐ๊ณผ๋ค์ ํตํด NerveNet์ด transferable
ํ ๋ฟ๋ง ์๋๋ผ zero-shot setting
๋ ๊ฐ๋ฅํจ์ ๋ณด์ฌ์ฃผ์๋ค.
transferable
- A task๋ฅผ ํ์ตํ ๋คํธ์ํฌ(weights)๋ฅผ ํ์ฉํ์ฌ B task ํ์ต์๋ ์ ์ฉํ์ฌ scratch์์ B task๋ฅผ ํ์ตํ๋ ๊ฒ๋ณด๋ค ๋ ๋น ๋ฅด๊ณ ํจ์จ์ ์ธ ํ์ต์ ๊ฐ๋ฅํ๊ฒ ํ ์ ์๋ค๋ ์๋ฏธ. A task ํ์ต์์ ์ต๋ํ ๋ ผ๋ฆฌ์ฒด๊ณ๋ฅผ B task์๋ ์ ์ฉํ ์ ์์์ผ๋ก ๋ณผ ์ ์๋ค.zero-shot
- Meta learning์์ ์ฌ์ฉ๋๋ ์ฉ์ด๋ก A task์ ๋ํด์ ํ์ต๋ ๋คํธ์ํฌ๊ฐ fine tuning์ด ์์ด ๋ฐ๋ก unseen new task B์ ๋ํด์ ์ข์ ์ฑ๋ฅ์ ๋ด๋ ๊ฒ์ ์๋ฏธ.
Introduction
๋ง์ ๊ฐํํ์ต ๋ฌธ์ ๋ค์์ agent๋ค์ ์ฌ๋ฌ๊ฐ์ ๋
๋ฆฝ์ ์ธ controller๋ค๋ก ๊ตฌ์ฑ๋์ด ์๋ค. ์๋ฅผ๋ค์ด ๋ก๋ด์ ์ ์ด์์ ๊ฐํํ์ต์ด ๋ง์ด ์ ์ฉ๋๊ณ ์๋๋ฐ, ๋ก๋ด์ ์ฌ๋ฌ๊ฐ์ ๋งํฌ(link)๋ค๊ณผ ์กฐ์ธํธ(joint)๋ค๋ก ์ด๋ฃจ์ด์ ธ ์๊ณ ์์ง์์ด ์ผ์ด๋๋ joint๋ค์ ๊ฐ ๊ฐ๋ณ์ ์ธ controller๋ก ๋ณผ ์ ์๋ค. ์ด๋ link๋ ๋ก๋ด์ ๋ฌผ๋ฆฌ์ ์ธ ํํ๋ฅผ ๊ตฌ์ฑํ๋ ๋ผ๋์ฒ๋ผ ์๊ฐํ๋ฉด ๋๊ณ joint๋ ๋ก๋ด์ ๋ชจ์
์ ๊ฒฐ์ ํ๋ ๊ด์ ๋ก ์๊ฐํ ์ ์๋ค.(๋ณดํต ํ์ ์ด ๊ฐ๋ฅํ revolute joint๋ฅผ ์ฌ์ฉํ๋ค.) ๋ก๋ด์ link-joint-link- ... -joint-link
์ ๊ฐ์ด link์ joint๊ฐ ์ฒด์ธ์ฒ๋ผ ์ฐ๊ฒฐ๋์ด์ ๊ตฌ์ฑ๋๋๋ฐ, ๊ฐ link์ joint์ ์์ง์์ ์์ ์ ์ํ์๋ง ์์กดํ๋ ๊ฒ์ด ์๋๋ผ ์ฐ๊ฒฐ๋ ์ฃผ๋ณ link์ joint๋ค์๊ฒ์๋ ์ํฅ์ ๋ฐ์ ์ ๋ฐ์ ์๋ค. ๋ก๋ด์ ์์ง์ด๋๋ก ์ ์ด
๋ฅผ ํ๋ค๋ ๊ฒ์ ๋ฐ๋ก ๋ชจ์
์ ๋ง๋๋ joint์ ์์ง์์ ๊ฒฐ์
ํ๋ ๊ฒ์ด๋ค. ๋ฐ๋ผ์ robot agent๋ ๊ฐํํ์ต์์์ action์ joint์ ์ ์ดํ๋ ๊ฒ์ผ๋ก ์๊ฐํ ์ ์๊ณ action์ผ๋ก ๋ก๋ด์ joint์ ๊ฐ๋๋ฅผ ๋ฐ๊ฟ๊ฐ๋ฉฐ ๋ก๋ด์ ์์ง์ด๊ฒ ๋๋ค.
๋ณดํต ๊ฐํํ์ต์์ agent์ policy๋ MLP๋ก ๊ตฌ์ฑํ๋ค. MLP๊ธฐ๋ฐ์ policy๋ ๊ฐ์ฅ ๋จ์ํ ๋คํธ์ํฌ ๊ตฌ์กฐ๋ก input์ผ๋ก agent๊ฐ ์ป๋ observation ์ ๋ณด๋ฅผ concatenation
ํด์ ๋ฃ์ด์ฃผ๊ฒ ๋๋ค. ๋ค์ ๋ก๋ด agent์ ์์๋ก ๋์๊ฐ์ ์๊ฐํด๋ณด๋ฉด, agent๊ฐ observation์ผ๋ก ์ฌ์ฉํ๋ ์ ๋ณด๋ก๋ ๊ฐ joint์ ํ์ ๊ฐ๋, ํ์ ๊ฐ์๋, ์์น ์ ๋ณด๋ฑ ๋ค์ํ ์ข
๋ฅ์ ์ ๋ณด๋ค์ด ์๊ณ ์ด ๋ค์ํ ์ ๋ณด๋ค์ ๊ฐ joint๋ก๋ถํฐ ์ป๊ฒ ๋๋ฏ๋ก ๊ฐ joint์์ ์ป์ ์ ์๋ ์ ๋ณด x joint์ ์
๊ฐ ๋ณดํต observation์ ์ฐจ์ ์๊ฐ ๋๋ค. ๋ก๋ด์ joint๊ฐ ๋ง์์ง ์๋ก ์ฐจ์์ด ๋ฐฐ๋ก ์ปค์ง๊ฒ ๋๊ณ ์ด๋ฐ ์ฌ๋ฌ ์ ๋ณด๋ค์ ๋จ์ํ concatenateํด์ policy ๋คํธ์ํฌ์ ๋ฃ์ด์ฃผ๋ ๊ฒ์ ๋ ๋ง์ training time์ ์๊ตฌํ๊ฒ ๋๊ณ ๋ ๋ง์ ํ๊ฒฝ๊ณผ agent ๊ฐ์ interaction ๊ณผ์ ์ด ํ์ํ๊ฒ ๋๋ค. ๋ฐ๋ผ์ ๋ณธ ๋
ผ๋ฌธ์์๋ agent๊ฐ ๊ฐ์ง๊ณ ์๋ (link์ joint๋ก ์ด๋ฃจ์ด์ง) ๊ตฌ์กฐ์ ์ธ ํน์ฑ์ ์ด์ฉํด์ policy๋ฅผ ๊ทธ๋ํ๋ก ๋ง๋ค์ด observation์ ๋ฃ์ด์ฃผ๊ณ ํ์ต์ ํ๋ ๊ฒ์ ์ ์ํ๊ฒ ๋๋ค.
๋ก๋ด์ด๋ ๋๋ฌผ๋ค์ ์ ์ฒด์ ์ธ ๊ตฌ์กฐ๋ฅผ ๋ณด๋ฉด ๊ทธ๋ํ ๊ตฌ์กฐ์ ์ ์ฌํ๋ค. ์์์ ์ค๋ช
ํ link์ joint์ ์ฒด์ธ๊ณผ ๊ฐ์ ์ฐ๊ฒฐ์ฑ์ Graph Neural Network
๋ฅผ ์ ์ฉํ๊ธฐ์ ์ข๋ค. ๊ทธ๋์ NerveNet์์ ์ ๋ณด๋ค์ propagation์ ์ด๋ฐ ๊ทธ๋ํ ๊ตฌ์กฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ผ์ด๋๊ฒ ๋๊ณ agent์ body์ ๋ณด๋ฅผ ์ฌ๋ฌ ๋ค๋ฅธ ํํธ๋ค์ ๊ทธ๋ํ์ node์ edge๋ก ์ ์ํ๋ฉด์ ์์ง์์ด ์ผ์ด๋๋ body node๋ค์ action์ ๊ฒฐ์ ํ๊ฒ ๋๋ค.
NerveNet
์ฐ์ ๊ฐํํ์ต์ Notation์ ์ ๋ฆฌํด๋ณด๋ฉด, ๋ณธ ๋
ผ๋ฌธ์์๋ locomotion control problem๋ค์ ๋ชฉํ๋ก ์ก์๊ธฐ ๋๋ฌธ์ infinite-horizon discounted Markov decision process (MDP)
๋ก ์ค์ ํ๋ค. ๋ณดํต continuous
ํ ์ ์ด ๋ฌธ์ ์์๋ ์๊ฐ ํ ์ธ์จ์ ๊ณ ๋ คํ ๋ฌดํ ์๊ฐ ์คํ
์ ๊ฐ์ง๊ณ MDP๋ฅผ ๊ตฌ์ฑํ๊ฒ ๋๋ค(์ค์ ๋ก๋ max step์ ์ค์ ํ๊ธดํ๋ ๋งค์ฐ ํฐ ์๋ก ์ก๋๋ค).
state ํน์ observation space๋ก S , action space๋ก A, stochastic policy \pi_{\theta}\left(a^{\tau} \mid s^{\tau}\right) ๋ ํ๋ผ๋ฏธํฐ \theta๋ฅผ ๊ฐ์ง๊ณ ํ์ฌ ์ํ s๋ฅผ ๊ธฐ๋ฐ์ผ๋ก a๋ฅผ ๋ง๋ค๊ฒ ๋๋ค. ์ด๋ ๊ฒ ๋์จ agent์ a์ s๋ฅผ ๊ฐ์ง๊ณ ํ๊ฒฝ์์๋ reward r\left(s^{\tau}, a^{\tau}\right) ๋ฅผ ์ฃผ๊ฒ ๋๊ณ agent๋ ์ด๋ ๊ฒ ๋ฐ๋ reward๋ฅผ ์ต๋ํํ๋ ๋ฐฉ๋ฒ์ ํ์ตํ๊ฒ ๋๋ค.
์ด๋ฌํ MDP ๊ตฌ์ฑ์ ๊ธฐ๋ณธ์ ์ธ ๊ฐํํ์ต์ Notation์์ ํฌ๊ฒ ๋ฒ์ด๋์ง ์๋๋ค.
Graph Construction
๋ณธ ๋
ผ๋ฌธ์์๋ ์ฌ์ฉํ MuJoCo์ agent๋ค์ ์ด๋ฏธ ๊ตฌ์กฐ์ ์ผ๋ก tree ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์๋ค. NerveNet์ ํต์ฌ ์์ด๋์ด์ธ ๊ทธ๋ํ๋ฅผ ๊ตฌ์ฑํ๊ธฐ ์ํด body
์ joint
, root
๋ผ๋ 3๊ฐ์ง ์ข
๋ฅ์ ๋
ธ๋๋ฅผ ์ค์ ํ๋ค. body
๋
ธ๋๋ ๋ก๋ด๊ณตํ์์ ๋งํ๋ link ๊ธฐ์ค์ ์ขํ์์คํ
์ ๋ํ๋ด๋ ๋
ธ๋์ด๊ณ , joint
๋
ธ๋๋ ๋ชจ์
์ ์์ ๋(freedom of motion)์ ๋ํ๋ด๋ฉฐ 2๊ฐ์ body ๋
ธ๋๋ค์ ์ฐ๊ฒฐํด์ฃผ๋ ๋
ธ๋์ด๋ค.
์๋๋ Ant
ํ๊ฒฝ์ ์์์ธ๋ฐ, ํ ๊ฐ์ง ๊ทธ๋ฆผ์์ ํท๊ฐ๋ฆฌ์ง ๋ง์์ผ ํ ์ ์ ๊ทธ๋ฆผ์์๋ ๋ง์น body
์ root
๋
ธ๋๋ง ๋
ธ๋๋ก ๋ง๋ ๊ฒ ์ฒ๋ผ ๋ณด์ด์ง๋ง root์ body, body์ body๋ฅผ ์ฐ๊ฒฐํ๋ ์ฃ์ง๋ค๋ ์ค์ ๋ก๋ joint
๋
ธ๋๋ค์ด๋ค.(we omit the joint nodes and use edges to represent the physical connections of joint nodes.)root
๋ผ๋ ๋
ธ๋๋ agent์ ์ถ๊ฐ์ ์ธ ์ ๋ณด๋ค์ ๋ด์ ๋ถ๋ถ์ผ๋ก ์ฌ์ฉํ๊ธฐ ์ํด ์ถ๊ฐํ ๋
ธ๋ ์ข
๋ฅ๋ก, ์๋ฅผ ๋ค์ด agent๊ฐ ๋๋ฌํด์ผ ํ๋ target position์ ๋ํ ์ ๋ณด ๋ฑ์ด ๋ด๊ฒจ์๋ค.
NerveNet as Policy
ํฌ๊ฒ 3๊ฐ์ง ํํธ๋ก NerveNet์ ์ดํด๋ณผ ๊ฒ์ธ๋ฐ ์ฐ์ (0) Notation
์ ๋ณด๊ณ ๋๋ค, (1) Input model
(2) Propagation model
(3) Output model
์์ผ๋ก ์ดํด๋ณผ ์์ ์ด๋ค.
0. Notation
๊ทธ๋ํ์์์ ๋ ธํ ์ด์ ์ ๋ค์๊ณผ ๊ฐ์ด G ๋ผ๋ ๊ทธ๋ํ๋ ๋ ธ๋ ์งํฉ V์ ์ฃ์ง ์งํฉ E๋ก ๊ตฌ์ฑ๋๋ค.
G=(V, E)
Nervenet policy๋ฅผ ๊ตฌ์ฑํ๋ ๊ทธ๋ํ๋ Directed graph(์ ํฅ ๊ทธ๋ํ)
์ด๊ธฐ ๋๋ฌธ์ ๊ฐ ๋
ธ๋์์์ in
๊ณผ out
์ด ๋ฐ๋ก ๋ช
์๋๊ฒ ๋๋ค.
- ๋ ธ๋ u๋ฅผ ์ค์ฌ์ผ๋ก ๋ ธ๋ u๋ก ๋ค์ด์ค๋ ์ด์ ๋ ธ๋์ด๋ฉด \mathcal{N}_{in}(u)
- ๋ ธ๋ u๋ฅผ ์ค์ฌ์ผ๋ก ๋ ธ๋ u์์ ๋๊ฐ๋ ์ด์ ๋ ธ๋์ด๋ฉด \mathcal{N}_{out}(u)
๊ทธ๋ํ์ ๋ชจ๋ ๋
ธ๋ u๋ ํ์
์ ๊ฐ์ง๊ฒ ๋๊ณ ์ด๋ฅผ p_{u} \in\{1,2, \ldots, P\} (associated note type)๋ก ๋ํ๋ด๋ฉฐ ์ฌ๊ธฐ์์๋ ์์ ์ค๋ช
ํ ๊ฒ๊ณผ ๊ฐ์ด body
, joint
, root
3๊ฐ์ง ํ์
์ด ์๋ค.
๋ ธ๋๋ค ๋ฟ๋ง ์๋๋ผ ์ฃ์ง๋ค๋ ํ์ ์ ์ ํ ์ ์๋๋ฐ c_{(u, v)} \in\{1,2, \ldots, C\} (associate each edge)๋ก ํ๊ธฐํ์ฌ ๋ ธ๋์ (u, v) ์ฌ์ด์ ์ฃ์ง ํ์ ์ ์ ์ํ ์ ์๋ค.(ํ๋์ ์ฃ์ง์ ๋ํด์ ์ฌ๋ฌ ์ฃ์ง ํ์ ์ ์ ์ํ ์ ์์ง๋ง ์ฌ๊ธฐ์์๋ ์ฌํ ์ด์ฆ ๋ ๋ฒ ์คํธ ์ฒ ํ์ผ๋ก ํ๋์ ์ฃ์ง๋ ํ๋์ ํ์ ๋ง ๊ฐ์ง๋๋ก ํ๋ค)
์ด๋ ๊ฒ ๋ ธ๋๋ณ, ์ฃ์ง๋ณ ํ์ ์ ๋๋์ผ๋ก์จ,
๋ ธ๋ ํ์
์ ๋ ธ๋๋ค๊ฐ์ ๋ค๋ฅธ ์ค์๋๋ฅผ ํ์ ํ๋๋ฐ ๋์์ด ๋๊ณ์ฃ์ง ํ์
์ ๋ ธ๋๋ค๊ฐ์ ์๋ก๋ค๋ฅธ ๊ด๊ณ๋ค์ ๋ํ๋ด๊ณ ์ด ๊ด๊ณ์ ์ข ๋ฅ์ ๋ฐ๋ผ ์ ๋ณด๋ฅผ ๋ค๋ฅด๊ฒ propagation ํ๊ฒ ๋๋ค.
์ด์ ์๊ฐ ๋ ธํ ์ด์ ์ ๋ํ ๋ถ๋ถ์ ์ดํด๋ณด์. NerveNet์๋ ์๊ฐ(time step)์ ๊ฐ๋ ์ด 2๊ฐ์ง ์กด์ฌํ๋ค.
- ๊ธฐ์กด ๊ฐํํ์ต์์ ํ๊ฒฝ๊ณผ agent ์ฌ์ด์ interaction time step์ ๋ํ๋ด๋ \tau
- NerveNet์ ๋ด๋ถ graph policy์์์ propagation step์ ๋ํ๋ด๋ t
๋ค์ ํ์ด์ ์๊ฐํด๋ณด๋ฉด, ๊ฐํํ์ต์ ์๊ฐ ๊ฐ๋ \tau ์คํ ์์ ํ๊ฒฝ์ผ๋ก๋ถํฐ observation์ ๋ฐ๊ณ , ๋ฐ์ observation์ ๊ธฐ๋ฐ์ผ๋ก t ์คํ ๋์ NerveNet์ ๋ด๋ถ์ ๊ทธ๋ํ์ propagation์ด ์ผ์ด๋๋ค.
1. Input model
์์์ ๋งํ๋ฏ์ด ํ๊ฒฝ๊ณผ ์ํธ์์ฉ์ผ๋ก observation s^{\tau} \in \mathcal{S}์ ๋ฐ๊ฒ ๋๋ค(time step \tau). ์ด s^{\tau}๋ concatenation๋ ๊ฐ ๋
ธ๋์ observation์ด๋ผ๊ณ ๋ณผ ์ ์๋ค. ์ด์ ๊ฐํํ์ต interaction ์์ค์ \tau ์คํ
์ ์ ์ ๋ฉ์ถฐ๋๊ณ ๊ทธ๋ํ ๋ด๋ถ์ ํ์ ์คํ
์ธ t ์์ค์์ ์๊ฐํด๋ณด์. observation์ node u์ ํด๋นํ๋ x_{u}๋ก ํํํ ์ ์๊ณ x_{u}๋ input network F_{\mathrm{in}}(MLP)๋ฅผ ๊ฑฐ์ณ์ ๊ณ ์ ๋ ํฌ๊ธฐ์ state vector์ธ h_{u}^{0}๊ฐ ๋๋ค. h_{u}^{0}์ ๋
ธํ
์ด์
์ ํ์ด์ ํด์ํ๋ฉด ๋
ธ๋ u
์ propagation step 0
์์์ state vector์ธ ๊ฒ์ด๋ค. ์ด๋ observation vector x_{u}๊ฐ ๋
ธ๋๋ง๋ค ํฌ๊ธฐ๊ฐ ๋ค๋ฅผ ๊ฒฝ์ฐ zero padding์ผ๋ก ๋ง์ถฐ์ input network์ ๋ฃ์ด์ฃผ๊ฒ ๋๋ค.
h_{u}^{0}=F_{\text {in }}\left(x_{u}\right)
2. Propagation model
NerveNet์ propagation ๊ณผ์ ๋
ธ๋๋ค ๊ฐ์ ์ฃผ๊ณ ๋ฐ๋ ์ ๋ณด๋ฅผ message
๋ผ๊ณ ํ๊ฒ ๋๊ณ ์ด๋ ๋
ธ๋๋ค ๊ฐ์ ์ฃผ๊ณ ๋ฐ๋ ์ํธ์์ฉ์ด๋ผ๊ณ ์๊ฐํ ์ ์๋ค. Propagation model์ 3๊ฐ์ง ๋จ๊ณ๋ก ๋๋์ด์ ๋ณผ ์ ์๋ค.
- Message Computation
์ ๋ฌํ ๋ฉ์ธ์ง๋ฅผ ๊ณ์ฐํ๋ค.
propagation step์ธ t์, ๋ชจ๋ ๋ ธ๋๋ค u์์ state vector h_{u}^{t}๋ฅผ ์ ์ํ ์ ์๋ค.
๋ ธ๋ u๋ก ๋ชจ์์ง๋(in-coming) ๋ชจ๋ ์ฃ์ง๋ค์ ๊ฐ์ง๊ณ ๋ฉ์์ง๋ฅผ ๊ตฌํ๊ฒ ๋๋๋ฐ, ์ด๋ M์ MLP์ด๊ณ M์ ์๋์ฒจ์ c_{(u, v)} ๋ ธํ ์ด์ ์์ ์ ์ ์๋ฏ์ด ๊ฐ์ ์ข ๋ฅ์ ์ฃ์ง์ ๋ํด์๋ ๊ฐ์ message function M์ ์ด๋ค.
m_{(u, v)}^{t}=M_{c_{(u, v)}}\left(h_{u}^{t}\right)
์๋ฅผ ๋ค์ด ์๋ ๊ทธ๋ฆผ์
CentipedeEight
์ ๋ชจ์ต์ธ๋ฐ, ์ผ์ชฝ์ ์ค์ agent์ ๋ชจ์ต์ ๋ํ๋ด๊ณ ์์ผ๋ฉฐ ์ค๋ฅธ์ชฝ์ agent๋ฅผ ๊ทธ๋ํ๋ก ๋ํ๋์ ๋์ ๋ชจ์ต์ด๋ค. ์ฌ๊ธฐ์์ 2๋ฒ์งธ torso์์ ์ฒซ๋ฒ์งธ ์ธ๋ฒ์งธ torso์์ ๋ณด๋ผ ๋ ๊ฐ์ ๋ฉ์ธ์ง ํ์ M_{1} ์ ์ฌ์ฉํ๊ณ , LeftHip๊ณผ RightHip์ผ๋ก ๋ณด๋ด๋ ๋ฉ์ธ์ง ํ์ M_{2}๋ฅผ ์ฌ์ฉํ๊ฒ ๋๋ ๊ฒ์ด๋ค.
- Message Aggregation
์ ๋จ๊ณ์์ ๋ชจ๋ ๋ ธ๋๋ค์ ๋ํด์ ๋ฉ์ธ์ง ๊ณ์ฐ์ด ๋๋ ํ์ in-coming ์ด์ ๋ ธ๋๋ค๋ก๋ถํฐ ์จ(๊ณ์ฐ๋) ๋ฉ์ธ์ง๋ฅผ ๋ชจ์ผ๊ฒ ๋๋ค. ์ด๋ summation, average, max-pooling ๋ฑ ๋ค์ํ aggregation ํจ์๋ฅผ ์ฌ์ฉํ ์ ์๋ค.
\bar{m}_{u}^{t}=A\left(\left\{h_{v}^{t} \mid v \in \mathcal{N}_{i n}(u)\right\}\right)
- States Update
์ด์ ๋ชจ์ ๋ฉ์ธ์ง๋ฅผ ๊ธฐ๋ฐ์ผ๋ก state vector๋ฅผ ์ ๋ฐ์ดํธ ํ๋ฉด ๋๋ค!
h_{u}^{t+1}=U_{p_{u}}\left(h_{u}^{t}, \bar{m}_{u}^{t}\right)
์ฌ๊ธฐ์ ์ ๋ฐ์ดํธ ํจ์ U ๋ a gated recurrent unit (GRU), a long short term memory (LSTM) unit ๋๋ MLP๊ฐ ๋ ์ ์๋ค.
Update function์ ์๋์ฒจ์ p_{u}์์ ๋ณผ ์ ์๋ค์ํผ ๊ฐ์ ๋ ธ๋ ํ์ ์ด๋ฉด ๊ฐ์ update function U๋ฅผ ์ฐ๊ฒ ๋๋ค. ์ด๋ ๊ฒ ์ ๋ฐ์ดํธ๋ state vector๋ ํ์ ์คํ t๊ฐ ํ๋ ์ฌ๋ผ๊ฐ t+1 ์ด ๋ h_{u}^{t+1}๊ฐ ๋๋ค.
์ด๋ ๊ฒ ๋ด๋ถ propagation ๊ณผ์ 3๋จ๊ณ(Message Computation, Message Aggregation, States Update)๊ฐ T ์คํ ๋์ ์ผ์ด๋๊ฒ ๋๊ณ ๊ฐ ๋ ธ๋์ ์ต์ข state vector๋ h_{u}^{T} ๊ฐ ๋๋ค.
3. Output model
์ ํ์ ์ธ RL์ MLP ํด๋ฆฌ์์์๋ ๋คํธ์ํฌ์์ ๊ฐ action์ gaussian distribution์ mean์ ๋ฝ์๋ด๊ฒ ๋๋ค. std๋ trainableํ ๋ฒกํฐ์ด๋ค. NerveNet์์๋ std๋ ๋น์ทํ๊ฒ ๋ค๋ฃจ์ง๋ง ๊ฐ ๋
ธ๋์ ๋ง๋ค
action prediction์ ๋ง๋ค๊ฒ ๋๋ค.
actuator์ ์ฐ๊ฒฐ๋์ด ์๋ ๋ ธ๋๋ค์ ์งํฉ์ O๋ผ๊ณ ํ์. ์ด ์งํฉ์ ์๋ ๋ ธ๋๋ค์ ์ต์ข state vector h_{u \in \mathcal{O}}^{T}๋ MLP์ธ Ouput model O_{q_{u}}์ ์ธํ์ผ๋ก ๋ค์ด๊ฐ๊ฒ ๋๊ณ ์์ํ์ผ๋ก ๊ฐ actuator์ action distribution์ธ gaussian distribution์ mean \mu์ ์ถ๋ ฅํ๊ฒ ๋๋ค. ์ฌ๊ธฐ์์ ์๋ก์ด ๋ ธํ ์ด์ q_{u}๋ฅผ ๋ณผ ์ ์๋๋ฐ q_{u}๋ ์์ํ ํ์ , ์ฆ ์์ํ์ ๋ด๋๋ ๋ ธ๋ u์ ํ์ ์ผ๋ก ์์ํ ํ์ ์ ์๋์ฒจ์์ q_{u}์ ๋ฐ๋ผ ์์ํ ๋ ธ๋์ ํ์ ์ด ๊ฐ์ผ๋ฉด Output function์ ๊ณต์ ํ ์ ์๋ค. ๋ค์๋งํด ์์ํ ๋ ธ๋ ํ์ ์ ๋ฐ๋ผ ์ปจํธ๋กค๋ฌ๋ฅผ ๊ณต์ ํ ์๋ ์๋ ๊ฒ์ด๋ค. ์์ Centipedes์ ์์๋ก ๋ณด๋ฉด, ๊ฐ์ LeftHip ๋ผ๋ฆฌ๋ ์ปจํธ๋กค๋ฌ๋ฅผ ๊ณต์ ํ ์ ์๋ค๋ ๊ฒ์ด๋ค.
\mu_{u \in \mathcal{O}}=O_{q_{u}}\left(h_{u}^{T}\right)
๋ ผ๋ฌธ์์ ์ค์ ๋ก ์คํ์ ํด๋ดค์ ๋ ๋ค๋ฅธ ํ์ ์ ์ปจํธ๋กค๋ฌ๋ค์ ํ๋๋ก ํตํฉํ๋๋ผ๋(O function์ ๋ค ๊ฐ์ MLP๋ก ์ฌ์ฉ) ํผํฌ๋จผ์ค๊ฐ ๊ทธ๋ ๊ฒ ํด์ณ์ง์ง ์์์ ํ์ธํ ์ ์์๋ค๊ณ ํ๋ค.
์ฌ๊ธฐ๊น์งํด์ ๊ทธ๋ํ ๋ ธํ ์ด์ ์ ๋น๋ ค ๊ทธ๋ํ ๊ธฐ๋ฐ ๊ฐ์ฐ์์ stochastic policy๋ฅผ ๋ํ๋ด๋ฉด ์๋์ ์์๊ณผ ๊ฐ๋ค.
\pi_{\theta}\left(a^{\tau} \mid s^{\tau}\right)=\prod_{u \in \mathcal{O}} \pi_{\theta, u}\left(a_{u}^{\tau} \mid s^{\tau}\right)=\prod_{u \in \mathcal{O}} \frac{1}{\sqrt{2 \pi \sigma_{u}^{2}}} e^{\left(a_{u}^{\tau}-\mu_{u}\right)^{2} /\left(2 \sigma_{u}^{2}\right)}
์ฌ๊ธฐ๊น์ง NerveNet์ ๊ฐ ๋จ๊ณ๋ฅผ Walker-Ostrich
ํ๊ฒฝ์์ ์์๋ก ํ๋์ ๋ณด๊ธฐ ์ฝ๊ฒ ์ ๋ฆฌํ ๊ทธ๋ฆผ์ ์๋์ ๊ฐ๋ค.
Learning Algorithm
์ด์ ํํธ์์ NerveNet์ ๋ด๋ถ์์ propagation ์คํ t ๋จ์์์ ๊ฐ ๋จ๊ณ๋ค์ ์์ธํ ์ดํด๋ณด์๋ค๋ฉด ์ด์ ๊ฐํํ์ต ํ์ ์คํ \tau ๋จ์์์ ํ์ต์ ๋ชฉ์ ํจ์์ ์๊ณ ๋ฆฌ์ฆ์ ์ดํด๋ณด์. ๋ชฉ์ ํจ์๋ ์ ํ์ ์ธ RL๊ณผ ๋ค๋ฅธ ์ ์ด ์์ด policy์ ํ๋ผ๋ฏธํฐ \theta๋ฅผ ๊ฐ์ง๊ณ Return ๊ฐ์ maximizationํ๋ ๊ฒ์ผ๋ก ํ๋ค.
J(\theta)=\mathbb{E}{\pi}\left[\sum{\tau=0}^{\infty} \gamma^{\tau} r\left(s^{\tau}, a^{\tau}\right)\right]
๊ฐํํ์ต ์๊ณ ๋ฆฌ์ฆ์ผ๋ก๋ PPO
๊ณผ GAE
๋ฅผ ์ฌ์ฉํ์ผ๋ฉฐ ํด๋น ์๊ณ ๋ฆฌ์ฆ๋ค์ ๋ด์ฉ์ ๊ฐ๊ฐ ์๊ณ ๋ฆฌ์ฆ๋ค์ ์๋ ์์๊ณผ ๋ด์ฉ๋ค๊ณผ ์์ดํ ์ ์ด ์์ผ๋ฏ๋ก ๊ฐ ๋
ผ๋ฌธ์ผ ์ฐธ๊ณ ํ๋ฉด ๋๊ธฐ ๋๋ฌธ์ ์ด๋ฒ ๋
ผ๋ฌธ ๋ฆฌ๋ทฐ์์๋ ์๋ตํ๋ค.
PPO์ GAE ์๊ณ ๋ฆฌ์ฆ์ ์ฐธ๊ณ ํ์ฌ ์์ ๋ชฉ์ ํจ์ J๋ฅผ ์ ๋ฆฌํ๋ฉด NerveNet์ ๋ชฉ์ ํจ์๋ ๋ค์๊ณผ ๊ฐ๋ค.
\begin{aligned} \tilde{J}(\theta)=& J(\theta)-\beta L_{K L}(\theta)-\alpha L_{V}(\theta) \\ =& \mathbb{E}_{\pi_{\theta}}\left[\sum_{\tau=0}^{\infty} \min \left(\hat{A}^{\tau} r^{\tau}(\theta), \hat{A}^{\tau} \operatorname{clip}\left(r^{\tau}(\theta), 1-\epsilon, 1+\epsilon\right)\right)\right] \\ &-\beta \mathbb{E}_{\pi_{\theta}}\left[\sum_{\tau=0}^{\infty} \operatorname{KL}\left[\pi_{\theta}\left(: \mid s^{\tau}\right) \mid \pi_{\theta_{o l d}}\left(: \mid s^{\tau}\right)\right]\right]-\alpha \mathbb{E}_{\pi_{\theta}}\left[\sum_{\tau=0}^{\infty}\left(V_{\theta}\left(s^{\tau}\right)-V\left(s^{\tau}\right)^{\operatorname{target}}\right)^{2}\right] \end{aligned}
์์ ์์์์ ๋ณผ ์ ์๋ value network V๋ฅผ ์ด๋ป๊ฒ ๋์์ธํ ๊ฒ์ธ์ง
๊ฐ ์ด๋ฒ ๋
ผ๋ฌธ์ ๋ค๋ฅธ ํฌ์ธํธ๋ก ๋ณผ ์ ์๋ค. ๋
ผ๋ฌธ์ ๊ธฐ๋ณธ ์์ด๋์ด๋ policy network๋ฅผ ๊ทธ๋ํ๋ก ํํํ๋ ๊ฒ์ด๊ณ , value network๋ ์ด๋ป๊ฒ ํ ์ง ์ฌ๋ฌ ์ ํ์ง๋ค์ด ๋จ์์๋ค. ๊ทธ๋์ ๋ณธ ๋
ผ๋ฌธ์์๋ value network์ ๋์์ธ์ ๋๊ณ ํฌ๊ฒ 3๊ฐ์ง NerveNet์ ๋ณํ ์๊ณ ๋ฆฌ์ฆ๋ค์ ์คํํด๋ณด์๋ค.
NerveNet-MLP : policy network๋ฅผ 1๊ฐ์ GNN์ผ๋ก ๊ตฌ์ฑํ๊ณ
value network๋ MLP๋ก
๊ตฌ์ฑNerveNet-2 : policy network๋ฅผ 1๊ฐ์ GNN์ผ๋ก ๊ตฌ์ฑํ๊ณ
value network๋ ๋ ๋ค๋ฅธ GNN์ผ๋ก
๊ตฌ์ฑ(์ด GNN 2๊ฐ - without sharing the parameters of the two GNNs)NerveNet-1 :
policy network์ value network ๋ชจ๋ 1๊ฐ์ GNN์ผ๋ก
๊ตฌ์ฑ(์ด GNN 1๊ฐ)
Experiments
๋จผ์ MuJoCo ์๋ฎฌ๋ ์ดํฐ์์ NerveNet์ ํจ๊ณผ๋ฅผ ํ์ธํ๊ณ ์ผ๋ถ ์ปค์คํ ํ ํ๊ฒฝ๋ค์์ NerveNet์ transferable๊ณผ multi-task learning ๋ฅ๋ ฅ์ ํ์ธํ๋ค.
1. Comparison on standard benchmarks of MuJoCo
- ๋น๊ต๊ตฐ์ผ๋ก
MLP
,TreeNet(๋ชจ๋ ๋ ธ๋๋ค์ด ์ฐ๊ฒฐ ๋์ด ์๋ ๊ทธ๋ํ, depth 1)
์ ์ฌ์ฉ - ์ด 8๊ฐ์ ํ๊ฒฝ์์ ์คํ -
Reacher, InvertedPendulum, InvertedDoublePendulum, Swimmer, HalfCheetah, Hopper, Walker2d, Ant
- ์ถฉ๋ถํ ํ์ตํ๋ ์คํ ์ ์ฃผ๊ธฐ ์ํด์ 1 million์ max๋ก ๋
- ํ์ดํผ ํ๋ผ๋ฏธํฐ์ ๊ฒฝ์ฐ ๊ทธ๋ฆฌ๋ ์์น๋ก ์ฐพ์์ผ๋ฉฐ(Appendix ์ฐธ๊ณ ) ๊ฐ ์๊ณ ๋ฆฌ์ฆ์ ํผํฌ๋จผ์ค๋ฅผ ์ธก์ ํ ๋ 3๋ฒ์ run์ ๋๋ค ์๋๋ฅผ ๋ฐ๊ฟ๊ฐ๋ฉฐ ์คํ์ํจ ํ ํ๊ท ์ ๊ตฌํด์ ๊ธฐ๋ก
- ๋๋ถ๋ถ์ ํ๊ฒฝ์์ MLP๊ฐ ์๋๊ณ NerveNet๋ ์ด์ ๋น๋ฑํ ํผํฌ๋จผ์ค๋ฅผ ๋๋ค.
(3๊ฐ์ง ์ผ์ด์ค์ ๋ํ learning curve, ๋ค๋ฅธ ์ผ์ด์ค๋ค์์๋ ๋์ฒด๋ก NerveNet๊ณผ MLP๊ฐ ๋น์ทํ๋ค.)
HalfCheetah |
InvertedDoublePendulum |
Swimmer |
---|---|---|
MLP์ NerveNet์ด ๋น์ทํ๊ณ TreeNet์ด ๋ง์ด ์์ข์์ | MLP๊ฐ ์ข๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ | NerveNet์ด MLP๋ณด๋ค ์ข์ ์ฑ๋ฅ์ ๋ |
- ๋๋ถ๋ถ ํ๊ฒฝ๋ค์์
TreeNet
์ดNerveNet
๋ณด๋ค ์ข์ง ์์๊ณ ์ด๋ฅผ ํตํด์ ๋ฌผ๋ฆฌ์ ์ธ ๊ทธ๋ํ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ ธ๊ฐ๋ ๊ฒ์ด ์ผ๋ง๋ ์ค์ํ์ง ์ ์ ์๋ค.
2. Structure transfer learning
- MuJoCo์ ํ๊ฒฝ ํ๋๋ฅผ ์ปค์คํ
ํด์
size
์disability
์ ๋ณํ๊ฐ ์์ ๋ transferable ํจ์ ๊ฒ์ฆsize transfer
- ์์ ์ฌ์ด์ฆ์ ๊ทธ๋ํ๋ฅผ ๊ฐ์ง agent๋ฅผ ํ์ต ์ํจ ํ ๋ ํฐ ์ฌ์ด์ฆ์ ๊ทธ๋ํ๋ฅผ ๊ฐ์ง agent๋ก transferable ํ์งdisability transfer
- ๋ชจ๋ ํํธ๋ค์ด ์ ์์๋ํ๋ agent๋ก ํ์ตํ ํ ์ผ๋ถ ํํธ๋ค์ด ์๋ํ์ง ์๋ ์ํฉ์ agent๋ก transferable ํ์ง
- 2๊ฐ ์ข
๋ฅ์ ํ๊ฒฝ์ ์ปค์คํ
ํ์ฌ ์คํ -
centipede
์snake
centipede - ์ง๋ค์ ๊ฐ์ด ์๊ธด agent๋ก torso body๋ค์ด ์ฌ๋ฌ๊ฐ ์ฒด์ธ์ฒ๋ผ ์ฐ๊ฒฐ ๋์ด ์๊ณ torso๋ฅผ ์ค์ฌ์ผ๋ก ์์ชฝ์ ๋ค๋ฆฌ๊ฐ 1์์ผ๋ก ๋ถ์ด ์๋ค. ํ๋์ ๋ค๋ฆฌ๋ thigh์ shin์ผ๋ก ๊ตฌ์ฑ๋์ด ์๊ณ hinge actuator๋ก ๊ตฌํ๋์ด ์๋ค. ์ปค์คํ ์ ๋ค๋ฆฌ์ ๊ฐฏ์๋ฅผ ๋ค์ํ๊ฒ ํด์ ์ฌ๋ฌ ์ปค์คํ ํ๊ฒฝ๋ค์ ๋ง๋ค์๋๋ฐ, ๊ฐ์ฅ ์งง์ agent๋ก๋
CentipedeFour
๋ถํฐ ๊ฐ์ฅ ๊ธด agent๋ก๋CentipedeFourty
๋ก ๋ค๋ฆฌ๊ฐ 40๊ฐ๊น์ง(20์) ์๋ ํ๊ฒฝ์ ๋ง๋ค์ ์์๋ค. disability๋ก ์ผ๋ถ ํํธ๊ฐ ์๋ํ์ง ์๋ ํ๊ฒฝ์Cp(Cripple)
๋ก ๋ฐ๋ก ํ๊ธฐํ๋ค. ์ด ํ๊ฒฝ์์ y-direction์ผ๋ก ๋นจ๋ฆฌ ์์ผ๋ก ๊ฐ๋๊ฒ ๋ชฉํ๋ค.snake -
swimmer
ํ๊ฒฝ์ ๊ธฐ๋ฐ์ผ๋ก ์ปค์คํ ํ์ผ๋ฉฐ ๊ฐ์ฅ ๋นจ๋ฆฌ ์งํ๋ฐฉํฅ์ผ๋ก ์์ง์ด๋ ๊ฒ ๋ชฉํ๋ค.
๋น๊ต๊ตฐ
NerveNet
: small agent๊ฐ ํ์ตํ ๋ชจ๋ธ์ ๋ฐ๋ก large agent์ ์ ์ฉํ ์ ์์๋ค. agent์ ๊ตฌ์กฐ๊ฐ ๋ฐ๋ณต์ ์ด๊ธฐ ๋๋ฌธ์ ๋ฐ๋ณต๋๋ ๋ถ๋ถ์ ๋ ๋๋ฆฌ๊ธฐ๋ง ํ๋ฉด ๋๊ธฐ ๋๋ฌธ์ด๋ค.MLP Pre-trained (MLPP)
: agent์ ํฌ๊ธฐ๊ฐ ์ปค์ง์ ๋ฐ๋ผ input size๊ฐ ๋ฌ๋ผ์ง๋ฏ๋ก ๊ฐ์ฅ straightforwardํ๊ฒ ์ฒซ๋ฒ์งธ hidden layer๋ฅผ ๊ทธ๋๋ก output layer๋ก ์ฌ์ฉํ๊ณ input layer์ ์ฌ์ด์ฆ๋ง ํค์์ ์ถ๊ฐํ๊ณ ์ด input layer๋ ๋๋ค ์ด๊ธฐํ๋ฅผ ํด์ค๋ค.MLP Activation Assigning (MLPAA)
: small agent์ weight๋ค์ ๋ฐ๋ก large agent์ ๋ชจ๋ธ์ ๋ฃ์ด์ฃผ๊ณ weight๋ค์ ๋จ๋ ๋ถ๋ถ๋ค์ 0์ผ๋ก ์ด๊ธฐํ ํด์ค๋ค.TreeNet
: MLPAA์ฒ๋ผ ์ค์ผ์ผ์ ํค์์ 0์ผ๋ก ์ด๊ธฐํ ํด์ค๋ค.Random
: action space์์ uniformlyํ๊ฒ ์ํ๋ง์ ํ๋ policy์ด๋ค.
Result
Centipede
1-1. Pretraining
- 6-๋ค๋ฆฌ ๋ชจ๋ธ๊ณผ 4-๋ค๋ฆฌ ๋ชจ๋ธ๋ก
NerveNet
,MLP
,TreeNet
์์์ ํผํฌ๋จผ์ค๋ฅผ ๋น๊ตํ๋ค. ์ฌ๊ธฐ์ 3๊ฐ์ ๋ชจ๋ธ์ ์์ benchmark ๋น๊ต ์คํ์์ ์ฌ์ฉํ ๋น๊ต๊ตฐ๋ค๊ณผ ๋์ผํ๋ค.
- 4-๋ค๋ฆฌ ๋ชจ๋ธ์์๋ NerveNet์ด ๊ฐ์ฅ Reward๊ฐ ๋๊ณ , 6-๋ค๋ฆฌ ๋ชจ๋ธ์์๋ MLP๊ฐ ๊ฐ์ฅ Reward๊ฐ ๋์์ ์ ์ ์๋ค. TreeNet์ ๋ ํ๊ฒฝ ๋ชจ๋์์ ๊ฐ์ฅ ๋ฎ๋ค.
- 6-๋ค๋ฆฌ ๋ชจ๋ธ๊ณผ 4-๋ค๋ฆฌ ๋ชจ๋ธ๋ก pretraining์ ์งํํ ํ transferable์ ์คํํ๋ค.
1-2. Zero-shot
- fine tuning ์์ด ํผํฌ๋จผ์ค๋ฅผ ์ธก์ ํ๋ค.
- ํผํฌ๋จผ์ค๋ฅผ ์ฝ๊ฒ ๋น๊ตํ ์ ์๋๋ก
average reward
์average running-length
๋ฅผ normalizationํด์ ์์ผ๋ก ์๋์ ๊ฐ์ด ํํํ๋ค.(green-good, red-bad)
- ๋์ผ๋ก ํ์คํ ํ์ธํ ์ ์๋ฏ์ด NerveNet์ ํผํฌ๋จผ์ค๊ฐ ๋ค๋ฅธ ๋น๊ต๊ตฐ์ ๋นํด ์๋ฑํ transferableํจ์ ์ ์ ์์๋ค.
- ๋ํ learning curve์์ ๋ณผ ์ ์๋ฏ์ด NerveNet+Pretrain ์ด ๋ค๋ฅธ Pretrain ๋น๊ต๊ตฐ๋ค์ ๋นํด ํจ์ฌ ๋์ reward ์์์ ์์ ์์ํ๊ณ ๋ ์ ์ timestep์ผ๋ก solved ์ ์์ ๋๋ฌํ๋ ๊ฒ์ ๋ณด์ ๊ทธ๋ํ์ ๊ตฌ์กฐ์ ์ด์ ์ ํ์คํ ํ์ฉํ๊ณ ์์์ ์ ์ ์๋ค.
- NerveNet์ agent๋ค์ ๋ค๋ฅธ ๋น๊ต๊ตฐ agent๋ค์์ ๋ณด์ด์ง ์๋
walk-cycle
์ ๊ฐ์ง๊ณ ์์์ ํ์ธํ ์ ์์๋๋ฐ, ์ด๋ ๋ณดํ ๋ก๋ด๋ค์ ๊ฑธ์์์์ ๋ฐ๋ณต์ ์ธ ์์ง์์ ํ๊ฒ ๋์ด ์๊ธฐ ๋๋ฌธ์ ์์ฐ์ค๋ฝ๊ฒ cycle์ ๊ฐ์ง๊ฒ ๋๋ ๊ฒ์ agent๊ฐ ํ์ตํ์์ ์ ์ ์๋ค. (๋ฐ๋ฉด MLP๋ 8-๋ค๋ฆฌ ๋ชจ๋ธ์์ ๋ชจ๋ ๋ค๋ฆฌ๋ฅผ ์์ง์ด์ง ์๋ ๋ชจ์ต์ ๋ณด์ด๊ธฐ๋ ํ๋ค.)
- 6-๋ค๋ฆฌ ๋ชจ๋ธ๊ณผ 4-๋ค๋ฆฌ ๋ชจ๋ธ๋ก
Snake
- snakeํ๊ฒฝ์์๋ NerveNet์ด ๋ค๋ฅธ ๋น๊ต๊ตฐ๋ค์ ๋นํด ๋ฐ์ด๋ reward ์ ์๋ฅผ ๋ณด์ฌ์ฃผ๋ฉฐ transferable ํจ์ ์๋์ ๋ํ์์์ฒ๋ผ ๋ณด์ฌ์ฃผ์๋ค.
- 350์ ์ ๋๊ฐ
snakeThree
์์ solved๋ ์ํ๋ผ๊ณ ๋ณผ ์ ์๋๋ฐ NerveNet์ ์์ ์ ์๋ค์ด ๋๋ถ๋ถ 300์ ๋์์ ์์ํ ๊ฒ์ผ๋ก ๋ณด์ ์ด๋ ์๋นํ zero-shot ์ญ๋์ด ์์์ ์ ์ ์๋ค. - ๋ค๋ฅธ ๋น๊ต๊ตฐ๋ค์ overfitting์ด ์ฌํด์ Random๋ณด๋ค ์์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ฃผ๋ ์ ๋ ํฅ๋ฏธ๋กญ๋ค.
- zero-shot ๋ฟ๋ง ์๋๋ผ fine tuning์ ํ๋ learning curve์์๋ NerveNet์ Pretrain์ ์ด์ ์ ๋ค๋ฅธ ๋น๊ต๊ตฐ๋ค์ ๋นํด ์ ํ์ฉํ๊ณ ์์์ ๋ณผ ์ ์์๋ค.
NerveNet+Pretrain
์ ์์ reward๊ฐ ๋์ผ๋ฉฐ, ํน์ size transfer ์คํ์์๋ scratch NerveNet์ด ๋์ง ๋ชปํ MLP ์ ์๋ฅผNerveNet+Pretrain
์ด ๋ฐ๋ผ์ก์๋ค.
3. Multi-task learning
NerveNet์ ๋คํธ์ํฌ์ structure prior๋ฅผ ํฌํจํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ multi-task learning์ ์ ๋ฆฌํ ์ ์๋ค. ๋ฐ๋ผ์ ์ด๋ฅผ ์คํํ๊ธฐ ์ํด Walker
multi-task learning์ ์งํํ๋ค.
- 2d-walker ํ๊ฒฝ๋ค 5๊ฐ -
Walker-HalfHumanoid
,Walker-Hopper
,Walker-Horse
,Walker-Ostrich, Walker-Wolf
- 1๊ฐ์ ํตํฉ๋ network๋ก ํ์ต
๋น๊ต๊ตฐ
NerveNet
: agent๋ค์ ํํ๊ฐ ๋ฌ๋ผ weight๋ค์ด ๋ค๋ฅผ ์ ๋ฐ์ ์๊ธฐ ๋๋ฌธ์ propagation๊ณผ์ ์์์ weight matrices์ output๋ง ๊ณต์ ํ๋ค.MLP Sharing
: hidden layer๋ค ๊ฐ์ weight matrices ๋ฅผ ๊ณต์MLP Aggregation
: ์ฐจ์์ด ๋ค๋ฅธ observation๋ค์ aggregation๊ณผ์ ์ ํตํด ์ฒซ๋ฒ์งธ hidden layer์ ํฌ๊ธฐ๋ก ๋ค ๋ง์ถฐ์ฃผ์ด์ input์ผ๋ก ๋ฃ์ด์คTreeNet
: TreeNet๋ weight๋ฅผ ๊ณต์ ๋ฅผ ํ ์ ์์ง๋ง agent์ ๊ตฌ์กฐ์ ์ธ ์ ๋ณด๋ ์ ์ ์๋ค. ๋จ์ํ root node๋ฅผ ์ค์ฌ์ผ๋ก ๋ชจ๋ ๋ ธ๋์ ์ ๋ณด๋ค aggregation ๋๊ธฐ ๋๋ฌธ์ด๋ค.MLPs
: ๊ฐ agent๋ง๋ค ๋ฐ๋ก MLP policy๋ฅผ ๋ง๋ค์ด์ ํ์ต(single-task)
Result - multi-task learning ์คํ์ด๊ธฐ ๋๋ฌธ์ ํ ๋๊ฐ ๋ฌ๋ ๊ทธ๋ํ๋ง ๋ณผ ์ ์๊ณ 5๊ฐ์ ๋ฌ๋ ๊ทธ๋ํ๋ฅผ ๊ฐ์ด ๋ด์ผ ํ๋ค. - Single-task policy๋ฅผ ์ ์ธํ๊ณ ๋ชจ๋ ํ๊ฒฝ์์ NerveNet์ ํผํฌ๋จผ์ค๊ฐ ์ข์์ ์ ์ ์๋ค.
- ํ ์ด๋ธ์์ Ratio๊ฐ single-task policy์ ๋นํด multi-task policy์ ์ฑ๋ฅ์ percentage๋ก ๋ํ๋ธ ์์น์ธ๋ฐ, MLP์ ํผํฌ๋จผ์ค๊ฐ single-task์์ multi-task๋ก ๋์ด๊ฐ์ ๋ 42%๋ ํผํฌ๋จผ์ค๊ฐ ์ค์ด๋๋ ๊ฒ์ ํ์ธํ ์ ์๋ค. (Average-58.6%) ๋ฐ๋ฉด์ NerveNet์ ์ฑ๋ฅ์ด ์ ํ ๋จ์ด์ง์ง ์๋ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ฃผ์๋ค.
4. Robustness of learnt policies
๊ฐํํ์ต ์ ์ด์์ robustness๋ ์ค์ํ ์งํ์ธ๋ฐ ์ง๋์ด๋ ํ๊ณผ ๊ฐ์ ๋ฌผ๋ฆฌ์ ์ธ ๊ฐ๋ค์ ์ค์ฐจ ๋ฒ์๊ฐ ์ด๋์ ๋๊น์ง policy๊ฐ ํ์ฉํ๊ณ ์ ์๋ํ๋์ง๋ฅผ ํ์ธํด์ผ ํ๋ค.
- 5๊ฐ์ Walker ๊ทธ๋ฃน์ ํ๊ฒฝ์์ ์คํ
- pretrained agent๋ฅผ ๊ฐ์ง๊ณ agent์ ์ง๋๊ณผ joint์ strength์ ๋ณ๊ฒฝํ ๋ค ํผํฌ๋จผ์ค ์ธก์
- ๋๋ถ๋ถ์ ํ๊ฒฝ๊ณผ variation์์ NerveNet์ robustness๊ฐ MLP๋ณด๋ค ์ข์์ ์ ์ ์๋ค.
5. Interpreting the learned representations
์ค์ ํด๋ฆฌ์๋ค์ด ์ด๋ค representation๋ค์ ํ์ตํ๋์ง ์์๋ณด๊ธฐ ์ํด CentipedeEight
ํ๊ฒฝ์์ ํ์ต๋ agent์ final state vector๋ฅผ ๊ฐ์ง๊ณ 2D, 1D PCA๋ฅผ ์งํํ๋ค.
๊ฐ ๋ค๋ฆฌ์๋ค(Left Hip-Right Hip)๋ค์ agent์ ์ ์ฒด ๋ชธ์ฒด์์ ๊ฐ๊ธฐ ๋ค๋ฅธ ์์น์ ์์์๋ ๋ถ๊ตฌํ๊ณ invariant representation์ ๋ฐฐ์ธ ์ ์์์์ PCA๋ฅผ ํตํด์ ์ ์ ์์๋ค.
๋ํ ์์ Centipede transfer learning ์คํ ๊ฒฐ๊ณผ์์๋ ์ ๊น ์ธ๊ธํ๋ walk-cycle
์ด ์ฃผ๊ธฐ์ฑ์ด ๋๋ ทํ๊ฒ ๋ณด์๋ค.
6. Comparison of model variants
Value Network๋ฅผ ์ด๋ป๊ฒ ํ ๊ฒ์ธ์ง์ ๋ฐ๋ผ NerveNet์ ์ฌ๋ฌ ๋ณํ์ด ์์ ์ ์๋๋ฐ Swimmer
, Reacher
, HalfCheetah
์์ ๋น๊ตํด๋ณธ ๊ฒฐ๊ณผ, Value Network๋ MLP๋ก ํ NerveNet-MLP
์ ํผํฌ๋จผ์ค๊ฐ ๊ฐ์ฅ ์ข์๊ณ NerveNet-1
์ ํผํฌ๋จผ์ค๊ฐ 2๋ฑ์ผ๋ก NerveNet-MLP
์ ๋น์ทํ๋ค. ์ด์ ๋ํ ์ ์ฌ์ ์ธ ์ด์ ๋ก value network์ policy network๊ฐ weight๋ฅผ ๊ณต์ ํ๋ ๊ฒ์ด PPO ์๊ณ ๋ฆฌ์ฆ์์์ trust-region based optimitaion์์์ weight \alpha๋ฅผ ๋ sensitiveํ๊ฒ ๋ง๋ค๊ธฐ ๋๋ฌธ์ด๋ผ๊ณ ์ถ๋ก ํ ์ ์๋ค.
Conclusion
- NerveNet์ด๋ผ๋ ๊ทธ๋ํ ๊ตฌ์กฐ๋ฅผ ํ์ฉํ policy๋ฅผ ๊ฐ์ง๊ณ RL agent์ body structure๋ฅผ ํ์ฉํ ์ ์๋ ์๊ณ ๋ฆฌ์ฆ์ ์ ์
- ๊ฐ body์ joint์ observation์ ๋ฐ์ GNN์ ํตํด non-linear message๋ค์ ๊ณ์ฐํ๊ณ propagationํ๋ ๋ชจ๋ธ๋ง
- propagation์ ์ฃ์ง๋ก ํํ๋ joint๊ฐ์ ๋ฌผ๋ฆฌ์ ์ผ๋ก ์ฐ๊ฒฐ์ฑ์ ๊ฐ์ง๊ณ ๋ณธ๋ ์๋ ์์กด์ฑ์ ๊ธฐ๋ฐ์ผ๋ก ์ด๋ฃจ์ด์ง
- ์คํ์ ์ผ๋ก NerveNet์ด MoJoCo ์๋ฎฌ๋ ์ดํฐ ๊ธฐ๋ฐ ์ฌ๋ฌ ํ๊ฒฝ๋ค์์ MLP ๊ธฐ๋ฐ SOTA ์๊ณ ๋ฆฌ์ฆ๋ค๊ณผ ๊ฒฌ์ค๋งํ ํผํฌ๋จผ์ค๋ฅผ ๋ณด์ฌ์คo state-of-the-art methods on standard MuJoCo environments.
- ๋ช๊ฐ์ง ํ๊ฒฝ์ ์ปค์คํ ํด์ size์ disability transfer๋ฅผ ๊ฒ์ฆํ์ผ๋ฉฐ zero-shot setting์์๋ transferableํจ์ ๋ณด์
Review
๋ ผ๋ฌธ ๋ฆฌ๋ทฐํ์ ์ฃผ๊ด์ ์ธ ์ฅ๋จ์ ์ ์ ๋ฆฌํ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
- Pros ๐
- ๋ก๋ด์ ๊ตฌ์กฐ์ ์ธ ํน์ง์ ๊ธฐ๋ฐ์ผ๋ก ํจ์จ์ ์ธ feature embedding์ด์์์ ๋ณด์ฌ์ค
- ๋ค์ํ robot configuration์์๋ ์ ์๋ํจ
- ๋ชจ๋ธ์ ํ์ฅ์ฑ์ ์ค๋ช ํ ์ ์๋ Transfer learning๊ณผ Multi-task learning์ด ์ธ์์ ์ด์๊ณ ํฐ ์ฅ์ ์ด๋ผ๊ณ ์๊ฐ
- Cons ๐
- ์๋ฎฌ๋ ์ด์ ์์๋ง ์คํํ๋ค๋ ์ ์ด ์์ฌ์
- ์๊ฐ๋ณด๋ค ๊ธฐ๋ณธ์ ์ธ gnn๋ชจ๋ธ์ด๋ผ์ edge์ ๋ํ ํฐ ๋์์ธ ์์๊ฐ ๋ค์ด๊ฐ์ง ์์ ๊ฒ ๊ฐ์
- ๋ค์ํ RL ์๊ณ ๋ฆฌ์ฆ๋ค๊ณผ์ ์๋์ง๋ฅผ ๋ณด๊ธฐ์๋ ์์งํ ๋ ผ๋ฌธ์ ์์ด ๋๋ฌด ๋ฐฉ๋ํด์ง ๊ฒ ๊ฐ๊ธดํ์ง๋ง ์ด์ ๋ํ ๋น๊ต๊ฐ ์์์ผ๋ฉด ์ข์์ ๊ฒ ๊ฐ์
Reference
- Original Project Homepage: http://www.cs.toronto.edu/~tingwuwang/nervenet.html
- Code
- Official: https://github.com/WilsonWangTHU/NerveNet
- Not official: https://github.com/HannesStark/gnn-reinforcement-learning