[CAM]chest xray

Author

SEOYEON CHOI

Published

September 21, 2023

Import

import torch 
from fastai.vision.all import * 
import cv2 as cv
import fastbook
from fastbook import *
from fastai.vision.widgets import *
import os

Data

refer : https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia

# path=Path('./home/Dropbox/chest_xray/chest_xray') 
path = Path(os.path.expanduser(os.path.join('~', 'Dropbox/chest_xray/chest_xray')))
path.ls()
(#3) [Path('/home/csy/Dropbox/chest_xray/chest_xray/train'),Path('/home/csy/Dropbox/chest_xray/chest_xray/test'),Path('/home/csy/Dropbox/chest_xray/chest_xray/val')]
files=get_image_files(path)
dls = ImageDataLoaders.from_folder(path, train='train', valid_pct=0.2, item_tfms=Resize(224))      
dls.vocab
['NORMAL', 'PNEUMONIA']
dls.show_batch(max_n=16)

learn=cnn_learner(dls,resnet34,metrics=error_rate)
/home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages/fastai/vision/learner.py:288: UserWarning: `cnn_learner` has been renamed to `vision_learner` -- please update your code
  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")
/home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
net1=learn.model[0]
net2=learn.model[1] 
net2 = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d(output_size=1), 
    torch.nn.Flatten(),
    torch.nn.Linear(512,out_features=2,bias=False))
net=torch.nn.Sequential(net1,net2)
lrnr2=Learner(dls,net,metrics=accuracy) 
lrnr2.fine_tune(200) 
epoch train_loss valid_loss accuracy time
0 0.256840 0.114656 0.955594 08:44
33.00% [66/200 8:25:10<17:05:40]
epoch train_loss valid_loss accuracy time
0 0.116520 0.101848 0.963279 09:36
1 0.102428 0.094126 0.967549 08:12
2 0.093699 0.093717 0.967549 08:54
3 0.086143 0.088825 0.970111 09:29
4 0.081806 0.083427 0.972673 07:39
5 0.073343 0.083292 0.970965 08:30
6 0.067238 0.081730 0.971819 09:19
7 0.060490 0.080178 0.970965 07:47
8 0.057268 0.082641 0.971819 07:59
9 0.051426 0.083972 0.970965 09:19
10 0.048123 0.085828 0.968403 08:25
11 0.048666 0.088471 0.966695 07:34
12 0.039608 0.081649 0.970965 09:34
13 0.039144 0.082900 0.969257 08:33
14 0.036003 0.088085 0.969257 07:21
15 0.033820 0.090371 0.966695 09:41
16 0.030702 0.090631 0.965841 08:42
17 0.030634 0.088019 0.968403 06:44
18 0.028083 0.086056 0.974381 09:44
19 0.026960 0.093619 0.966695 09:01
20 0.027269 0.101308 0.966695 06:51
21 0.022599 0.092235 0.967549 10:03
22 0.020182 0.101367 0.966695 09:07
23 0.018661 0.107479 0.966695 06:53
24 0.022132 0.090607 0.971819 09:42
25 0.020702 0.119097 0.965841 09:13
26 0.018335 0.100486 0.970965 07:02
27 0.023388 0.091187 0.974381 09:31
28 0.014129 0.109383 0.971819 09:16
29 0.011590 0.091540 0.973527 07:19
30 0.011195 0.112240 0.969257 09:14
31 0.010830 0.123103 0.970111 09:16
32 0.011482 0.111904 0.973527 07:22
33 0.009102 0.113118 0.973527 09:06
34 0.012337 0.114353 0.970111 07:51
35 0.012090 0.123892 0.970965 05:32
36 0.012958 0.130399 0.972673 05:38
37 0.011067 0.154380 0.965841 08:05
38 0.012189 0.106678 0.973527 06:58
39 0.013642 0.094777 0.977797 04:06
40 0.013831 0.136403 0.970111 08:22
41 0.010596 0.125548 0.974381 07:40
42 0.011968 0.122372 0.974381 04:49
43 0.010313 0.140959 0.971819 07:01
44 0.015604 0.116439 0.972673 07:46
45 0.009300 0.109000 0.975235 05:58
46 0.010223 0.114949 0.975235 05:03
47 0.010521 0.112844 0.975235 08:26
48 0.010516 0.120698 0.978651 07:13
49 0.006499 0.103925 0.980359 04:13
50 0.015288 0.140839 0.973527 08:11
51 0.009214 0.115853 0.975235 07:46
52 0.010889 0.124385 0.971819 05:01
53 0.010107 0.112582 0.978651 06:47
54 0.008590 0.108547 0.976943 07:51
55 0.012988 0.102842 0.978651 06:07
56 0.013263 0.112827 0.979505 04:47
57 0.009377 0.111573 0.977797 08:15
58 0.008087 0.126913 0.976089 07:22
59 0.006141 0.105021 0.981213 04:23
60 0.009329 0.121833 0.975235 07:45
61 0.010449 0.128219 0.978651 07:51
62 0.011542 0.107680 0.979505 05:15
63 0.009637 0.151910 0.967549 06:17
64 0.012696 0.154750 0.972673 08:00
65 0.008207 0.125963 0.978651 06:21

