๐GN-Block
์ด๋ฒ post๋ Graph Networks as Learnable Physics Engines for Inference and Control ๋ผ๋ ๋ ผ๋ฌธ์ ์ฝ๊ณ ๋ฆฌ๋ทฐํ ๋ด์ฉ์ ๋๋ค.
Abstract
Understanding and interacting with everyday physical scenes requires rich knowledge about the structure of the world, represented either implicitly in a value or policy function, or explicitly in a transition model. Here we introduce a new class of learnable modelsโ
based on graph networksโwhich implement an inductive bias
for object- and relation-centric representations of complex, dynamical systems. Our results show that as aforward model
, our approach supports accurate predictions from real and simulated data, and surprisingly strong and efficient generalization, across eight distinct physical systems which we varied parametrically and structurally. We also found that ourinference model
can perform system identification. Our models are alsodifferentiable, and support online planning via gradient-based trajectory optimization
, as well asoffline policy optimization
. Our framework offers new opportunities for harnessing and exploiting rich knowledge about the world and takes a key step toward building machines withmore human-like representations of the world
.
1. Introduction
์ฌ๋์ ๊ฑธ์๋ ๋ง์ฐฐ๋ ฅ, ์์ฉ ๋ฐ์์ฉ ๋ฒ์น์ ์๊ฐํ๋ฉด์ ๊ฑท์ง ์์ต๋๋ค. ๋ง์ ๋ฌผ๋ฆฌ ๋ฌธ์ ๋ค๊ณผ ๋ฒ์น๋ค์ ์ดํดํ๊ธฐ ์ด๋ ต์ง๋ง ๋ณต์กํ ๋ฌผ๋ฆฌ์ ์์ฉ๋ค์ด ์ผ์ด๋๋ ๊ฑท๋ ํ๋์ ์์ด์ ์ด๋ ค์์ด ์์ต๋๋ค. ์ฌ์ค ์ฌ๋์ด ํ์ด๋์ ์์์ด ๋ง์ ๊ฒฝํ๋ค์ ๋์ ์ผ๋ก ์ฐ๋ฆฌ๋ ํฌ๊ฒ ์ ๊ฒฝ์ฐ์ง ์์๋ ์ด๋ป๊ฒ ํ์ ์ฃผ๋ฉด ๋ค๋ฆฌ๊ฐ ์์ง์ด๋ ์ง ์๊ณ ์๊ธฐ ๋๋ฌธ์ ๋๋ค. ์ด์ฒ๋ผ ์ธ๊ณต์ง๋ฅ๋ ์ด๋ป๊ฒ ํ๋ฉด ๋ณต์กํ ์์คํ ์ ์ดํดํ๊ณ ์ ์๋ํ ์ ์์๊น? ๋ผ๋ ๋ฌผ์์ GN-Block์ด๋ผ๋ Graph ์์ด๋์ด๋ก ํด๊ฒฐํ ์ ์๋ค๊ณ ์ฃผ์ฅํ๋ ๋ ผ๋ฌธ์ ๋๋ค.
โ How can an intelligent agent understand and control such complex systems?
์ธ๊ณต์ง๋ฅ์ด ์ด๋ ๊ฒ ์ฌ๋์ด ์์ฐ์ค๋ฝ๊ฒ ์ตํ๋ ์ธ์์ ๋ฌผ๋ฆฌ์ ์ธ ํ์์ ์ดํดํ๊ณ ์ํธ์์ฉ ์์ฉํ๋ ค๋ฉด ์์์ ์ด๋ ๋ช
์์ ์ผ๋ก๋ ์ธ๊ณ์ ๋ํ ํ๋ถํ(rich) ํํ๊ณผ ์ง์์ด ํ์ํฉ๋๋ค. ๋ค์ ๋งํด, ์์คํ
์ ์๋ objects๋ค๊ณผ objects๊ฐ์ ๊ด๊ณ๋ฅผ ํํํด์ ๋์ผ object์๋ ๋์ผํ object-wiseํ ๊ณ์ฐ์, ์ด๋ค ์ฌ์ด์ ์ผ์ด๋๋ interation๋ค์ ๋ํด์๋ relation-wise ๊ณ์ฐ์ ์ ์ฉํด์ ํ์ต์ ํ ์ ์์ด์ผ ํฉ๋๋ค. ๋ง์น ๋ ๊ณ ๋ธ๋ญ๋ค ํ๋ํ๋๋ฅผ ์ดํดํ๊ณ ์ด๋ป๊ฒ ํ๋ฉด ์ฑ์ ์์ ์ ์๋์ง ์๋ ๊ฒ์ฒ๋ผ combinatorial generalization
๋ฅ๋ ฅ์ ๊ฐ์ง ์ ์๋ ๊ฒ์ด๋ผ๊ณ ์๊ฐํด ๋ณผ ์ ์์ต๋๋ค.
ํด๋น ๋
ผ๋ฌธ์ ๋ชฉํ๋ physical dynamics models์ ๊ทธ๋ํ ๊ธฐ๋ฐ์ ๋ฐฉ๋ฒ์ผ๋ก ํ์ตํ๋ ๊ฒ์
๋๋ค. Graph Neural Network(GN)์ node update function
์ ๊ฐ์ง๊ณ body์ dynamics์ ๋ํ ํ์ต์ ํ ์ ์๊ณ , edge update function
์ ๊ฐ์ง๊ณ interaction์ dynamics๋ฅผ ์ธ์ฝ๋ฉํ ์ ์์ผ๋ฉฐ, global update function
์ ๊ฐ์ง๊ณ global system์ ์์ฑ๋ค์ ์ธ์ฝ๋ฉ ํ ์ ์๋ค๊ณ ์ฃผ์ฅํฉ๋๋ค. ์ฌ๊ธฐ์ ๋ค๋ฅธ ๋
ผ๋ฌธ๋ค๊ณผ ๋ค๋ฅด๊ฒ ํน์ดํ ์ ์ globalํ ์์คํ
์ ์์ฑ์ด๋ผ๋ ๋ถ๋ถ์ ๋ฐ๋ก ๊ณ ๋ ค๋ฅผ ํ๋ค๋ ์ ์ด๋ผ๊ณ ํ ์ ์์ต๋๋ค.
๋
ผ๋ฌธ์ contribution์ ํฌ๊ฒ 3๊ฐ์ง forward model, inference model, control algorithm ์
๋๋ค. (ํ์ง๋ง ๋ฆฌ๋ทฐํ๋ฉด์ ๋๋ ์ ์ control algorithm ๋ถ๋ถ์ contribution์ด๋ผ๊ณ ํ๊ธฐ๋ณด๋ค๋ GN-based model์ ๊ฐ์ง๊ณ control pipeline์ ์ ๋ถ์ธ ๊ฒ์ด๋ผ๊ณ ์๊ฐ๋ฉ๋๋ค.) ๋ค๋ฅธ physics engine๋ค๊ณผ ๋ค๋ฅด๊ฒ ๋ฌผ๋ฆฌ๋ฒ์น์ ๋ํ ์ฌ์ ์ง์(prior knowledge)์ ์ ํ ๊ฐ์ ํ์ง ์์ง๋ง ๋์ object- and relation-centric inductive bias
๋ฅผ ์ด์ฉํ์ฌ current-state/next-state pairs์ ๋ํ ํ์ต์ ํฉ๋๋ค. ๊ทธ๋ํ ๊ธฐ๋ฐ์ผ๋ก ๋ฌผ๋ฆฌ ์์คํ
์ forward model๊ณผ inference model์ด ํ์ตํ๊ฒ ๋๋ฉด control algorithm์ ์ด ๋ชจ๋ธ๋ค์ ์ด์ฉํ์ฌ planning์ด๋ policy learning์ ํ๊ฒ ๋ฉ๋๋ค.
Model | Role |
---|---|
GN-based forward models | ์ ํํ๊ณ ์ผ๋ฐํ๋ prediction์ ํ ์ ์์ |
GN-based inference models | observation์ ์จ๊ฒจ์ ธ ์๋ ์์ฑ๋ค์ ๊ธฐ๋ฐ์ผ๋ก system identification์ ํ ์ ์์ |
NOT GN-based control algorithms | ๋ค๋ฅธ ๋ฒ ์ด์ค๋ผ์ธ๋ค๋ณด๋ค ์ข์ control ํผํฌ๋จผ์ค๋ฅผ ๋ณด์ฌ์ค |
2. Model
Graph representation of a physical system
๋ฌผ๋ฆฌ์์คํ ์ ์ด๋ป๊ฒ ๊ทธ๋ํ๋ก ๋ํ๋ผ ์ ์๋์ง ๋ช๊ฐ์ง ์ฉ์ด์ ์์๋ค์ ์ ๋ฆฌํด๋ณด๊ฒ ์ต๋๋ค.
- ๋ฌผ๋ฆฌ ์์คํ
์ body๋ ๊ทธ๋ํ์
node
๋ก ํํํฉ๋๋ค. - ๋ฌผ๋ฆฌ ์์คํ
์ joint๋ ๊ทธ๋ํ์
edge
๋ก ํํํฉ๋๋ค. - ๋ฌผ๋ฆฌ ์์คํ
์ globalํ ์์ฑ์
global feature
๋ก ํํํฉ๋๋ค.
์๋ ์ฌ์ง์์ ๋ณด์ด๋ half-cheetah์์ ์ง๊ด์ ์ผ๋ก ์ด๋ป๊ฒ ๊ทธ๋ํ๊ฐ ๊ทธ๋ ค์ง ์ ์๋์ง ์ ์ ์๊ณ ์ด ๊ทธ๋ํ๋ฅผ G๋ก ๋ํ๋ผ ์ ์์ต๋๋ค.
์์ ์ค๋ช ํ ๋ถ๋ถ์ ์์์ผ๋ก ๋ํ๋ด๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
G=\left(\mathbf{g},\left\{\mathbf{n}_{i}\right\}_{i=1 \cdots N_{n}},\left\{\mathbf{e}_{j}, s_{j}, r_{j}\right\}_{j=1 \cdots N_{e}}\right)
- g : global features ์์คํ ์ ์ค๋ ฅ์ด๋ time step๊ณผ ๊ฐ์ ์์ฑ์ ๋ํ๋ด๋ ๋ฒกํฐ์ ๋๋ค.
- \mathbf{n}_{i} : node features๋ฅผ ๋ํ๋ด๋ ๋ฒกํฐ์ ๋๋ค.
- \mathbf{e}_{j} : edge features๋ฅผ ๋ํ๋ด๋ ๋ฒกํฐ์ ๋๋ค.
- s_{j} : ์ด edge๋ฅผ ํตํด์ message๋ฅผ ๋ณด๋ด๋ sender nodes์ ์ธ๋ฑ์ค์ ๋๋ค.
- r_{j} : ์ด edge๋ฅผ ํตํด์ message๋ฅผ ๋ฐ๋ receiver nodes์ ์ธ๋ฑ์ค์ ๋๋ค.
Static & Dynamic properties
์ฌ๊ธฐ์ static graph G_s์ dynamic graph G_d ๋ผ๋ ๊ทธ๋ํ๋ 2๊ฐ์ง ์ข ๋ฅ๊ฐ ์์ต๋๋ค. ์ด 2๊ฐ์ ๊ทธ๋ํ๋ ๊ฐ๊ฐ ์์คํ ์ ์์ฑ์ด ์๊ฐ์ ๋ฐ๋ผ ๋ณํํ๋์ง(dynamic/time-variant) ์ํ๋์ง(static/time-invaritant)์ ๋ฐ๋ผ ๊ทธ๋ํ๋ฅผ ๊ตฌ์ฑํ๋ ์ ๋ณด์ ์ข ๋ฅ๊ฐ ๋ค๋ฆ ๋๋ค.(์์ธํ ์ ๋ณด๋ Appendix G section์์ Mujoco ๊ธฐ๋ฐ์ ์ด๋ค ์ ๋ณด๋ก ๊ฐ ๊ทธ๋ํ๋ฅผ ๊ตฌ์ฑํ๋์ง ๋์์์ต๋๋ค.)
- A static graph G_s: ์์คํ
์ staticํ ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์๋ ๊ทธ๋ํ
global parameters
: the time step, viscosity, gravity, etcbody/node parameters
: mass, inertia tensor, etc.joint/edge parameters
: joint type๊ณผ properties, motor type and properties, etc
- A dynamic graph G_d: ์์คํ
์ ์ผ์์ ์ธ state ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์๋ ๊ทธ๋ํ
body/node
: 3D Cartesian position, 4D quaternion orientation, 3D linear velocity, 3D angular velocityjoint/edge
: joint์ ์ ์ฉ๋ action๋ค์ ํฌ๊ธฐ
Graph networks
graph2graph
๋ชจ๋์ ํ์ฉํ์ฌ ์ธํ์ ๊ทธ๋ํ๋ก ๋ฐ๊ณ ์์ํ๋ ๊ทธ๋ํ๋ก ๋ฐ๋ ๋ชจ๋ธ์ ๋๋ค. ๋ฐ๋ผ์ ์์ํ์ ๊ทธ๋ํ๋ ์ธํ ๊ทธ๋ํ์ ๋ค๋ฅธ edge, node, global features๋ฅผ ๊ฐ์ง๊ฒ ๋ฉ๋๋ค.
๋ณธ ๋ ผ๋ฌธ์ ํต์ฌ ์์ด๋์ด์ธ GN ๋ธ๋ก์ ๊ตฌ์กฐ์ ๋ํด ์์๋ณด๊ฒ ์ต๋๋ค. - A core GN block
- 3๊ฐ์ sub function, MLP๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค.
- edge-wise $f_e$ : ๋ชจ๋ edge๋ค์ ๋ํ update๋ฅผ ์งํํฉ๋๋ค.
- node-wise $f_n$ : ๋ชจ๋ node๋ค์ ๋ํ update๋ฅผ ์งํํฉ๋๋ค.
- global $f_g$ : ๋ง์ง๋ง์ผ๋ก global feature๋ค์ update ํฉ๋๋ค.
ํ๋์ feedforward GN pass๋ ๊ทธ๋ํ ์์์ message-passing ๋จ๊ณ์ ํ ์คํ ์ผ๋ก ๊ฐ์ฃผํ ์ ์์ต๋๋ค. ์ด๋ฌํ GN-block ๋ด์์์ ์๊ณ ๋ฆฌ์ฆ์ ์๋์ ๊ฐ์ต๋๋ค.
Forward models
Forward model์ ๋ชฉ์ ์ ํ์ฌ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ค์ step์ ์ํ๋ฅผ ์์ธก(prediction)ํ๋ ๊ฒ์ ๋๋ค. (์ด๋ ์์ด ๋จ์ด์ ๋น์ทํ ์๋ฏธ๋๋ฌธ์ ๋ค์์ ๋์ค๋ inference model์ ๋ชฉ์ ๊ณผ ๋ง์ด ํผ๋๋ ์ ์์ผ๋ ์ ์ ์ํ๊ณ ๋์ด๊ฐ๋ ๊ฒ์ด ์ข์ต๋๋ค.) forward model์ RNN(GRU)๋ฅผ ๋์ ํ๋์ง ์ฌ๋ถ์ ๋ฐ๋ผ 2๊ฐ์ง ํ์ ์ด ์์ต๋๋ค.
Type1. GNN feed-forward
๊ฐ์ฅ ๊ฐ๋จํ GNN feed-forward ๋ชจ๋ธ์ ๋๋ค. ๊ทธ๋ํ๋ ์ฒ์์ GN_1์ ๊ฑฐ์ณ latent graph์ธ G'์ด ๋ฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ๋ค์ GN_2์ ์ธํ์ผ๋ก๋ GN_1์ ๊ฑฐ์น๊ธด ์ ์ ๊ทธ๋ํ์๋ G์ G'๋ฅผ concatenate๋ฅผ ํด์ ๋ฃ์ด์ฃผ๊ฒ ๋ฉ๋๋ค. ์ ์๋ค์ ์ด๋ ๊ฒ ๋์์ธํ ์ด์ ๋ก, ๊ทธ๋ํ์ ๋ชจ๋ ๋ ธ๋๋ค๊ณผ ์ฃ์ง๋ค์ด ๋ชจ๋ communicateํ๊ฒ ํ๊ธฐ ์ํจ์ด๋ผ๊ณ ์ด์ผ๊ธฐํฉ๋๋ค. ์ด๋ ๊ฒ GN_1, GN_2๋ฅผ ๊ฑฐ์ณ ์ต์ข ์ ์ผ๋ก ๋์ค๋ G^*์ node feature๋ค์ด ๊ฐ body์ ์ํ prediction ๊ฐ์ด ๋๋ ๊ฒ ์ ๋๋ค.
Type2. RNN+GNN
๋ค์์ผ๋ก ์์ ๊ธฐ๋ณธ์ด ๋๋ ๋ชจ๋ธ์ G-GRU๋ฅผ ์ถ๊ฐํ ํ์ ๋๋ค. Type 1๊ณผ ๋น์ทํ๊ฒ skip connection, latent graph๋ฅผ ๋ชจ๋ ์ฌ์ฉํ๋๋ฐ GN block์ GRU ๋ฒ์ ผ์ธ G-GRU๊ฐ ๋ค์ด๊ฐ๋ฉด์ G_h๋ผ๋ RNN์์ hidden vector์ ๊ฐ์ ๊ฐ๋ ์ hidden graph๊ฐ ์ถ๊ฐ๋ ๊ฒ์ ๋๋ค. ๋ชจ๋ edge, node, global feature๋ค์ ๋ํด ๊ฐ๊ฐ RNN์ด ์ ์ฉ๋์ด ์ด 3๊ฐ์ RNN sub-modules์ด ์์ต๋๋ค.
๋๊ฐ์ง ํ์ ์ GNN forward ๋ชจ๋ธ์ ๊ณตํต์ ์ธ ์ฌํญ
state differences
๋ฅผ ์์ธกํ๋ ๊ฒ์ ํ์ตํด์ state prediction์ ์ ๋๊ฐ(absolute)์ ๊ณ์ฐํฉ๋๋ค. ์ด ๊ณ์ฐ๋ absolute state prediction์ ๊ฐ์ง๊ณ state๋ฅผ updateํ๊ฒ ๋๋ ๊ฒ์ ๋๋ค.long-range rollout
trajectory๋ฅผ ๋ง๋ค์ด๋ด๊ธฐ ์ํด์ state prediction ๊ฐ๊ณผ control input์ ๋ฐ๋ณต์ ์ผ๋ก model์ ๋ฃ์ด์ฃผ์ด์ ์ฌ๋ฌ ์คํ ์ trajectory๋ฅผ ์์ฑํ๊ฒ ๋ฉ๋๋ค.GN model์ ์ธํ๊ณผ ์์ํ๋ค์ normalize ๋ฉ๋๋ค.
์ฌ์ค ๋ฆฌ๋ทฐ๋ฅผ ํ๋ฉด์ forward model๊ณผ inference model ์ฌ์ด์ ๊ตฌ๋ถ์ด๋ ๋ชจ๋ธ์ ๊ตฌ์ฒด์ ์ธ ํ๋ก์ธ์ค ์ดํด๊ฐ pseudo algorithm์ ๋ณด๊ธฐ ์ ๊น์ง ์๋์ง ์์์ต๋๋ค. Appendix์ ๋์์์ด์ ์ ๋ณด์ง ์์ ํ๋ฅ ์ด ๋์ง๋ง ๋ ผ๋ฌธ์ ๊ฐ๋ ์ ๋๋ต์ ์ผ๋ก ์ดํดํ๊ณ ๋ ํ์๋ ๊ผญ line by line์ผ๋ก ๋ณด์๊ธธ ์ถ์ฒํฉ๋๋ค.
๋จผ์ forward model์ ํ์ต๊ณผ์ ์ ๋ณด์ฌ์ฃผ๋ pseudo algorithm ์ ๋๋ค. ๋ค์ํ๋ฒ ์ด ๋ชจ๋ธ์ ๋ชฉ์ ์ ์๊ธฐ์์ผ๋ณด์๋ฉด, ํ์ฌ ์ํ x^{t_0} ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก a^{t_0}์ ํจ๊ป ์ฃผ์ด์ก์ ๋, x^{t_0+1}์ ์์ธกํ๋ ๊ฒ์ ๋๋ค. ์์ ์ค๋ช ํ ๋ถ๋ถ๋ค์ธ, state์ ์์ฐจ๋ฅผ ํ์ตํ๋ ๋ถ๋ถ์ด๋ normalization ๋ฑ์ด ์๊ณ ๋ฆฌ์ฆ๋ด์ ์ ๋์์์ต๋๋ค.
๋ค์์ ํ์ต๋ forward model์ ๊ฐ์ง๊ณ ๋ค์ ์ํ์ธ x^{t_0+1}์ ์ด๋ป๊ฒ ์์ธกํ๋์ง ๋ณด์ฌ์ฃผ๋ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค.
๋ง์ง๋ง์ผ๋ก ๋ฐ๋ก ์ ์๊ณ ๋ฆฌ์ฆ๊ณผ ๋์ผํ๊ฒ ํ์ต๋ forward model์ ๊ฐ์ง๊ณ ๋ค์ ์ํ์ธ x^{t_0+1}์ ์ด๋ป๊ฒ ์์ธกํ๋์ง ๋ณด์ฌ์ฃผ๋ ์๊ณ ๋ฆฌ์ฆ์ด์ง๋ง inference model์์ ํ์ต๋ GN_p๋ฅผ ๊ฐ์ง๊ณ system identification
์ด ์ถ๊ฐ๋ ์ํ์์ ์ด๋ป๊ฒ ์๊ณ ๋ฆฌ์ฆ์ด ํ๋ฌ๊ฐ๋์ง ๋ณด์ฌ์ค๋๋ค.(์ด์ ์ ์๊ณ ๋ฆฌ์ฆ์์๋ system parameter p๋ผ๊ณ ํ์๋์๋ ๋ถ๋ถ์ด ๋์ฒด๋ ๊ฒ์
๋๋ค.)
Inference Models
Inference model์ ๋ชฉ์ ์ ํ ๋ง๋๋ก ํํํ์๋ฉด System identification
์ด๋ผ๊ณ ํ ์ ์์ต๋๋ค. System identification์ด๋ ๊ด์ฐฐํ ์ ์๋(unobserved) dynamic system์ ์์ฑ๋ค์ ๊ด์ฐฐ๋๋(observed) behavior(๋๋ ์ด๋ค ์์)๋ฅผ ๊ฐ์ง๊ณ ์ถ๋ก ํ๋ ๊ฒ์ ๋งํฉ๋๋ค. ์ฆ ์์์ ์ผ๋ก system์ ๊ตฌ์ฑํ๋ ์์๋ค์ (๋ช
์์ ์ด์ง ์์) ์ธก์ ํ๊ฑฐ๋ ๊ด์ฐฐํ ์ ์์ง๋ง latent representations์ ํตํด ์ถ๋ก ํ ์ ์์ต๋๋ค.
Inference model๋ Recurrent GN-based model ์
๋๋ค. forward ๋ชจ๋ธ๊ณผ ๋ค๋ฅธ ์ ์ผ๋ก๋ ์ค์ง trajectory์ dynamic states
๋ค๋ง input์ผ๋ก ๋ฐ์ต๋๋ค. ๋ฐ๋ผ์ dynamic state graph์ธ G_d์ control input์ ๋ฐ์ต๋๋ค. ์์ํ์ผ๋ก๋ ์ผ์ time step T์ดํ์ G^*(T)์ด ๋๋ฉฐ, ๋ณธ ๋
ผ๋ฌธ์์ ์ดํ ์คํํํธ์์ 20 step์ ์ฌ์ฉํ์ต๋๋ค.
inference model ํ์ต๊ณผ์ ์ pseudo ์๊ณ ๋ฆฌ์ฆ์ ์๋์ ๊ฐ์ต๋๋ค.
Control algorithm
control algorithm์์๋ ๊ทธ๋ํ ๊ธฐ๋ฐ์ด ์๋๊ณ ์์ ์ค๋ช ํ ๊ทธ๋ํ ๊ธฐ๋ฐ์ forward model๊ณผ inference model์ ์ ํ์ฉํด์ ์ด๋ป๊ฒ controlํ ์ ์์์ง๋ฅผ ๋ณด์ฌ์ค๋๋ค. ๋ณธ ๋ ผ๋ฌธ์์๋ ํฌ๊ฒ 2๊ฐ์ง control algorithm์ ์ฌ์ฉํ์ต๋๋ค. ๊ฐํํ์ต์ ์ฃผ๋ก ์ฐ๊ตฌํ๋ ์ ์ฅ์์ ๋ฆฌ๋ทฐํด๋ณด๋ฉด, ๋๋ถ๋ถ ๊ฐํํ์ต์ model-free ๊ธฐ๋ฐ์ ์๊ณ ๋ฆฌ์ฆ์ด ๋ง์ด ๋ฐ์ ํ๋๋ฐ GN๊ธฐ๋ฐ์ ๋ค์ ์ํ๋ฅผ ์์ธกํ ์ ์๋ model์ ๋ง๋ฆ์ผ๋ก์จ model-based ๊ธฐ๋ฐ์ ๊ฐํํ์ต ์๊ณ ๋ฆฌ์ฆ์ ์ ์ฉํ ์ ์๋ค๋ ๊ฒ์ด ๋งค์ฐ ํฅ๋ฏธ๋ก์ ์ต๋๋ค.
MPC(Model Predictive Control)
GN์ ๋ฏธ๋ถ ๊ฐ๋ฅํ๊ธฐ ๋๋ฌธ์ MPC๊ฐ์ gradient-based trajectory optimization ๋ฐฉ๋ฒ์ผ๋ก model-based planning์ ํ ์ ์์ต๋๋ค. ๋ํ์ ์ผ๋ก MPC๊ฐ ์๊ณ ํ์ต๊ธฐ๋ฐ์ด ์๋๋ผ ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ์ด๋ฉฐ ์๊ณ ๋ฆฌ์ฆ์ ํ๋ฆ์ ์๋์ ๊ฐ์ต๋๋ค.
SVG(Stochastic Value Gradients)
๊ฐํํ์ต ์๊ณ ๋ฆฌ์ฆ ์ค ํ๋์ด๋ฉฐ, GN-based model๊ณผ SVG์ policy function์ ๋์์ ํ์ตํ๋ agent๋ก control์ ํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. SVG(1)์ ํ ์คํ ์ ์์ธกํ๋ GN model์ ๊ฐ์ง๊ณ ๊ฐํํ์ต ์๊ณ ๋ฆฌ์ฆ์ผ๋ก control์ ํ ๊ฒ์ด๋ฉฐ(model-based) SVG(0)์ ์์ธกํ๋ GN model ์์ด model-free ๊ธฐ๋ฐ์ผ๋ก controlํ ๊ฒ์ผ๋ก ์ดํดํ์๋ฉด ๋ฉ๋๋ค.
์ฌ์ค MPC์ SVG๋ ๋งค์ฐ ๋น์ทํ ์ธก๋ฉด์ด ์์ต๋๋ค. MPC์์๋ control inputs๋ค์ด ํ ์ํผ์๋์์ ์ด๊ธฐ ์กฐ๊ฑด๋ค์ด ์ฃผ์ด์ก์ ๋ ์ต์ ํ ๋๋ ๊ฒ์ด๋ผ๋ฉด, SVG์์๋ state์ control์ ๋งค์นญ์ํค๋ policy function์ด ํ์ต๊ณผ์ ์์ ๊ฒฝํํ states์ ๋ํด์ ์ต์ ํ ๋๋ ๊ฒ์ ๋๋ค.
3. Methods
Environments
- MuJoCo ์๋ฎฌ๋ ์ด์
ํ๊ฒฝ์ ์ด์ฉํ๋ค.
- Pendulum, Cartpole, Acrobot, Swimmer, Cheetah, Walker2d, JACO(robotic arm)
- generated training data for our forward models by applying simulated random controls to the systems, and recording the state transitions
- generalization and system identification ์คํ์ ์ํด์
- created a dataset of versions of several of our systemsโPendulum, Cartpole, Swimmer, Cheetah and JACOโ with procedurally varied parameters and structure.
- ๋ณํ์ํจ ์์ฑ๋ค๋ก๋ link lengths, body masses, and motor gears. + varied the number of links in the Swimmerโs structure, from 3-15 (we refer to a swimmer with N links as SwimmerN )
MPC planning
- N-step trajectory(N: planning horizon)์ ๊ทธ trajectory๋ฅผ ์คํํ์ ๋ ๋ฐ์ total reward๋ฅผ GN forward ๋ชจ๋ธ๋ก ์์ธกํ๋ค.
- ์ด๋์ action sequences(=trajectory)๋ค์ total reward์ backpropagating gradient๋ฅผ ๊ฐ์ง๊ณ ์ต์ ํํ๊ฒ ๋๋ค.
Model-based reinforcement learning
- GN-based model์ ๊ฐํํ์ต์ ์ ์ฉํด๋ณด๋ค
- SVG๋ฅผ ์ด์ฉํ๋ค.
- GN forward model์ด ๋ฏธ๋ถ๊ฐ๋ฅํ๊ธฐ ๋๋ฌธ์ GN ๋ชจ๋ธ๋ก ์์ฑ๋ next state๋ฅผ ๊ฐ์ง๊ณ expected return ๊ฐ์ ๊ทธ๋ผ๋์ธํธ๋ฅผ ๊ตฌํ ์ ์๋ค.
- 1 step์ ์์ธกํ๋ SVG(1)๊ณผ
- model-free RL ๋ฒ ์ด์ค๋ผ์ธ๊ณผ ๋น๊ตํ๊ธฐ ์ํด SVG(0)๊ณผ deterministic policy ์๊ณ ๋ฆฌ์ฆ์ธ DDPG(Deep Deterministic Policy Gradients)์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ๋ค.
Baseline comparisons
- constant prediction baseline: input state๋ฅผ ๊ทธ๋๋ก output state๋ก ์ฌ์ฉ
- MLP baseline: GN model์ ์ฐ์ธ ๋ฐ์ดํฐ๋ฅผ ๊ทธ๋๋ก MLP์ flattened & concatenated ํด์ ํ์ต
- MPC baseline: ๋ฌผ๋ฆฌ ๋ชจ๋ธ์ ๊ฐ์ง๊ณ Differential Dynamic Programming algorithm์ ์ฌ์ฉํด์ ground truth ๊ฐ์ ๊ฐ์ง
- SVG(0): model-free RL agents
- DDPG: model-free RL agents
Prediction performance evaluation
- calculated independent errors for position, orientation, linear velocity angular velocity
- squared one-step dynamic state differences (one-step error)
- squared trajectory differences (rollout error) between the prediction and the ground truth.
4. Results
Prediction
Learning a forward model for a single system
ํ๋์ ์์คํ ์ ๊ฐ์ง๊ณ ํ์ตํ forward model์ Prediction ์ฑ๋ฅ ์ดํด๋ณด๊ธฐ
random control๋ก ๋ง๋ค์ด์ง ๋ฐ์ดํฐ๋ค์ ๊ฐ์ง๊ณ ํ์ต๋ GN-based model
[Visually] Swimmer6์์ ๊ทธ๋ฆผ์์ ์ฒ๋ผ ground truth์ ์์ธก ๊ฒฐ๊ณผ๊ฐ ๊ตฌ๋ถ์ด ์ ๊ฐ ์ ๋๋ก ํก์ฌํ๋ค.(์์์์๋ ๊ฑฐ์ ๊ตฌ๋ถ์ด ์ ๊ฐ ์ ๋๋ก ์ ์์ธกํ๊ณ ์์์ ์ ์ ์๋ค.)
[Quantitatively] 100 step์์ 3์ถ ๋ฐฉํฅ์ผ๋ก์ ์์น, ์ ์๋, ๊ฐ์๋, ์ฟผํฐ๋์ ๋ฐฉํฅ ๋น๊ต
constant prediction baseline์ ์์ํ์ผ๋ก ์ธํ์ ๊ทธ๋๋ก ๋ณต์ฌํด์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ์ ๋ฌ ์ต๋์น๋ก normalization ํ๊ธฐ ์ํด ๊ฒ์์ ์ ์ ์ผ๋ก ํ๊ธฐ
์ฐ์ ๊ฒ์ ์ ์ ๊ณผ ๋ง๋๊ธฐ๋ค์ ๋ญ๋ฑ๊ทธ๋ ค์ ๋ณด๋ฉด,
1 step๊ณผ 100 step์ rollout ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํ์ ๋ ๊ฒ์ ์ ์ ์ ๋นํด ํ๋์ ๋ง๋๊ธฐ๋ค์ error ๊ฐ์ด ๋ฎ์์ ์ ์ ์๋ค.
GN ๋ชจ๋ธ์ด MLP-based ๋ณด๋ค ๋ ๋ฎ์ ์ ๋ฌ๋ฅผ ๊ฐ์ง๋ ๊ฒ์ ์ ์ ์๋ค. ์ด๋ ํน๋ณํ Swimmer6์ฒ๋ผ ์์ด์ ํธ์ ๊ตฌ์กฐ๊ฐ ๋ฐ๋ณต์ ์ธ ๊ฒฝ์ฐ์ ๋์ฑ ๋์ ๋๊ฒ ๋ฎ์์ ์ ์ ์์๋ค. ์ด๋ฅผ ํตํด GN-based forward ๋ชจ๋ธ์ด ๋ค์ํ ๋ฌผ๋ฆฌ ์์คํ ๋ค์์ dynamics๋ฅผ ์ ์์ธกํจ์ ์ ์ ์๋ค.
- GN์ด MLP๋ณด๋ค ๋ generalization์ด ์ ๋จ์ ํ์ธํ ์ ์์๋๋ฐ, Swimmer6๋ฅผ ์ง์ค์ ์ผ๋ก
train
,valid
,test
๋ฐ์ดํฐ์ ๋ํด 1-step, rollout error๋ฅผ ๊ฐ๊ฐ ํ์ธํด๋ดค์ ๋, Best GN์ error ๊ฐ์ด Best MLP๋ณด๋ค ๋ฎ์์ ์ ์ ์๋ค. ๋ฟ๋ง ์๋๋ผ test data์ error ์ฆ๊ฐ์จ์ ๋ดค์ ๋์๋ GN ๋ชจ๋ธ์ test data์ error๊ฐ ๋ ์ ๊ฒ ์ฆ๊ฐํจ์ ๊ด์ฐฐํ ์ ์์๊ณ ์ด๋ agent์ bodies์ joints๋ค์ ๋ํ inductive bias๊ฐ GN์ ํตํด ์ ํ์ต๋์์์ ์ฆ๋ช ํ ์ ์๋ค.
Learning a forward model for multiple systems
ํ ๊ฐ์ ์์คํ ์์์ forward model์ ์ดํด๋ณด์์ผ๋ ์ด์ ์ฌ๋ฌ ์์คํ ์์์ forward model์ ์ฑ๋ฅ์ ์ดํด๋ณด์. GN์ ์ฌ์ฉํ๋ฉด ์ฌ๋ฌ ์์คํ ๋ค์ ๋ค์ํ ๋ณ์๋ค๋ ์ ๋ค๋ฃฐ ์ ์๋ค๋ ๊ฐ์ ์ด ์์๋ค. ์ด๋ฅผ ํ์ธํ๊ธฐ ์ํด ์ฐ์์ ์ผ๋ก static parameter๋ค(์ง๋, body์ ๊ธธ์ด, joint์ ๊ฐ๋ ๋ฑ)์ ๋ฐ๊ฟ๊ฐ๋ฉด์ forward dynamics๋ฅผ ์ด๋ป๊ฒ ํ์ตํด๊ฐ๋์ง ํ์ธํ๋ค.
Inference
Control
5. Discussion
Review
๋ ผ๋ฌธ ๋ฆฌ๋ทฐํ์ ์ฃผ๊ด์ ์ธ ์ฅ๋จ์ ์ ์ ๋ฆฌํ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
- Pros ๐
- Cons ๐
- mlp comparison
Reference
- Original paper Graph Networks as Learnable Physics Engines for Inference and Control
- Official code https://github.com/fxia22/gn.pytorch