๐WaveNet ๋ฆฌ๋ทฐ
์ด๋ฒ ํฌ์คํ ์ Google DeepMind์์ ๋ฐํํ WaveNet์ด๋ผ๋ ๋ ผ๋ฌธ์ ๋ํด ๋ฆฌ๋ทฐ๋ฅผ ํ๋ ค๊ณ ํฉ๋๋ค. WaveNet์ Autoregressiveํ Generative model๋ก์จ Google์ ์คํผ์ปค ์๋น์ค์ ์ฌ์ฉ๋์๋ค๊ณ ๋ง์ด ์๋ ค์ง ๋ชจ๋ธ์ ๋๋ค.
๋ฆฌ๋ทฐ์ ์์์ ๊ฐ์ฅ ๋์์ ๋ง์ด ๋ฐ๊ณ ์๋ ํฌ์คํ ์ ์๋นํ ์ด๋ฏธ์ง๋ค์ด ๊น์ ํฌ ๋์ [๋ ผ๋ฌธ๋ฆฌ๋ทฐ]WaveNet ํฌ์คํ ์์ ๊ฐ์ ธ์จ ๊ฒ์์ ๋ฐํ๋ฉฐ ๊ฐ์ฌ์ ๋ง์์ ์ ํด๋๋ฆฌ๊ณ ์ถ์ต๋๋ค. ๊ฐ ์ด๋ฏธ์ง์ ์ถ์ฒ๋ ์์ฒจ์๋ก Reference numbering์ ํ์ํ์์ต๋๋ค.
Background
WaveNet์ ์์ฑ ์์ฑ ๋ชจ๋ธ๋ก ๋ณธ๊ฒฉ์ ์ผ๋ก ๋ชจ๋ธ์ ๋ํด ์์๋ณด๊ธฐ ์ ์ ์๋ฆฌ
๋ผ๋ ๊ฒ์ด ์ด๋ป๊ฒ ์ ํธ
๊ฐ ๋๋๊ฐ๋ฅผ ์ดํด๋ณผ ํ์๊ฐ ์์ต๋๋ค. ์๋ฆฌ๋ ๊ณต๊ธฐ ์
์๋ค์ ๋จ๋ฆผ์ด๋ฉฐ ์ข
ํ์ ํํ์ ๊ฐ์ง๊ณ ์์ต๋๋ค. ์ด๋ฌํ ์๋ฆฌ
๋ผ๋ ํ์์ ํ๋์ผ๋ก ํํํด๋ณด์๋ฉด, ์๋์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ๊ณต๊ธฐ ์
์๋ค์ด ๋ง์ด ๋ฐ์ง๋์ด ์๋ ๋ถ๋ถ์ ํ๋์ ์งํญ์ ํฌ๊ฒ, ์๋์ ์ผ๋ก ์
์๋ค์ ์๊ฐ ์ ์ ๊ณณ์ ์งํญ์ ์๊ฒํ์ฌ ํํํ ์ ์์ต๋๋ค.
์ด๋ ๊ฒ ํ๋ ๋ชจํ์ผ๋ก ๋ํ๋ด์ด์ง ์๋ฆฌ๋ Continutous(์ฐ์์ ์ธ) ์ ํธ ์
๋๋ค. ์ด๋ฌํ ์ ํธ๋ฅผ ์ปดํจํฐ์์ ์ฒ๋ฆฌํ๊ธฐ ์ํด์๋ ์ปดํจํฐ๊ฐ ์ดํดํ ์ ์๋๋ก Discrete(๋ถ์ฐ์์ ์ธ) ๊ฐ์ผ๋ก ๋ํ๋ผ ์ ์์ด์ผ ํ๋ฉฐ Continuousํ ์ ํธ โ Discreteํ ์ ํธ
๋ก ๋ฐ๊พธ๋ ๊ณผ์ ์ Sampling์ด๋ผ๊ณ ํฉ๋๋ค. ์ฌ๊ธฐ์ ๋์ด ์๋ ์ปดํจํฐ๋ ๋ฌดํํ (์ด์งํ๋)์ ์ ํํ์ ๊ฐ์ง ์ ์๋๊ฒ์ด ์๋๊ณ ๋ ํจ์จ์ ์ผ๋ก ์ ํธ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํด Quantization(์์ํ)๊ณผ์ ์ ๊ฑฐ์น๊ฒ ๋ฉ๋๋ค. ์ด๋ ์ํ๋ง๋์ด ์ด์ฐํ ๋์ด ์๋ ์ ํธ ๊ฐ์ Section์ ๋๋์ด ์ผ์ ๊ตฌ๊ฐ ๋ด์ ์๋ ๊ฐ๋ค์ ํ๋์ ์์ํ๋ ๊ฐ์ผ๋ก ๋งค์นญํ๋ ๊ณผ์ ์
๋๋ค. ์ด๋ ๊ฒ ์ด์ง์๋ก ์ ์ํ๋ ์๋ฆฌ๋ ์๋์ ์ค๋ฅธ์ชฝ ๊ทธ๋ฆผ์์์ ๊ฐ์ด ์๊ฐ์ถ(x)์ ๋ฐ๋ผ ๋นจ๊ฐ ์ ์ผ๋ก ๋ํ๋ด์ด์ง๋ ์ ํธ๋ก ๋ณํ๋๊ฒ ๋ฉ๋๋ค. ์ด๋ฌํ ์ ํธ์ฒ๋ฆฌ ๊ณผ์ ์ Pulse-Code Modulation(PCM)์ด๋ผ๊ณ ํฉ๋๋ค.
๋ณดํต ์๋ฆฌ์ ์ ํธ์ฒ๋ฆฌ๋ 16-bit์ ์ ์ํํ(-255 ~ 256)์ผ๋ก ๋ํ๋ด์ง๋ง WaveNet์์๋ Nonlinearity๋ฅผ ์ฆ๊ฐ์ํค๊ณ ๋ ํจ์จ์ ์ด์๋ 8-bit ์ ์ ํํ ๋์งํธ ์ ํธ๋ฅผ ์ฌ์ฉํ์ต๋๋ค. ์ด๋ ์ฌ์ฉํ ๋ฐฉ๋ฒ์ ยต-law Companding Transformation(ฮผ-law algorithm)
์ผ๋ก ์ฌ๋์ด ์๋ฆฌ๋ฅผ ์ธ์ํ๋ ๋ฐฉ๋ฒ์ ๋ชจ๋ฐฉํ ๋ฐฉ์์ ์ฌ์ฉํ์ต๋๋ค. ์ฌ๋์ ์์ ์๋ฆฌ์ ๋ณํ์๋ ๋ฏผ๊ฐํ์ง๋ง ํฐ ์๋ฆฌ์ ๋ณํ์๋ ๋๊ฐํ๋ฏ๋ก ฮผ-law algorithm์์๋ ์์ ์๋ฆฌ์ ๊ตฌ๊ฐ(์๋ ๊ทธ๋ํ์์ ์ค์ ๋ถ๋ถ)์ ์ธ๋ฐํ๊ฒ ๋๋๊ณ ํฐ ์๋ฆฌ ๊ตฌ๊ฐ(์๋ ๊ทธ๋ํ์์ ์ข์ฐ ๋ ๋ถ๋ถ)์ ๊ธฐ์ธ๊ธฐ๋ฅผ ์๋งํ๊ฒ ํ์ฌ ๋น๊ต์ ๋ฌ์ฑํ๊ฒ ๋๋์์ต๋๋ค.
WaveNet์์ 16-bit๊ฐ ์๋ 8-bit๋ฅผ ์ฌ์ฉํ ์ด์ ๋ ์๋ ๊ทธ๋ฆผ์ ์ค๋ฅธ์ชฝ์์ WaveNet์ ์ ์ฒด ํ๋ฆ์์ ๋ณผ ๋ ์์ํ๋ ๊ฐ ๊ตฌ๊ฐ์ softmax๋ก ํด๋น ๊ฐ์ ํ๋ฅ ์ ๊ตฌํ๊ฒ ๋๋๋ฐ, 16-bit๋ผ๋ฉด softmax layer์์ ์ด 65,536(= -2^{15} ~ 2^{15}-1 )๊ฐ์ ํ๋ฅ ์ ๊ตฌํด์ผ ํ๋ฏ๋ก ๊ณ์ฐ์ด ๋งค์ฐ ๋ง์ด ํ์ํ๊ธฐ ๋๋ฌธ์ ๋๋ค.
TTS(Text-to-Speech)๋
์์ ์ด์ผ๊ธฐํ๋๋ก ๊ตฌ๊ธ์ ์คํผ์ปค ์๋น์ค์ WaveNet์ด ์ฐ์ธ ๊ฒ์ผ๋ก ํฐ ํ์ ์๋๋ฐ ์ด๋ ๋ฐ๋ก TTS ์๋น์ค์ WaveNet์ด ์ฐ์ธ ๊ฒ ์ด์์ต๋๋ค. TTS task๋ ํน์ text๊ฐ ์ฃผ์ด์ง๋ฉด ์ด๋ฅผ ์์ฑ ์ ํธ๋ก ๋ฐ๊ฟ์ฃผ๋(์์ฑ์ ์์ฑํ๋) task์ด๋ฉฐ Text analysis์ Speech synthesis๊ฐ ๊ฐ์ด ์ด๋ฃจ์ด์ง๋ task ์ ๋๋ค.
๊ธฐ์กด์ TTS ๊ธฐ์ ์ ํฌ๊ฒ 2๊ฐ์ง๊ฐ ์์์ต๋๋ค. ์ฒซ๋ฒ์งธ๋ก Concatenative
๋ ๋ค๋์ ์์ฑ ๋ฐ์ดํฐ๋ฅผ ์์ ๋จ์๋ก ์ชผ๊ฐ์ด ์ ํธ๋ฅผ ์ ์ฅํ ๊ฒ์ ์กฐํฉํ์ฌ ์๋ก์ด ์์ฑ์ ์์ฑํ๋ ๋ฐฉ์์ผ๋ก, ๋ง์น ํผํธ๋ก ์ท๊ฐ์ ํจํด์ ๋ง๋ค์ด๋ด๋ฏ์ด ์์ฑ ๋จ์๋ค์ ์ด์ด๋ถ์ด๋ ๋ฐฉ์์
๋๋ค. ์ด ๋ฐฉ๋ฒ์ ์ค์ ์์ฑ ๋ฐ์ดํฐ๋ฅผ ์ชผ๊ฐ ๊ฒ์ ์ฌ์ฉํ๋ ๊ฒ์ด๋ฏ๋ก ์์ฑ ๋ฐ์ดํฐ ํ๋ ํ๋์ ํ๋ฆฌํฐ๋ ์ข์ง๋ง ๋จ์ ์ผ๋ก๋ ์์ฑ์ ์กฐ์ ํ ์ ์๋ ์์ ๋๊ฐ ๋จ์ด์ง๋ค๋ ์ ๊ณผ ์์ฑ ๋ฐ์ดํฐ๊ฐ ๋งค์ฐ ๋ง์์ผ ํ๋ค๋ ์ ์ด ์์ต๋๋ค.
๋๋ฒ์งธ๋ก Parametric
์ ํต๊ณ์ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์์ฑ์ ํฉ์ฑํ๋ ๋ฐฉ์์ผ๋ก WaveNet์ ๋ถ๋ก์ ์์ธํ ์ค๋ช
์ด ๋์ด ์๋ฏ์ด Acoustic model์ ๋ง๋ค์ด์ ์์ฑ์ ๋ง๋ค์ด ๋
๋๋ค. Concatenative์ ๋ค๋ฅด๊ฒ ์๋ก์ด ์์ฑ ๋ฐ์ดํฐ๋ฅผ ๋ง๋ค์ด๋ธ๋ค๋ ์ ์์ ์์ฑ ์ ํธ๋ฅผ ์กฐ์ํ ์ ์๋ ์์ ๋๊ฐ ์ปค์ง๊ณ ๋ฐ์ดํฐ ์
์ด ๋ง์ด ํ์ ์์ผ๋ ์์ฑ์ ์์ฑํด๋ด๋ ํ๋ฆฌํฐ๊ฐ ๋ค์ ๋จ์ด์ง๋ ๋จ์ ์ด ์์ต๋๋ค. ๊ธฐ์กด์ 2๊ฐ์ง ๋ฐฉ์๊ณผ ๋ค๋ฅด๊ฒ WaveNet์ explicitํ acoustic feature๋ฅผ ๋ชจ๋ธ๋ง ํ์ง ์๊ณ ๋ฐ๋ก raw waveform์ ์์ฑํ๋ ๊ฒ์ด ๊ฐ์ฅ ํฐ ์ฐจ์ด๋ผ๊ณ ๋ณผ ์ ์์ต๋๋ค.
WaveNet์ vocoder๋ก ์ด์ฉํ์ฌ Tacotron2์ ๊ฐ์ ํ ์คํธ์์ ์ง์ ์์ฑ ํฉ์ฑ์ ์ํ ์ ๊ฒฝ๋ง ์ํคํ ์ฒ์์ ์ฐ๊ฒ ๋ฉ๋๋ค. ์๋๋ Tacotron2์ ๊ตฌ์กฐ์ด๋ฉฐ ์ค๋ฅธ์ชฝ ์๋จ์์ WaveNet MoL(mixture of logistic distributions)์ ์ฐพ์๋ณผ ์ ์์ต๋๋ค.
WaveNet
WaveNet์ ์ ์ฒด์ ์ธ ๊ตฌ์กฐ๋ ์๋์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ํฌ๊ฒ 4๊ฐ์ง ๋ถ๋ถ์ผ๋ก ๋๋์ด์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
Dilated Casual Convolution
Residual Connection & Gated Activation Units
Skip Connection
Conditional WaveNets
WaveNet ๊ตฌํ์ ๋ด์ฉ ์ดํด๋ฅผ ์ฐ์ ์ผ๋ก ํ๊ธฐ ์ํด ๋น๊ต์ ๊ตฌํ์ด ๊ฐ๋จ ๋ช ๋ฃํ๊ฒ ๋์ด์๋ Reference[17]์ ์ฐธ๊ณ ํ์์ต๋๋ค.(Youtube ๊ฐ์) ์ฐ์ WaveNet์ ์ ์ฒด ์ฝ๋๋ ๋ค์๊ณผ ๊ฐ๊ณ class ๋ด๋ถ์ ์๋ ๋ค๋ฅธ module class์ ๋ํ ์์ธํ ์ฝ๋๋ ์๋ ๋ด์ฉ์์ ์ค๋ช ๊ณผ ํจ๊ป ๋์ฌ ์์ ์ ๋๋ค.
class WaveNet(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stack_size, layer_size):
super().__init__()
self.stack_size = stack_size
self.layer_size = layer_size
self.kernel_size = kernel_size
self.casualConv1D = CasualDilatedConv1D(in_channels, in_channels, kernel_size, dilation=1)
self.stackResBlock = StackOfResBlocks(self.stack_size, self.layer_size, in_channels, out_channels, kernel_size)
self.denseLayer = DenseLayer(out_channels)
def calculateReceptiveField(self):
return np.sum([(self.kernel_size - 1) * (2 ** l) for l in range(self.layer_size)] * self.stack_size)
def calculateOutputSize(self, x):
return int(x.size(2)) - self.calculateReceptiveField()
def forward(self, x):
# x: b c t -> input data size
x = self.casualConv1D(x)
skipSize = self.calculateOutputSize(x)
_, skipConnections = self.stackResBlock(x, skipSize)
dense=self.denseLayer(skipConnections)
return dense
1. Dilated Casual Convolution
๋จผ์ Dilated Casual Convolution
์ ยต-law Companding Transformation
์ฒ๋ฆฌ๋ฅผ ๊ฑฐ์น ์์ฑ ์ ํธ๋ฅผ ๋ฐ์์ค๋ ์ฒซ๋ฒ์งธ ๋ถ๋ถ์
๋๋ค.
์ฐ์ Casual ์ด๋ผ๋ ๊ฒ์ Time-series์ธ ์์ฑ ์ ํธ์ ์๊ฐ ์์๋ฅผ ๊ณ ๋ คํ์ฌ ํ์ฌ ์์ t๋ฅผ ๊ธฐ์ค์ผ๋ก ๋ฏธ๋ ์ ๋ณด๋ ์ฌ์ฉํ ์ ์๊ณ ํ์ฌ๊น์ง์(๊ณผ๊ฑฐ~ํ์ฌ t) ์ ๋ณด๋ง
์ฌ์ฉํ ์ ์๋ค๋ ์๋ฏธ์
๋๋ค. ์ผ์ชฝ Causal Convolution ๊ทธ๋ฆผ์์ Receptive Field๋ (๋ ์ด์ด ์) + (ํํฐ์ length) -1
๋ก ๊ณ์ฐ๋์ด ์ด ๋ ์ด์ด ์๋ 4๊ฐ์ด๊ณ ํํฐ length๋ ์ด์ ๋ ์ด์ด์์ 2๊ฐ์ ์ ๋ณด๊ฐ ๋ชจ์์ ธ์ ๋ค์ ๋ ์ด์ด์ ํ๋์ ๋ฐ์ดํฐ๋ก ์ฐ์ถ๋๋ฏ๋ก ํํฐ length๋ 2๋ผ๊ณ ๋ณผ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ 4+2-1๋ก Receptive Field๋ 5๊ฐ ๋๋ฉฐ ์ด๋ฅผ ๊ทธ๋ฆผ์์ ์ดํด๋ณด๋ฉด ์ฒ์ input
์์ 5๊ฐ์ ์์ฑ ์ ๋ณด๊ฐ output
์ 1๊ฐ์ ์ ๋ณด๋ก ๋์ค๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ์ด๋ฐ Receptive Field๋ ๋งค์ฐ ์งง์ ์๊ฐ์ ๋ง์ ์์ฑ์ ํธ๊ฐ ๋งค์นญ๋๋ ์ํฉ์์ ๋งค์ฐ ์ข์ผ๋ฉฐ RF๋ฅผ ๋๋ฆฌ๊ธฐ ์ํด์๋ ๋ ์ด์ด ์๋ฅผ ๋๋ฆฌ๊ฑฐ๋ ํํฐ์ length๋ฅผ ๋๋ ค์ผ ํ๋๋ฐ ์ด๋ ๋ชจ๋ธ์ ๋งค์ฐ ํฌ๊ฒ ๋ง๋ค๊ฒ ๋๊ณ ๊ณ์ฐ๋ ๋ง์ด ์๊ตฌ๋ฉ๋๋ค.
๊ทธ๋์ ์ ์์ด ๋ ๋ฐฉ๋ฒ์ด ๋ฐ๋ก Dilated Convolution์ ๋๋ค. ์ด๋ convolution with holes๋ก ํด์ํ ์ ์๋๋ฐ ์์ ๊ทธ๋ฆผ์์ ๋ณผ ์ ์๋ฏ์ด ์ด์ ๋ ์ด์ด์์ ๋ฐ์ดํฐ๊ฐ Dilated๋์ด ๋ฐ์ดํฐ๊ฐ ๋ฌ์ฑ๋ฌ์ฑํ๊ฒ ๋ชจ์์ ธ์ ๋ค์ ๋ ์ด์ด๋ก ๋์ด๊ฐ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ์ด๋ skip์ด๋ pooling๊ณผ ์ ์ฌํด๋ณด์ด์ง๋ง input๊ณผ output์ ์ฐจ์์ด ์ ์ง๋๋ค๋ ์ ์์ ์ฐจ์ด๊ฐ ์์ต๋๋ค. ์ด๋์ RF๋ ๊ฐ ๋ ์ด์ด์ Dilation ๊ฐ์ ๋ชจ๋ ๋ํ๊ณ ๋ง์ง๋ง์ ํ์ฌ ์์ ์ ๋ฐ์ดํฐ 1์ ๋ํ๋ฉฐ RF๊ฐ ๊ณ์ฐ๋ฉ๋๋ค. WaveNet์์๋ Dilation์ ์ด 30๊ฐ์ ๋ ์ด์ด์ ์ ์ฉํ๊ณ Dilation ๊ฐ์ ํจํด์ input์์ ๋ถํฐ 1, 2, โฆ, 512 ๋ก 2๋ฐฐ์ฉ ๋๋ฆฐ 10๊ฐ์ ๋ ์ด์ด๋ฅผ ์ด 3๋ฒ ๋ฐ๋ณตํ์ต๋๋ค. ์ด๋, 1 ~ 512 Dilation ๊ฐ์ ๊ฐ์ง 10๊ฐ ๋ ์ด์ด์ RF๋ 1024๋ก ๊ณ์ฐ๋ฉ๋๋ค.
Code ๊ตฌํ์ผ๋ก ์ดํด๋ณด๋ฉด ์๋์ ๊ฐ์ด ๊ตฌํํ ์ ์์ต๋๋ค. Casual ํน์ฑ์ ๋ฐ์ํ๊ธฐ ์ํด self.ignoreOutIndex
์ ๋ง๋ค์ด์ dilation ๊ฐ์ ๊ณ ๋ คํ์ฌ (kernel_size - 1) * dilation
์ผ๋ก ๊ณ์ฐํ ํ์ ์๋ผ๋ด์ฃผ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
class CasualDilatedConv1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation, padding=1):
super().__init__()
self.conv1D = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, bias=False, padding='same')
self.ignoreOutIndex = (kernel_size - 1) * dilation # casual
def forward(self, x):
return self.conv1D(x)[..., :-self.ignoreOutIndex] # casual
2. Residual Connection & Gated Activation Units
๋ค์์ผ๋ก Dilated Causal Convolution์ ๊ฑฐ์น ํ ํต๊ณผํ๊ฒ ๋๋ Residual Connection & Gated Activation Units
๋ถ๋ถ์ ๋ํด์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
WaveNet์์ ์ฌ์ฉ๋ Gated Activation Units๋ PixelCNN์์ ์ฌ์ฉ๋ ๋งค์ปค๋์ฆ์ ์ฐจ์ฉํ์ต๋๋ค. ์๋์ ๊ทธ๋ฆผ์์ ๋ณด์ด๋ ๋ณด๋ผ์ Dilated Conv๊ฐ ์์์ ์ค๋ช ํ DCC์ด๋ฉฐ ์ด๋ฅผ ๊ฑฐ์น ํ Convoltion layer์ ๊ฐ๊ฐ tanh, sigmoid activation์ ํต๊ณผํ์ฌ Filter, Gate๊ฐ ๋ฉ๋๋ค. ์ด 2๊ฐ์ง ๊ฒฝ๋ก๋ก ๊ณ์ฐ๋ ๊ฐ์ elementwise product๋ฅผ ํตํด ํ๋์ ๋ฒกํฐ๋ก ๋ณํ๋ฉ๋๋ค. ์ด๋ Dilated๋ฅผ ํต๊ณผํ๊ธฐ ์ ๊ฐ์ Residual Connection์ ํตํด ์ฐ๊ฒฐํจ์ผ๋ก์จ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ๋ ์ด์ด๋ฅผ ๋ ๊น๊ฒ ์์ ์ ์๋๋ก ๋๊ณ ๋ ๋น ๋ฅด๊ฒ ํ์ตํ ์ ์๋๋ก ํ ์ ์์๋ค๊ณ ํฉ๋๋ค.
3. Skip Connection
Skip Connection์ Dilated Convolution์ ํตํด ๋ค์ํ Receptive Field๋ฅผ ๊ฐ์ง ๊ฐ ๋ ์ด์ด๋ค์ ๊ฐ์ ํ์ฉํ์ฌ output์ ๋ง๋ค์ด๋ผ ์ ์๋๋ก ํ์ต๋๋ค. ์์ ์ค๋ช ํ๋ ๋๋ก ๊ฐ Residual Block์ Dilation ๊ฐ์ด ๋ค ๋ค๋ฅด๊ธฐ ๋๋ฌธ์ ๊ฐ Residual Block์ output์ ์๋ก ๋ค๋ฅธ Receptive Field๋ฅผ ๊ฐ์ง๊ฒ ๋ฉ๋๋ค.
Residual Connection๊ณผ Skip Connection์ Code๋ก ๊ตฌํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค. ์์์ ์ค๋ช
ํ๋ Gated Activation Units์ tanh, sigmoid activation์ ๊ฐ๊ฐ์ activation function์ ๊ฑฐ์นํ self.resConv1D
์ ํต๊ณผํ๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ๋ํ Skip Connection์ ๊ตฌํํ๋ ๋ถ๋ถ์ self.skipConv1D
์์ ํ์ธํ ์ ์์ต๋๋ค. ๋ง์ง๋ง return์์ resOutput
, skipOutput
์ผ๋ก 2๊ฐ์ output์ด ๋์ค๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
class ResBlock(nn.Module):
def __init__(self, res_channels, skip_channels, kernel_size, dilation):
super().__init__()
self.casualDilatedConv1D = CasualDilatedConv1D(res_channels, res_channels, kernel_size, dilation=dilation)
self.resConv1D = nn.Conv1d(res_channels, res_channels, kernel_size=1)
self.skipConv1D = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
def forward(self, inputX, skipSize):
x = self.casualDilatedConv1D(inputX)
x1 = self.tanh(x)
x2 = self.sigmoid(x)
x = x1 * x2
resOutput = self.resConv1D(x)
resOutput = resOutput + inputX[..., -resOutput.size(2):]
skipOutput = self.skipConv1D(x)
skipOutput = skipOutput[..., -skipSize:]
return resOutput, skipOutput
์์ ๊ฐ์ ResBlock
์ ์ ์ฒด ๊ตฌ์กฐ์์ ๋ณด์๋ค์ํผ ์ฌ๋ฌ๊ฐ๊ฐ stacked ๋์ด ์์ผ๋ฏ๋ก StackOfResBlocks
class๋ก ๊ตฌํํ์ฌ WaveNet์ ๋ฃ์ด์ฃผ๊ฒ ๋ฉ๋๋ค.
class StackOfResBlocks(nn.Module):
def __init__(self, stack_size, layer_size, res_channels, skip_channels, kernel_size):
super().__init__()
buildDilationFunc = np.vectorize(self.buildDilation)
dilations = buildDilationFunc(stack_size, layer_size)
self.resBlocks = []
for s,dilationPerStack in enumerate(dilations):
for l,dilation in enumerate(dilationPerStack):
resBlock=ResBlock(res_channels, skip_channels, kernel_size, dilation)
self.add_module(f'resBlock_{s}_{l}', resBlock) # Add modules manually
self.resBlocks.append(resBlock)
def buildDilation(self, stack_size, layer_size):
# stack1=[1,2,4,8,16,...512]
dilationsForAllStacks = []
for stack in range(stack_size):
dilations = []
for layer in range(layer_size):
dilations.append(2 ** layer)
dilationsForAllStacks.append(dilations)
return dilationsForAllStacks
def forward(self, x, skipSize):
resOutput = x
skipOutputs = []
for resBlock in self.resBlocks:
resOutput, skipOutput = resBlock(resOutput, skipSize)
skipOutputs.append(skipOutput)
return resOutput, torch.stack(skipOutputs)
4. Conditional WaveNets
Conditional Modeling์ Autoregressive model์ธ WaveNet์ ์ ์ฉํ๊ธฐ ์ฝ๊ณ ์ด ๋ํ PixelCNN์์์ ์์ด๋์ด์ ์ ์ฌํฉ๋๋ค. Feature h ๋ฒกํฐ๋ฅผ ์กฐ๊ฑด ๋ถ๋ถ์ ์ถ๊ฐํ์ฌ ์์ฑ ๋ฐ์ดํฐ์ ์กฐ๊ฑด์ ์ถ๊ฐํ ์ ์์ต๋๋ค.
p(\mathbf{x} \mid \mathbf{h})=\prod_{t=1}^T p\left(x_t \mid x_1, \ldots, x_{t-1}, \mathbf{h}\right)
Condition์๋ ํฌ๊ฒ 2๊ฐ์ง๋ก Global๊ณผ Local์ด ์์ต๋๋ค. ๋จผ์ Global์ Time-invariantํ ์กฐ๊ฑด์ผ๋ก ์์ ์ ๋ฐ๋ผ ๋ณํ์ง ์๋ ์กฐ๊ฑด ์ ๋ณด๋ฅผ ์ถ๊ฐํ๋ ๊ฒ์ ๋งํฉ๋๋ค. ์๋ฅผ ๋ค์ด ํ ๋ฐํ์์ ์์ฑ์ ํด๋น ์์ฑ ํ์ผ์ ์ด๋ค ์์ ์์๋ ๋๊ฐ์ condition์ด๊ธฐ ๋๋ฌธ์ Global condition์ด๋ผ๊ณ ํ ์ ์์ต๋๋ค. ์ด๋์ Feature vector h๋ linear projection์ ๊ฑฐ์น ํ data x์ ๋ํ๊ฒ ๋ฉ๋๋ค.
๋ค์์ผ๋ก Time-variantํ Local condition์ ์์ ์ ๋ฐ๋ผ ๋ณํ๋ ์กฐ๊ฑด ์ ๋ณด๋ฅผ ์ถ๊ฐํ๋ ๊ฒ์ ๋งํ๋๋ฐ ์์ฑ ๋ฐ์ดํฐ๋ณด๋ค ๊ธธ์ด๊ฐ ์งง์ง๋ง ์์๊ฐ ์๋ ์ผ์ ๊ธธ์ด์ Sequence vector๋ผ๊ณ ์๊ฐํ ์ ์์ต๋๋ค. ๊ฐ์ ๋ฐํ์์ฌ๋ ์ด๋ค ๋จ์ด๋ฅผ ๋งํ๋๋์ ๋ฐ๋ผ ์์ฑํ์ ์ธ ํน์ง(linguistic feature)๊ฐ ๋ค๋ฅผ ์ ์๊ธฐ ๋๋ฌธ์ localํ ์กฐ๊ฑด์ ํ ์์ฑ ํ์ผ์ ์ฌ๋ฌ๊ฐ๊ฐ ์์ ์ ์์ต๋๋ค. ์ด๋ Feature vector h๋ ์์ฑ ํ์ผ๊ณผ ๊ธธ์ด๊ฐ ๋ค๋ฅด๊ธฐ ๋๋ฌธ์ Upsampling์ ๊ฑฐ์นํ 1x1 convolution์ ๊ฑฐ์ณ์ data x์ ๋ํด์ง๋๋ค.
Experiments
์คํ์ ์ด 4๊ฐ์ง Free-form Speech Generation, TTS, Music Audio Modelling, Speech Recognition์ ์งํํ์ง๋ง ์ฃผ๋ ์คํ์ TTS๋ฅผ ์ค์ฌ์ผ๋ก ์ด๋ฃจ์ด์ก์ผ๋ฉฐ Evaluation์ 2๊ฐ์ง๋ก Paired Comparison Test
, Mean Opinion Score
์ผ๋ก ์งํํ์ต๋๋ค. Paired Comparison Test
์ ํผ์คํ์์๊ฒ 2๊ฐ์ ์คํ ๋ชจ๋ธ๋ก๋ถํฐ ์์ฑ๋ ์์ฑ ํ์ผ์ ๋ค๋ ค์ฃผ๊ณ ๋ ์ค ๋ ์์ฐ์ค๋ฝ๋ค๊ณ ์๊ฐ๋๋ ์์ฑ ํ์ผ์ ์ ํํ๊ฒ ํฉ๋๋ค. ์ด๋ ๋ ๊ฐ์ ์์ฑ๋ค์์ ๋ฑํ ์ ํธ๋๊ฐ ์์ ๊ฒฝ์ฐ์๋ No preference
๋ก ์๋ตํ ์ ์์ต๋๋ค. Mean Opinion Score
์คํ์์๋ ํผ์คํ์์๊ฒ ์์ฑ๋ ์์ฑ 1๊ฐ๋ฅผ ๋ค๋ ค์ฃผ๊ณ 1~5์ ์ฌ์ด์ ํ์ง ์ ์๋ฅผ ๋ฐ๊ฒ ๋ฉ๋๋ค. (1: Bad, 2: Poor, 3: Fair, 4: Good, 5: Excellent)
TTS ์คํ์์ Paired Comparison Test๋ฅผ ์งํํ๊ธฐ ์ํด ์
๋ ฅ text์์ ์ถ์ถ๋ linguistic feature[L]์ ์์ฑ์ ํน์ง ์ค ํ๋์ธ logarithmic fundamental frequency(F_o)[F]๋ฅผ local condition์ผ๋ก ๋ฃ์ด์ฃผ์์ต๋๋ค. ์ด๋ Receptive Field๋ 240 ๋ฐ๋ฆฌ์ธ์ปจ๋์์ผ๋ฉฐ ๋น๊ต๋ชจ๋ธ๋ก๋ concatenative ๊ณ์ด
์ HMM-driven unit selection๊ณผ parametric ๊ณ์ด
์ LSTM-RNN-based ๋ชจ๋ธ์ ๊ฐ์ง๊ณ ๋น๊ตํ์ต๋๋ค.
Preference score์ ๋น๊ตํด๋ดค์ ๋, ์ฐ์ ๊ธฐ์กด์ ๋ฐฉ๋ฒ๋ก ์ด์๋ LSTM์ Concat์ ๋น๊ตํด๋ณด๋ฉด(๊ฐ์ฅ ์ผ์ชฝ bar graph) ์์ด์์๋ Concat
์ด ์ค๊ตญ์ด์์๋ LSTM
์ด ๋ ๋์ ์ ์๋ฅผ ๋ฐ์ ๊ฒ์ ๋ณด์ ๋ฐ์ดํฐ๊ฐ ๋ง์ ์์ด์์๋ Concat ๋ฐฉ๋ฒ๋ก ์ด ๋ ์ข์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ๋ค์์ผ๋ก WaveNet์ local condition์ L๋ง ์ฃผ์์ ๋
์ L+F๋ฅผ ์ฃผ์์ ๋
๋ฅผ ๋น๊ตํด๋ณด๋ฉด(๊ฐ์ด๋ฐ bar graph) local condition ์กฐ๊ฑด์ด ๋ง์์๋ก, ์ฆ L+F๋ฅผ local condition์ผ๋ก ์ฃผ์์ ๋ ์ ํธ๋๊ฐ ๋์์ ์ ์ ์์์ต๋๋ค. ๋ง์ง๋ง์ผ๋ก ๋น๊ต๊ตฐ์ด์๋ ๊ธฐ์กด์ ๋ชจ๋ธ๋ค ์ค ๊ฐ์ฅ ์ ํธ๋๊ฐ ๋์ ๋ชจ๋ธ๊ณผ WaveNet์ ๋ชจ๋ local condition์ ์ฃผ์์ ๋๋ฅผ ๋น๊ตํด๋ณด๋ฉด(๊ฐ์ฅ ์ค๋ฅธ์ชฝ bar graph) ์์ด์ ์ค๊ตญ์ด ๋ชจ๋์์ WaveNet์ ์ ํธ๋๊ฐ ๋์ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
๋๋ฒ์งธ ์คํ์ธ Mean Opinion Score์์๋ WaveNet์ด 4์ Good์ ์์ด์ ์ค๊ตญ์ด์์ ๋ชจ๋ ๋์ ๊ฒ์ ํ์ธํ ์ ์์์ผ๋ฉฐ ์ค์ ์์ฑ(ground truth)์์ 8-bit ํน์ 16-bit๋ก ๋ณํํ ๊ฒ๊ณผ ๊ธฐ์กด ๋ชจ๋ธ๋ค(LSTM, HMM)์ฌ์ด์ ์ฐจ์ด๋ฅผ ๋ ์ค์ฌ์ค ๊ฒ์ ํ์ธํจ์ผ๋ก์จ ์์ฑ ์์ฑ ๋ชจ๋ธ์ ํผํฌ๋จผ์ค๊ฐ ํฅ์๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
Conclusion
WaveNet ๋ ผ๋ฌธ์์๋ ์์ฑ ์์ฑ์ raw data๋ก ๋ฐ๋ก ํ ์ ์์๋ค๋ ๊ฒ์ ๋ณด์ฌ์ค ๊ฒ์ ํฐ Contribution์ด ์์ต๋๋ค. ์ด๋ฅผ ์ํด Dilated Causal Convolution / Skip / Residual ๊ธฐ๋ฒ์ ์ด์ฉํ์ฌ Receptive Field๋ฅผ ๋๋ ค์ ๊ธด ์์ฑ ํํ์ ํ์ตํ ์ ์๋๋ก ํ์ต๋๋ค. ๋ํ ์์ฑ ํํ ๋ฐ์ดํฐ์๋ค๊ฐ conditioning model์ ๋ํจ์ผ๋ก์จ ๋ ํน์ง์ ์ด๊ณ ์์ฐ์ค๋ฌ์ด ์์ฑ์ ์์ฑ ํ ์ ์๋๋ก ํ์ต๋๋ค. ๋ง์ง๋ง์ผ๋ก TTS๋ฅผ ์ค์ฌ์ผ๋ก ์ฐ๊ตฌ๊ฐ ๋๊ธดํ์ง๋ง ์์ ๊ณผ ๊ฐ์ ์ฌ๋์ ์์ฑ์ด ์๋ ์์ฑ ๋ฐ์ดํฐ ์์ฑ์๋ potentialํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ฃผ์ด ๊ทธ ํ์ฅ์ฑ์ด ์ข๋ค๊ณ ํ ์ ์์ต๋๋ค.
Improved Works
WaveNet์ auto-regressiveํ ํน์ฑ์ผ๋ก ์ธํด ๊ณ์ฐ๋์ด ๋ง๊ณ ๋๋ฆฐ ์์ฑ์ ๋ณด์ํ Fast Wavenet Generation Algorithm ์ฐ๊ตฌ๊ฐ ์์์ต๋๋ค. ๋คํธ์ํฌ์ ๋ ์ด์ด ์๋ฅผ L์ด๋ผ๊ณ ํ์ ๋ ๊ธฐ์กด์ naive WaveNet์ด O(2^L) ๋ณต์ก๋๊ฐ ์์์ง๋ง ์ค๋ณต๋๋ convolution ์ฐ์ฐ์ cachingํจ์ผ๋ก์จ O(L) ๋ณต์ก๋๋ก ์ค์ผ ์ ์์์ต๋๋ค.
Reference
[1] Original paper - WaveNet: A Generative Model for Raw Audio
[2] Project page - https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
[3] https://brilliant.org/practice/wave-anatomy-2/
[4] https://m.blog.naver.com/sbkim24/10084099777
[5] https://blog.naver.com/sorionclinic/221184537689
[6] https://joungheekim.github.io/2020/09/17/paper-review/
[7] https://tech.kakaoenterprise.com/66
[9] https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
[10] https://youtu.be/m2A9g6Xu91I
[11] https://youtu.be/GyQnex_DK2k
[12] https://wiki.aalto.fi/pages/viewpage.action?pageId=149890776
[13] https://youtu.be/MNZepE1m-kI
[14] https://medium.com/@satyam.kumar.iiitv/understanding-wavenet-architecture-361cc4c2d623
[15] https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
[16] https://towardsdatascience.com/wavenet-google-assistants-voice-synthesizer-a168e9af13b1
[17] https://github.com/antecessor/Wavenet
[18] https://youtu.be/nsrSrYtKkT8
[19] https://research.google/pubs/pub45882/