기계학습 특강 (8주차) 10월26일--(2)
이미지자료분석 - Transfer Learning, CAM (설명가능한 인공지능모형, XAI)
import torch
import torchvision
from fastai.vision.all import *
path = untar_data(URLs.CIFAR)
path.ls()
!ls '/home/csy/.fastai/data/cifar10/train'
(1) dls
dls = ImageDataLoaders.from_folder(path,train='train',valid='test')
_X,_y = dls.one_batch()
_X.shape, _y.shape
!ls '/home/csy/.fastai/data/cifar10/train' # 10개의 클래스
dls.show_batch()
(2) lrnr 생성
net1 = torch.nn.Sequential(
torch.nn.Conv2d(3,128,(5,5)),
torch.nn.ReLU(),
torch.nn.MaxPool2d((2,2)),
torch.nn.Flatten()
)
net는 cpu에 있고 X는 gpu에 있으니 cpu로 불러오자
net = torch.nn.Sequential(
net1,
torch.nn.Linear(25088,10)
)
loss_fn = torch.nn.CrossEntropyLoss()
lrnr = Learner(dls,net,loss_fn,metrics=accuracy)
net.to("cuda:0")
(3) 학습
X,y=dls.one_batch()
lrnr.model(X).shape
lrnr.fit(10)
- 이게 생각보다 잘 안맞아요.. 70넘기 힘듬
(2) lrnr 생성
학습되어 있는 파라메터까지 같이 가져오기
net = torchvision.models.resnet18(weights=torchvision.models.resnet.ResNet18_Weights.IMAGENET1K_V1)
net
- $k=1000$ 즉 1000개의 물체를 구분하는 모형임
net.fc = torch.nn.Linear(in_features=512, out_features=10)
loss_fn = torch.nn.CrossEntropyLoss()
lrnr = Learner(dls,net,loss_fn,metrics=accuracy)
(3) 학습
lrnr.fit(10)
- CIFAR10을 맞추기 위한 네트워크가 아님에도 불구하고 상당히 잘맞음
- 일반인이 거의 밑바닥에서 설계하는것보다 전이학습을 이용하는 것이 효율적일 경우가 많다.
-
예전코드 복습
path = untar_data(URLs.PETS)/'images'
files= get_image_files(path)
def label_func(fname):
if fname[0].isupper():
return 'cat'
else:
return 'dog'
dls = ImageDataLoaders.from_name_func(path,files,label_func,item_tfms=Resize(512))
lrnr = vision_learner(dls,resnet34,metrics=accuracy)
lrnr = cnn_learner(dls,resnet34,metrics=accuracy)
lrnr.fine_tune(1)
-
사실 위의 코드가 transfer learning 이었음.
lrnr.model
-
Class Activation Mapping (CAM)은 설명가능한 인공지능모형 (eXplainable Artificial Intelligence, XAI) 중 하나로 CNN의 판단근거를 시각화하는 기술
path = untar_data(URLs.PETS)/'images'
path.ls()
files= get_image_files(path)
def label_func(fname):
if fname[0].isupper():
return 'cat'
else:
return 'dog'
dls = ImageDataLoaders.from_name_func(path,files,label_func,item_tfms=Resize(512))
ximg = PILImage.create('/home/csy/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_106.jpg')
ximg
x = first(dls.test_dl([ximg]))[0]
x,x.shape
ap = torch.nn.AdaptiveAvgPool2d(output_size=1)
X = torch.arange(48).reshape(1,3,4,4)*1.0
X
ap(X)
X[0,0,...].mean(),X[0,1,...].mean(),X[0,2,...].mean()
(예시1)
tsr = torch.arange(12).reshape(4,3)
tsr
torch.einsum('ij->ji',tsr)
(예시2)
tsr1 = torch.arange(12).reshape(4,3).float()
tsr2 = torch.arange(15).reshape(3,5).float()
tsr1 @ tsr2
torch.einsum('ij,jk -> ik',tsr1,tsr2)
(예시3)
x.to("cpu").shape
torch,einsum
을 사용하여 shape을 아래로 변경
torch.einsum('ocij -> ijc',x.to("cpu")).shape
plt.imshow(torch.einsum('ocij -> ijc',x.to("cpu")))
lrnr = vision_learner(dls,resnet34,metrics=accuracy)
lrnr = cnn_learner(dls,resnet34,metrics=accuracy)
lrnr.fine_tune(1)
-
모형의 분해
net1= lrnr.model[0]
net2= lrnr.model[1]
net1이 2d part, net1이 1d part
-
net2를 좀더 살펴보자.
net2
_X, _y = dls.one_batch()
net1.to("cpu")
net2.to("cpu")
_X = _X.to("cpu")
print(net1(_X).shape)
print(net2[0](net1(_X)).shape)
print(net2[1](net2[0](net1(_X))).shape)
print(net2[2](net2[1](net2[0](net1(_X)))).shape)
-
net2를 아래와 같이 수정하고 재학습하자 (왜?)
net2= torch.nn.Sequential(
torch.nn.AdaptiveAvgPool2d(output_size=1), # (64,512,16,16) -> (64,512,1,1)
torch.nn.Flatten(), # (64,512,1,1) -> (64,512)
torch.nn.Linear(512,2,bias=False) # (64,512) -> (64,2)
)
net = torch.nn.Sequential(
net1,
net2
)
lrnr2= Learner(dls,net,metrics=accuracy) # loss_fn??
lrnr2.loss_func, lrnr.loss_func ## 알아서 기존의 loss function으로 잘 들어가 있음.
lrnr2.fine_tune(5) # net2를 수정해서 accuracy가 안좋아지긴 했는데 그래도 쓸만함
-
1개의 observation을 고정하였을 경우 출력과정 상상
ximg = PILImage.create('/home/csy/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_106.jpg')
x = first(dls.test_dl([ximg]))[0]
net2
print(net1(x).shape)
print(net2[0](net1(x)).shape)
print(net2[1](net2[0](net1(x))).shape)
print(net2[2](net2[1](net2[0](net1(x)))).shape)
-
최종결과 확인
net(x)
아마 모델 달라서 값이 다른 것일까..!
dls.vocab
- net(x)에서 뒤쪽의 값이 클수록 'dog'를 의미한다.
-
net2의 순서 바꾸기 전 전체 네트워크:
$$\underset{(1,3,512,512)}{\boldsymbol x} \overset{net_1}{\to} \left( \underset{(1,512,16,16)}{\tilde{\boldsymbol x}} \overset{ap}{\to} \underset{(1,512,1,1)}{{\boldsymbol \sharp}}\overset{flatten}{\to} \underset{(1,512)}{{\boldsymbol \sharp}}\overset{linear}{\to} \underset{(1,2)}{\hat{\boldsymbol y}}\right) = [-9.0358, 9.0926]$$
-
아래와 같이 순서를 바꿔서 한번 계산해보고 싶다. (왜???..)
$$\underset{(1,3,224,224)}{\boldsymbol x} \overset{net_1}{\to} \left( \underset{(1,512,16,16)}{\tilde{\boldsymbol x}} \overset{linear}{\to} \underset{(1,2,16,16)}{{\bf why}}\overset{ap}{\to} \underset{(1,2,1,1)}{{\boldsymbol \sharp}}\overset{flatten}{\to} \underset{(1,2)}{\hat{\boldsymbol y}}\right) = [−9.0358,9.0926]$$
- 여기에서 (1,512,16,16) -> (1,2,16,16) 로 가는 선형변환을 적용하는 방법? (16,16) each pixel에 대하여 (512 $\to$ 2)로 가는 변환을 수행
-
통찰: 이 경우 특이하게도 레이어의 순서를 바꿨을때 출력이 동일함 (선형변환하고 평균내거나 평균내고 선형변환하는건 같으니까)
_x =torch.tensor([1,2,3.14,4]).reshape(4,1)
_x
_l1 = torch.nn.Linear(1,1,bias=False)
_l1(_x).mean() # _x -> 선형변환 -> 평균
_l1(_x.mean().reshape(1,1)) # _x -> 평균 -> 선형변환
-
구현해보자.
net2[2].weight.shape,net1(x).shape
why = torch.einsum('cb,abij->acij',net2[2].weight,net1(x))
why.shape
net2[0](why)
net(x)
-
이미지
ximg
-
네트워크의 결과
net2(net1(x))
- -9.0358 << 9.0926 이므로 'ximg'는 높은 확률로 개라는 뜻이다.
내거에서는 9.0926이 10.2985
-
아래의 네트워크를 관찰
$$\underset{(1,2,16,16)}{{\bf why}}\overset{ap}{\to} \underset{(1,2,1,1)}{{\boldsymbol \sharp}}\overset{flatten}{\to} \underset{(1,2)}{\hat{\boldsymbol y}} = [-9.0358,9.0926]$$
net2[0](why)
더 파고들어서 분석해보자.
why.shape
(why[0,0,:,:]).mean(), (why[0,1,:,:]).mean()
why[0,0,:,:]
(why[0,0,:,:]).to(torch.int64)
- 이 값들의 평균은 -9.0358 이다. (이 값이 클수록 이 그림이 고양이라는 의미 = 이 값이 작을수록 이 그림이 고양이가 아니라는 의미)
- 그런데 살펴보니 대부분의 위치에서 0에 가까운 값을 가짐. 다만 특정위치에서 엄청 큰 작은값이 있어서 -9.0358이라는 평균값이 나옴 $\to$ 특정위치에 존재하는 엄청 작은 값들은 ximg가 고양이가 아니라고 판단하는 근거가 된다.
why[0,1,:,:]
(why[0,1,:,:]).to(torch.int64)
- 이 값들의 평균은 9.0926 이다. (이 값이 클수록 이 그림이 강아지라는 의미)
- 그런데 살펴보니 대부분의 위치에서 0에 가까운 값을 가짐. 다만 특정위치에서 엄청 큰 값들이 있어서 9.0926이라는 평균값이 나옴 $\to$ 특정위치에 존재하는 엄청 큰 값들은 결국 ximg를 강아지라고 판단하는 근거가 된다.
-
시각화
why_cat = why[0,0,:,:]
why_dog = why[0,1,:,:]
fig, ax = plt.subplots(1,3,figsize=(8,4))
ax[0].imshow(torch.einsum('ocij -> ijc',dls.decode((x,))[0]).to("cpu"))
ax[1].imshow(why_cat.to("cpu").detach(),cmap='magma')
ax[2].imshow(why_dog.to("cpu").detach(),cmap='magma')
- magma = 검은색 < 보라색 < 빨간색 < 노란색
- 왼쪽그림의 검은 부분은 고양이가 아니라는 근거, 오른쪽그림의 노란부분은 강아지라는 근거
-
why_cat, why_dog를 (16,16) $\to$ (512,512) 로 resize
fig, ax = plt.subplots(1,3,figsize=(8,4))
ax[0].imshow(torch.einsum('ocij -> ijc',dls.decode((x,))[0]).to("cpu"))
ax[1].imshow(why_cat.to("cpu").detach(),cmap='magma',extent=(0,511,511,0),interpolation='bilinear')
ax[2].imshow(why_dog.to("cpu").detach(),cmap='magma',extent=(0,511,511,0),interpolation='bilinear')
-
겹쳐그리기
fig, ax = plt.subplots(1,2,figsize=(8,4))
ax[0].imshow(torch.einsum('ocij -> ijc',dls.decode((x,))[0]).to("cpu"))
ax[0].imshow(why_cat.to("cpu").detach(),cmap='magma',extent=(0,511,511,0),interpolation='bilinear',alpha=0.5)
ax[1].imshow(torch.einsum('ocij -> ijc',dls.decode((x,))[0]).to("cpu"))
ax[1].imshow(why_dog.to("cpu").detach(),cmap='magma',extent=(0,511,511,0),interpolation='bilinear',alpha=0.5)
-
하니이미지 시각화
!wget https://github.com/guebin/DL2022/blob/master/_notebooks/2022-09-06-hani01.jpeg
#!wget https://github.com/guebin/DL2022/blob/master/_notebooks/2022-09-06-hani01.jpeg?raw=true
ximg= PILImage.create('2022-09-07-dogs.jpeg')
x= first(dls.test_dl([ximg]))[0]
why = torch.einsum('cb,abij->acij',net2[2].weight,net1(x))
why_cat = why[0,0,:,:]
why_dog = why[0,1,:,:]
fig, ax = plt.subplots(1,2,figsize=(8,4))
ax[0].imshow(torch.einsum('ocij -> ijc',dls.decode((x,))[0]).to("cpu"))
ax[0].imshow(why_cat.to("cpu").detach(),cmap='magma',extent=(0,511,511,0),interpolation='bilinear',alpha=0.5)
ax[1].imshow(torch.einsum('ocij -> ijc',dls.decode((x,))[0]).to("cpu"))
ax[1].imshow(why_dog.to("cpu").detach(),cmap='magma',extent=(0,511,511,0),interpolation='bilinear',alpha=0.5)
-
하니이미지 시각화 with prob
sftmax=torch.nn.Softmax(dim=1)
sftmax(net(x))
catprob, dogprob = sftmax(net(x))[0,0].item(), sftmax(net(x))[0,1].item()
fig, ax = plt.subplots(1,2,figsize=(8,4))
ax[0].imshow(torch.einsum('ocij -> ijc',dls.decode((x,))[0]).to("cpu"))
ax[0].imshow(why_cat.to("cpu").detach(),cmap='magma',extent=(0,511,511,0),interpolation='bilinear',alpha=0.5)
ax[0].set_title('catprob= %f' % catprob)
ax[1].imshow(torch.einsum('ocij -> ijc',dls.decode((x,))[0]).to("cpu"))
ax[1].imshow(why_dog.to("cpu").detach(),cmap='magma',extent=(0,511,511,0),interpolation='bilinear',alpha=0.5)
ax[1].set_title('dogprob=%f' % dogprob)
sftmax = torch.nn.Softmax(dim=1)
fig, ax = plt.subplots(5,5)
k=0
for i in range(5):
for j in range(5):
x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
why = torch.einsum('cb,abij -> acij', net2[2].weight, net1(x))
why_cat = why[0,0,:,:]
why_dog = why[0,1,:,:]
catprob, dogprob = sftmax(net(x))[0][0].item(), sftmax(net(x))[0][1].item()
if catprob>dogprob:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(why_cat.to("cpu").detach(),alpha=0.5,extent=(0,511,511,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("cat(%2f)" % catprob)
else:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(why_dog.to("cpu").detach(),alpha=0.5,extent=(0,511,511,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("dog(%2f)" % dogprob)
k=k+1
fig.set_figwidth(16)
fig.set_figheight(16)
fig.tight_layout()
fig, ax = plt.subplots(5,5)
k=25
for i in range(5):
for j in range(5):
x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
why = torch.einsum('cb,abij -> acij', net2[2].weight, net1(x))
why_cat = why[0,0,:,:]
why_dog = why[0,1,:,:]
catprob, dogprob = sftmax(net(x))[0][0].item(), sftmax(net(x))[0][1].item()
if catprob>dogprob:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(why_cat.to("cpu").detach(),alpha=0.5,extent=(0,511,511,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("cat(%2f)" % catprob)
else:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(why_dog.to("cpu").detach(),alpha=0.5,extent=(0,511,511,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("dog(%2f)" % dogprob)
k=k+1
fig.set_figwidth(16)
fig.set_figheight(16)
fig.tight_layout()
fig, ax = plt.subplots(5,5)
k=50
for i in range(5):
for j in range(5):
x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
why = torch.einsum('cb,abij -> acij', net2[2].weight, net1(x))
why_cat = why[0,0,:,:]
why_dog = why[0,1,:,:]
catprob, dogprob = sftmax(net(x))[0][0].item(), sftmax(net(x))[0][1].item()
if catprob>dogprob:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(why_cat.to("cpu").detach(),alpha=0.5,extent=(0,511,511,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("cat(%2f)" % catprob)
else:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(why_dog.to("cpu").detach(),alpha=0.5,extent=(0,511,511,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("dog(%2f)" % dogprob)
k=k+1
fig.set_figwidth(16)
fig.set_figheight(16)
fig.tight_layout()
fig, ax = plt.subplots(5,5)
k=75
for i in range(5):
for j in range(5):
x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
why = torch.einsum('cb,abij -> acij', net2[2].weight, net1(x))
why_cat = why[0,0,:,:]
why_dog = why[0,1,:,:]
catprob, dogprob = sftmax(net(x))[0][0].item(), sftmax(net(x))[0][1].item()
if catprob>dogprob:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(why_cat.to("cpu").detach(),alpha=0.5,extent=(0,511,511,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("cat(%2f)" % catprob)
else:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(why_dog.to("cpu").detach(),alpha=0.5,extent=(0,511,511,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("dog(%2f)" % dogprob)
k=k+1
fig.set_figwidth(16)
fig.set_figheight(16)
fig.tight_layout()