10.53% [2/19 00:11<01:40 0.0056]
lrnr2.fine_tune(100) 
0.00% [0/1 00:00<?]
epoch train_loss valid_loss accuracy time

49.32% [36/73 03:03<03:08 0.0001]
interp = ClassificationInterpretation.from_learner(lrnr2)
interp.plot_confusion_matrix()

fig, ax = plt.subplots(5,5) 
k=0 
for i in range(5):
    for j in range(5): 
        x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
        camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
        a,b = net(x).tolist()[0]
        normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) ,  np.exp(b)/ (np.exp(a)+np.exp(b)) 
        if normalprob>pneumoniaprob: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("normal(%s)" % normalprob.round(5))
        else: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
        k=k+1 
fig.set_figwidth(16)            
fig.set_figheight(16)
fig.tight_layout()

특정 이미지 select

img = PILImage.create(get_image_files(path)[304])
img

x, = first(dls.test_dl([img]))  #이미지 텐서화
a=net(x).tolist()[0][0]
b=net(x).tolist()[0][1]
np.exp(a)/(np.exp(a)+np.exp(b)), np.exp(b)/(np.exp(a)+np.exp(b))
(1.54300621731869e-12, 0.9999999999984569)
camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
fig, (ax1,ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

보라색 일 수록 특징이 강함을 의미..

test=camimg[1]-torch.min(camimg[1])
A1=torch.exp(-0.04*test)
A2=1-A1
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("X1 WEIGHT WITH THETA=0.04")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A1.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("X1 RES WEIGHT WITH THETA=0.04")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

#mode1_res*x
X1=np.array(A1.to("cpu").detach(),dtype=np.float32)
Y1=torch.Tensor(cv.resize(X1,(224,224),interpolation=cv.INTER_LINEAR))
x1=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*0.7*Y1))*0.4
#mode1*x
X12=np.array(A2.to("cpu").detach(),dtype=np.float32)
Y12=torch.Tensor(cv.resize(X1,(224,224),interpolation=cv.INTER_LINEAR))
x12=x.squeeze().to('cpu')*Y12*3

1st cam 결과 분리

fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x12.squeeze().show(ax=ax1)  #MODE1
x1.squeeze().show(ax=ax2)  #MODE1_res
ax1.set_title("X1")
ax2.set_title("X1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

x1=x1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
ver2 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x1).squeeze())
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x1.squeeze().show(ax=ax1)
ax1.imshow(ver2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x1.squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(6.664552246309979e-19, 1.0)

\(\theta\) 생각, hyperparameter로서..

test1=ver2[0]-torch.min(ver2[0])
A3=torch.exp(-0.04*test1)  
A4=1-A3
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A3.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE2 WEIGHT WITH THETA=0.04")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A4.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE2 RES WEIGHT WITH THETA=0.04")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

#mode2_res
X3=np.array(A3.to("cpu").detach(),dtype=np.float32)
Y3=torch.Tensor(cv.resize(X3,(224,224),interpolation=cv.INTER_LINEAR))
x3=x.squeeze().to('cpu')*Y1*Y3-torch.min(x.squeeze().to('cpu')*Y1*Y3)
#mode1*x
X4=np.array(A4.to("cpu").detach(),dtype=np.float32)
Y4=torch.Tensor(cv.resize(X4,(224,224),interpolation=cv.INTER_LINEAR))
x4=x.squeeze().to('cpu')*Y1*Y4

2nd 분리 결과

fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x12.squeeze().show(ax=ax1)  
x1.squeeze().show(ax=ax2)  
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x4.squeeze().show(ax=ax1)  
x3.squeeze().show(ax=ax2)  
ax1.set_title("MODE2")
ax2.set_title("MODE2 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

x3=x3.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
ver22 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x3).squeeze())

CAM

fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x3.squeeze().show(ax=ax1)
ax1.imshow(ver22[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x3.squeeze().show(ax=ax2)
ax2.imshow(ver22[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

fig, (ax1,ax2, ax3) = plt.subplots(1,3) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax3)
ax3.imshow(ver22[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax3.set_title("3RD CAM")
#

fig.set_figwidth(12)            
fig.set_figheight(12)
fig.tight_layout()

a2=net(x3).tolist()[0][0]
b2=net(x3).tolist()[0][1]
np.exp(a2)/(np.exp(a2)+np.exp(b2)), np.exp(b2)/(np.exp(a2)+np.exp(b2))
(3.322455317236289e-16, 0.9999999999999998)

another 특정 그림 select

img = PILImage.create(get_image_files(path)[3031])
img

철심

x, = first(dls.test_dl([img]))  #이미지 텐서화
x=x.to('cpu')
a=net(x).tolist()[0][0]
b=net(x).tolist()[0][1]
np.exp(a)/(np.exp(a)+np.exp(b)), np.exp(b)/(np.exp(a)+np.exp(b))
(2.994815878500848e-17, 1.0)
camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
fig, (ax1,ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

철심 있는 부분 디텍팅

test=camimg[0]-torch.min(camimg[0])
test1=camimg[1]-torch.min(camimg[1])
A1=torch.exp(-0.05*test)
A2=1-A1
A11=torch.exp(-0.05*test1)
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("X1 WEIGHT WITH THETA=0.05")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A11.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("X1 RES WEIGHT WITH THETA=0.05")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

#mode1_res*x
X1=np.array(A11.to("cpu").detach(),dtype=np.float32)
Y1=torch.Tensor(cv.resize(X1,(224,224),interpolation=cv.INTER_LINEAR))
#x1=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*0.7*Y1))*0.4
x1=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1))*0.18
#mode1그림을 위한 mode1_res*x
X_1=np.array(A1.to("cpu").detach(),dtype=np.float32)
Y_1=torch.Tensor(cv.resize(X1,(224,224),interpolation=cv.INTER_LINEAR))
#x1=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*0.7*Y1))*0.4
x_1=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)*0.05
#mode1*x
X12=np.array(A2.to("cpu").detach(),dtype=np.float32)
Y12=torch.Tensor(cv.resize(X_1,(224,224),interpolation=cv.INTER_LINEAR))
x12=x.squeeze().to('cpu')*Y12*0.3

1st cam 결과

fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x12.squeeze().show(ax=ax1)  #MODE1
x1.squeeze().show(ax=ax2)  #MODE1_res
ax1.set_title("X1")
ax2.set_title("X1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

x1=x1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)

2차

ver2 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x1).squeeze())

CAM

fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x1.squeeze().show(ax=ax1)
ax1.imshow(ver2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x1.squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(3.013152213595138e-16, 0.9999999999999997)
test=ver2[0]-torch.min(ver2[0])
test1=ver2[1]-torch.min(ver2[1])
A3=torch.exp(-0.08*test)  
A4=1-A3
A33 = torch.exp(-0.08*test1)
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A4.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("X2 WEIGHT WITH THETA=0.08")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A33.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("X2 RES WEIGHT WITH THETA=0.08")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

#mode2_res
X3=np.array(A33.to("cpu").detach(),dtype=np.float32)
Y3=torch.Tensor(cv.resize(X3,(224,224),interpolation=cv.INTER_LINEAR))
x3=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1))*0.2*Y3

#x1=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)*0.03
#mode1그림을 위한 mode2_res*x
X_3=np.array(A3.to("cpu").detach(),dtype=np.float32)
Y_3=torch.Tensor(cv.resize(X_3,(224,224),interpolation=cv.INTER_LINEAR))
x_3=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1))*Y3
#mode2*x
X4=np.array(A4.to("cpu").detach(),dtype=np.float32)
Y4=torch.Tensor(cv.resize(X_3,(224,224),interpolation=cv.INTER_LINEAR))
x4=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1))*Y4*0.2

2nd cam 결과 분리

fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x12.squeeze().show(ax=ax1)  
x1.squeeze().show(ax=ax2)  
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x4.squeeze().show(ax=ax1)  
x3.squeeze().show(ax=ax2)  
ax1.set_title("X2")
ax2.set_title("X2 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

x3=x3.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
ver22 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x3).squeeze())

CAM

fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x3.squeeze().show(ax=ax1)
ax1.imshow(ver22[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x3.squeeze().show(ax=ax2)
ax2.imshow(ver22[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

fig, (ax1,ax2, ax3) = plt.subplots(1,3) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax3)
ax3.imshow(ver22[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax3.set_title("3RD CAM")
#

fig.set_figwidth(12)            
fig.set_figheight(12)
fig.tight_layout()

a2=net(x3).tolist()[0][0]
b2=net(x3).tolist()[0][1]
np.exp(a2)/(np.exp(a2)+np.exp(b2)), np.exp(b2)/(np.exp(a2)+np.exp(b2))
(1.1537635027871649e-15, 0.9999999999999989)

other 다른 그림

img = PILImage.create(get_image_files(path)[3107])
img

x, = first(dls.test_dl([img]))  #이미지 텐서화
x=x.to('cpu')
a=net(x).tolist()[0][0]
b=net(x).tolist()[0][1]
np.exp(a)/(np.exp(a)+np.exp(b)), np.exp(b)/(np.exp(a)+np.exp(b))
(6.990148118894787e-19, 1.0)
camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
fig, (ax1,ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

test=camimg[0]-torch.min(camimg[0])
test1=camimg[1]-torch.min(camimg[1])
A1=torch.exp(-0.05*test)
A2=1-A1
A11=torch.exp(-0.05*test1)
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("X1 WEIGHT WITH THETA=0.05")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A11.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("X1 RES WEIGHT WITH THETA=0.05")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

#mode1_res*x
X1=np.array(A11.to("cpu").detach(),dtype=np.float32)
Y1=torch.Tensor(cv.resize(X1,(224,224),interpolation=cv.INTER_LINEAR))
#x1=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*0.7*Y1))*0.4
x1=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)*0.02
#mode1그림을 위한 mode1_res*x
X_1=np.array(A1.to("cpu").detach(),dtype=np.float32)
Y_1=torch.Tensor(cv.resize(X1,(224,224),interpolation=cv.INTER_LINEAR))
#x1=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*0.7*Y1))*0.4
x_1=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)*0.05
#mode1*x
X12=np.array(A2.to("cpu").detach(),dtype=np.float32)
Y12=torch.Tensor(cv.resize(X_1,(224,224),interpolation=cv.INTER_LINEAR))
x12=x.squeeze().to('cpu')*Y12*0.3

1st cam 결과 분리

fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x12.squeeze().show(ax=ax1)  #MODE1
x1.squeeze().show(ax=ax2)  #MODE1_res
ax1.set_title("X1")
ax2.set_title("X1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

x1=x1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)

2nd cam 분리

ver2 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x1).squeeze())

cam

fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x1.squeeze().show(ax=ax1)
ax1.imshow(ver2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x1.squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(1.4950733419675e-20, 1.0)
test=ver2[0]-torch.min(ver2[0])
test1=ver2[1]-torch.min(ver2[1])
A3=torch.exp(-0.1*test)  
A4=1-A3
A33 = torch.exp(-0.1*test1)
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A4.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("X2 WEIGHT WITH THETA=0.1")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A33.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("X2 RES WEIGHT WITH THETA=0.1")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

#mode2_res
X3=np.array(A33.to("cpu").detach(),dtype=np.float32)
Y3=torch.Tensor(cv.resize(X3,(224,224),interpolation=cv.INTER_LINEAR))
x3=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1))*0.3*Y3

#x1=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)*0.03
#mode1그림을 위한 mode2_res*x
X_3=np.array(A3.to("cpu").detach(),dtype=np.float32)
Y_3=torch.Tensor(cv.resize(X_3,(224,224),interpolation=cv.INTER_LINEAR))
x_3=(x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1))*Y3
#mode2*x
X4=np.array(A4.to("cpu").detach(),dtype=np.float32)
Y4=torch.Tensor(cv.resize(X_3,(224,224),interpolation=cv.INTER_LINEAR))
x4=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)*Y4*0.05

2nd cam 결과 분리

fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x12.squeeze().show(ax=ax1)  
x1.squeeze().show(ax=ax2)  
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
x4.squeeze().show(ax=ax1)  
x3.squeeze().show(ax=ax2)  
ax1.set_title("MODE2")
ax2.set_title("MODE2 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

x3=x3.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
ver22 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x3).squeeze())

cam

fig, (ax1,ax2) = plt.subplots(1,2) 
# 
x3.squeeze().show(ax=ax1)
ax1.imshow(ver22[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
x3.squeeze().show(ax=ax2)
ax2.imshow(ver22[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()

fig, (ax1,ax2, ax3) = plt.subplots(1,3) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(ver2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax3)
ax3.imshow(ver22[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax3.set_title("3RD CAM")
#

fig.set_figwidth(12)            
fig.set_figheight(12)
fig.tight_layout()

a2=net(x3).tolist()[0][0]
b2=net(x3).tolist()[0][1]
np.exp(a2)/(np.exp(a2)+np.exp(b2)), np.exp(b2)/(np.exp(a2)+np.exp(b2))
(4.3886825515670436e-18, 1.0)