빅데이터 분석 (11주차) 11월18일, 11월23일
체인룰, backpropagation
import torch
-
회귀분석에서 손실함수에 대한 미분은 아래와 같은 과정으로 계산할 수 있다.
- $loss = ({\bf y}-{\bf X}{\bf W})^\top ({\bf y}-{\bf X}{\bf W})={\bf y}^\top {\bf y} - {\bf y}^\top {\bf X}{\bf W} - {\bf W}^\top {\bf X}^\top {\bf y} + {\bf W}^\top {\bf X}^\top {\bf X} {\bf W}$
- $\frac{\partial }{\partial {\bf W}}loss = -2{\bf X}^\top {\bf y} +2 {\bf X}^\top {\bf X} {\bf W}$
-
체인룰을 이해하자.
- core가 여러개 있는 컴퓨터에서 어떻게 빠르게 계산할까.
- 합성함수의 미분방법인 체인룰
-
손실함수가 사실 아래와 같은 변환을 거쳐서 계산되었다고 볼 수 있다.
- ${\bf X} \to {\bf X}{\bf W} \to {\bf y} -{\bf X}{\bf W} \to ({\bf y}-{\bf X}{\bf W})^\top ({\bf y}-{\bf X}{\bf W})$
입력 데이터-임의의w-오차-오차제곱
-
위의 과정을 수식으로 정리해보면 아래와 같다.
이렇게 가정할거
-
${\bf u}={\bf X}{\bf W}$, $\quad {\bf u}: n \times 1$
-
${\bf v} = {\bf y}- {\bf u},$ $\quad {\bf v}: n \times 1$
-
$loss={\bf v}^\top {\bf v},$ $\quad loss: 1 \times 1 $
-
손실함수에 대한 미분은 아래와 같다.
(그런데 이걸 어떻게 계산함?)
-
계산할 수 있는것들의 모음..
- $\frac{\partial}{\partial {\bf v}} loss = 2{\bf v} $ $\quad \to$ (n,1) 벡터
$vv^T=v_{1}^2+...+v_{n}^2$
-
$\frac{\partial }{\partial {\bf u}} {\bf v}^\top = -{\bf I}$ $\quad \to $ (n,n) 매트릭스
$\frac{\partial }{\partial {\bf u}} {\bf v}^\top =(\frac{\partial }{\partial {\bf u}} y^T-\frac{\partial }{\partial {\bf u}} u^T)=0-\frac{\partial }{\partial {\bf u}} u^T = -I$
- $\frac{\partial }{\partial \bf W}{\bf u}^\top = {\bf X}^\top $ $\quad \to$ (p,n) 매트릭스
u(n 1)=X(n p)W(p * 1)
$\frac{\partial }{\partial \bf W}$ (p 1) $u^T$(1 n)
-
혹시.. 아래와 같이 쓸 수 있을까?
- 가능할것 같다. 뭐 기호야 정의하기 나름이니까!
-
그렇다면 혹시 아래와 같이 쓸 수 있을까?
- 이건 선을 넘는 것임.
- 그런데 어떠한 공식에 의해서 가능함. 그 공식 이름이 체인룰이다.
-
결국 정리하면 아래의 꼴이 되었다.
-
그렇다면?
그런데, ${\bf v}={\bf y}-{\bf u}={\bf y} -{\bf X}{\bf W}$ 이므로
정리하면
대응하는 미분법을 알고 있으면 코드짜기 편하겠다!
-
미분계수를 계산하는 문제
-
체인룰을 이용하여 미분계수를 계산하여 보자.
ones= torch.ones(5)
x = torch.tensor([11.0,12.0,13.0,14.0,15.0])
X = torch.vstack([ones,x]).T
y = torch.tensor([17.7,18.5,21.2,23.6,24.2])
W = torch.tensor([3.0,3.0])
u = X@W
v = y-u
loss = v.T @ v
loss
loss/5 # mse, 중간고사.
-
$\frac{\partial}{\partial\bf W}loss $ 의 계산
X.T @ -torch.eye(5) @ (2*v)
-
참고로 중간고사 답은
X.T @ -torch.eye(5)@ (2*v) / 5
입니다.
-
확인
_W = torch.tensor([3.0,3.0],requires_grad=True)
_loss = (y-X@_W).T @ (y-X@_W)
_loss.backward()
_W.grad.data
-
$\frac{\partial}{\partial \bf v} loss= 2{\bf v}$ 임을 확인하라.
v
_v= torch.tensor([-18.3000, -20.5000, -20.8000, -21.4000, -23.8000],requires_grad=True)
_loss = _v.T @ _v
_loss.backward()
_v.grad.data, v
2배가 되어야 겠지? 확인.
-
$\frac{\partial }{\partial {\bf u}}{\bf v}^\top$ 의 계산
u
_u = torch.tensor([36., 39., 42., 45., 48.],requires_grad=True)
_u
_v = y - _u ### 이전의 _v와 또다른 임시 _v
(_v.T).backward()
- 사실 토치에서는 스칼라아웃풋에 대해서만 미분을 계산할 수 있음
그런데 $\frac{\partial}{\partial {\bf u}}{\bf v}^\top=\frac{\partial}{\partial {\bf u}}(v_1,v_2,v_3,v_4,v_5)=\big(\frac{\partial}{\partial {\bf u}}v_1,\frac{\partial}{\partial {\bf u}}v_2,\frac{\partial}{\partial {\bf u}}v_3,\frac{\partial}{\partial {\bf u}}v_4,\frac{\partial}{\partial {\bf u}}v_5\big)$ 이므로
조금 귀찮은 과정을 거친다면 아래와 같은 알고리즘으로 계산할 수 있다.
(0) $\frac{\partial }{\partial {\bf u}} {\bf v}^\top$의 결과를 저장할 매트릭스를 만든다. 적당히 A
라고 만들자.
(1) _u
하나를 임시로 만든다. 그리고 $v_1$을 _u
로 미분하고 그 결과를 A
의 첫번째 칼럼에 기록한다.
(2) _u
를 또하나 임시로 만들고 $v_2$를 _u
로 미분한뒤 그 결과를 A
의 두번째 칼럼에 기록한다.
(3) (1)-(2)와 같은 작업을 $v_5$까지 반복한다.
(0)을 수행
A = torch.zeros((5,5))
A
(1)을 수행
u,v
_u = torch.tensor([36., 39., 42., 45., 48.],requires_grad=True)
v1 = (y-_u)[0]
- 이때 $v_1=g(f({\bf u}))$와 같이 표현할 수 있다. 여기에서 $f((u_1,\dots,u_5)^\top)=(y_1-u_1,\dots,y_5-u_5)^\top$, 그리고 $g((v_1,\dots,v_n)^\top)=v_1$ 라고 생각한다. 즉 $f$는 벡터 뺄셈을 수행하는 함수이고, $g$는 프로젝션 함수이다. 즉 $f:\mathbb{R}^5 \to \mathbb{R}^5$인 함수이고, $g:\mathbb{R}^5 \to \mathbb{R}$인 함수이다.
v1
grad_fn= (2)를 수행 (3)을 수행 // 그냥 (1)~(2)도 새로 수행하자. $\frac{\partial }{\partial {\bf W}}{\bf u}^\top = \frac{\partial }{\partial {\bf W}}(u_1,\dots,u_5)=\big(\frac{\partial }{\partial {\bf W}}u_1,\dots,\frac{\partial }{\partial {\bf W}}u_5 \big) $ 순전파 (1) 순전파를 하면서 입출력값을 모두 저장하고 (2) 그에 대응하는 층별 미분계수값 $2{\bf v}, -{\bf I}, {\bf X}^\top$ 를 구하고 (3) 층별미분계수값을 다시 곱하면 (그러니까 ${\bf X}^\top (-{\bf I}) 2{\bf v}$ 를 계산) 된다. (1) 순전파를 계산하고 각 층별 입출력 값을 기록 (2) 역전파를 수행하여 손실함수의 미분값을 계산 gpu특징: 큰 차원의 매트릭스 곱셈 전문가 (원리? 어마어마한 코어숫자)v1.backward()
_u.grad.data
A[:,0]= _u.grad.data
A
_u = torch.tensor([36., 39., 42., 45., 48.],requires_grad=True)
v2 = (y-_u)[1]
v2.backward()
_u.grad.data
A[:,1]= _u.grad.data
A
for i in range(5):
_u = torch.tensor([36., 39., 42., 45., 48.],requires_grad=True)
_v = (y-_u)[i]
_v.backward()
A[:,i]= _u.grad.data
A
-
$\frac{\partial }{\partial {\bf W}}{\bf u}^\top$의 계산B = torch.zeros((2,5))
B
W
_W = torch.tensor([3., 3.],requires_grad=True)
_W
for i in range(5):
_W = torch.tensor([3., 3.],requires_grad=True)
_u = (X@_W)[i]
_u.backward()
B[:,i]= _W.grad.data
B # X의 트랜스포즈
X
-
결국 위의 예제에 한정하여 임의의 ${\bf \hat{W}}$에 대한 $\frac{\partial}{\partial {\bf \hat W}}loss$는 아래와 같이 계산할 수 있다.
-
단계1에서 ${\bf v}$는 어떻게 알지?
-
(중요) step2에서 loss만 구해서 저장할 생각 하지말고 중간과정을 다 저장해라. (그중에 v와 같이 필요한것이 있을테니까) 그리고 그걸 적당한 방법을 통하여 이용하여 보자.-
아래와 같이 함수의 변환을 아키텍처로 이해하자. (함수의입력=레이어의입력, 함수의출력=레이어의출력)
-
그런데 위의 계산과정을 아래와 같이 요약할 수도 있다. (${\bf X} \to {\bf \hat y} \to loss$가 아니라 ${\bf W} \to loss({\bf W})$로 생각해보세요) <loss가 y의 함수가 아니라 W의 함수라 생각해봐>
-
그렇다면 아래와 같은 사실을 관찰할 수 있다.
-
요약: $2{\bf v},-{\bf I}, {\bf X}^\top$와 같은 핵심적인 값들이 사실 각 층의 입/출력 값들의 함수꼴로 표현가능하다. $\to$ 각 층의 입/출력 값들을 모두 기록하면 미분계산을 유리하게 할 수 있다.
-
결국
-
참고로 (1)에서 층별 입출력값은 GPU의 메모리에 기록된다.. 무려 GPU 메모리..-
작동원리를 GPU의 관점에서 요약 (슬기로운 GPU 활용)
net.to("cuda:0")
net(X)
loss = loss_fn(yhat,y)
loss.backward()
-
역전파기법은 체인룰 + $\alpha$ 이다.-
오차역전파기법이라는 용어를 쓰는 사람도 있다.-
이미 훈련한 네트워크에 입력 $X$를 넣어 결과값만 확인하고 싶을 경우 순전파만 사용하면 되고, 이 상황에서는 좋은 GPU가 필요 없다.