EbayesThresh torch Python화 활용법

Author

SEOYEON CHOI

Published

July 31, 2024

Install

!pip install git+https://github.com/seoyeonc/ebayesthresh_torch.git
Collecting git+https://github.com/seoyeonc/ebayesthresh_torch.git
  Cloning https://github.com/seoyeonc/ebayesthresh_torch.git to /tmp/pip-req-build-2yr1qt5n
  Running command git clone --filter=blob:none --quiet https://github.com/seoyeonc/ebayesthresh_torch.git /tmp/pip-req-build-2yr1qt5n
  Resolved https://github.com/seoyeonc/ebayesthresh_torch.git to commit 8d7a32e5ed482c3091f6124f1842496e8703048e
  Preparing metadata (setup.py) ... done
Requirement already satisfied: torch==2.0.1 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from ebayesthresh-torch==0.0.1) (2.0.1)
Requirement already satisfied: scipy==1.10.1 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from ebayesthresh-torch==0.0.1) (1.10.1)
Requirement already satisfied: statsmodels==0.14.0 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from ebayesthresh-torch==0.0.1) (0.14.0)
Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from scipy==1.10.1->ebayesthresh-torch==0.0.1) (1.23.5)
Requirement already satisfied: packaging>=21.3 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from statsmodels==0.14.0->ebayesthresh-torch==0.0.1) (23.0)
Requirement already satisfied: pandas>=1.0 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from statsmodels==0.14.0->ebayesthresh-torch==0.0.1) (2.0.1)
Requirement already satisfied: patsy>=0.5.2 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from statsmodels==0.14.0->ebayesthresh-torch==0.0.1) (0.5.3)
Requirement already satisfied: filelock in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from torch==2.0.1->ebayesthresh-torch==0.0.1) (3.9.0)
Requirement already satisfied: typing-extensions in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from torch==2.0.1->ebayesthresh-torch==0.0.1) (4.5.0)
Requirement already satisfied: sympy in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from torch==2.0.1->ebayesthresh-torch==0.0.1) (1.11.1)
Requirement already satisfied: networkx in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from torch==2.0.1->ebayesthresh-torch==0.0.1) (2.8.4)
Requirement already satisfied: jinja2 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from torch==2.0.1->ebayesthresh-torch==0.0.1) (3.1.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from pandas>=1.0->statsmodels==0.14.0->ebayesthresh-torch==0.0.1) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from pandas>=1.0->statsmodels==0.14.0->ebayesthresh-torch==0.0.1) (2023.3)
Requirement already satisfied: tzdata>=2022.1 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from pandas>=1.0->statsmodels==0.14.0->ebayesthresh-torch==0.0.1) (2023.3)
Requirement already satisfied: six in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from patsy>=0.5.2->statsmodels==0.14.0->ebayesthresh-torch==0.0.1) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from jinja2->torch==2.0.1->ebayesthresh-torch==0.0.1) (2.1.1)
Requirement already satisfied: mpmath>=0.19 in /home/csy/anaconda3/envs/temp_csy/lib/python3.8/site-packages (from sympy->torch==2.0.1->ebayesthresh-torch==0.0.1) (1.2.1)

Import

import ebayesthresh_torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

예제를 위해 필요한 함수 정의

def make_Psi(T):
    W = torch.zeros((T,T))
    for i in range(T):
        for j in range(T):
            if i==j :
                W[i,j] = 0
            elif torch.abs(torch.tensor(i - j)) <= 1 : 
                W[i,j] = 1
    d = W.sum(dim=1)
    D = torch.diag(d)
    L = torch.diag(1/torch.sqrt(d)) @ (D-W) @ torch.diag(1/torch.sqrt(d))
    lamb, Psi = torch.linalg.eigh(L)
    return Psi

Example

T = 100
x = np.arange(T)/T * 10
y_true = 3*np.sin(0.5*x) + 1.2*np.sin(1.0*x) + 0.5*np.sin(1.2*x) 
y = y_true + np.random.normal(size=T)
plt.figure(figsize=(10,6))
plt.plot(x,y,'o')
plt.plot(x,y_true,'--')

f = np.array(y)
if len(f.shape)==1: f = f.reshape(-1,1)
T,N = f.shape
Psi = make_Psi(T)
fbar = Psi.T @ f # apply dft 
fbar_threshed = ebayesthresh_torch.ebayesthresh(fbar[:,0])
/home/csy/Dropbox/sy_hub/posts/1_Note/ebayesthresh_torch/utils.py:73: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  if torch.isnan(torch.tensor(sdev)):
plt.figure(figsize=(10,6))
plt.plot((fbar**2)) # periodogram 
plt.plot((fbar_threshed**2)) 

plt.figure(figsize=(10,6))
plt.plot((fbar**2)[20:80]) # periodogram 
plt.plot((fbar_threshed**2)[20:80]) 

yhat = Psi @ fbar_threshed.float() # inverse dft
plt.figure(figsize=(10,6))
plt.plot(x,y,'.')
plt.plot(x,y_true,'--')
plt.plot(x,yhat)

사용자 함수 정의

class ebayesthresh_nn(torch.nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.bayesfac = kwargs.get("bayesfac",True)
        self.verbose = kwargs.get('verbose', True)
        self.threshrule = kwargs.get('threshrule', 'median')
        self.universalthresh = kwargs.get('universalthresh', True)
        self.stabadjustment = kwargs.get('stabadjustment', None)
        self.prior = kwargs.get('prior', 'laplace')
        self.bayesfac = kwargs.get('bayesfac', False) 
        #--#
        self.a = torch.tensor(0.1,requires_grad=True); self.a.data = torch.tensor(kwargs.get('a', 0.5)).float()
        self.sdev = torch.tensor(0.1,requires_grad=True); self.sdev =  torch.tensor(kwargs.get('sdev', 0.5)).float()
        #prior="laplace", a = 0.5, bayesfac = False, sdev = None, verbose = True, threshrule = "median", universalthresh = True, stabadjustment = None
    def forward(self,x):
        out = ebayesthresh_torch.ebayesthresh(
            x,
            self.prior, 
            a=self.a, 
            bayesfac=self.bayesfac, 
            sdev=self.sdev, 
            verbose=self.verbose, 
            threshrule=self.threshrule, 
            universalthresh=self.universalthresh, 
            stabadjustment=self.stabadjustment
        )
        self.muhat = out['muhat']
        self.a = out['a']
        self.sdev = out['sdev']
        return self.muhat

데이터

np.random.seed(111)
T = 100
x = np.arange(T)/T * 10
ytrue = 3*np.sin(0.5*x) + 1.2*np.sin(1.0*x) + 0.5*np.sin(1.2*x) 
noise  = np.random.normal(size=T)*0.7
y = ytrue + noise
# plt.figure(figsize=(10,6))
# plt.plot(y,'.',color='r')
# plt.plot(ytrue,'-',color='b')

레이어 정의

thresh_layer = ebayesthresh_nn()

Fourier Transform

Psi = make_Psi(T)
ybar = Psi.T @ y

Learn

power_threshed = thresh_layer(ybar)
/home/csy/Dropbox/sy_hub/posts/1_Note/ebayesthresh_torch/utils.py:68: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  else: sdev = torch.tensor(sdev*1.0,requires_grad=True)
power_threshed
tensor([  6.606373838818255, -18.459911752393730,  -5.548808386322293,
          0.336519592469161,  -7.753497824494824,   2.122283659415806,
         -2.837627123184061,  -0.000000000000000,   0.000000000000000,
         -0.000000000000000,   1.390184407168864,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.000000000000000,
         -0.000000000000000,   0.781773544538979,  -0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.000000000000000,
         -0.632339497985928,   0.000000000000000,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,   0.000000000000000,
          0.000000000000000,  -0.000000000000000,   0.000000000000000,
         -0.791569520349330,   0.000000000000000,  -0.696915272959604,
         -0.000000000000000,  -1.430598198373583,   0.000000000000000,
          0.942973224278802,   0.000000000000000,  -0.000000000000000,
          0.000000000000000,  -0.000000000000000,  -0.000000000000000,
          0.000000000000000,  -0.000000000000000,  -0.000000000000000,
          0.000000000000000,  -0.000000000000000,  -0.000000000000000,
         -0.000000000000000,   0.000000000000000,   0.000000000000000,
         -0.864876475773310,  -0.000000000000000,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.000000000000000,
         -0.000000000000000,   0.000000000000000,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.000000000000000,
          0.000000000000000,  -0.000000000000000,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.134906442023211,
          0.000000000000000,   0.000000000000000,  -0.000000000000000,
         -0.000000000000000,   0.000000000000000,   0.000000000000000,
          0.077974319687101,   0.000000000000000,   0.000000000000000,
          0.000000000000000,   0.000000000000000,   0.000000000000000,
          0.000000000000000,   0.000000000000000,   0.374434118750229,
          1.342773817293552,   0.060728811672864,   0.000000000000000,
          1.244851651799166,  -0.000000000000000,   0.000000000000000,
          0.000000000000000,  -0.000000000000000,  -0.000000000000000,
          0.000000000000000,   0.000000000000000,   0.000000000000000,
          0.430651230343156,   0.000000000000000,  -0.000000000000000,
         -0.000000000000000], dtype=torch.float64, grad_fn=<MulBackward0>)
thresh_layer.a
tensor(0.500000000000000, requires_grad=True)
thresh_layer.sdev
tensor(0.500000000000000, requires_grad=True)
gradient = torch.ones(power_threshed.shape)
power_threshed.backward(gradient=gradient)
power_threshed
tensor([  6.606373838818255, -18.459911752393730,  -5.548808386322293,
          0.336519592469161,  -7.753497824494824,   2.122283659415806,
         -2.837627123184061,  -0.000000000000000,   0.000000000000000,
         -0.000000000000000,   1.390184407168864,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.000000000000000,
         -0.000000000000000,   0.781773544538979,  -0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.000000000000000,
         -0.632339497985928,   0.000000000000000,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,   0.000000000000000,
          0.000000000000000,  -0.000000000000000,   0.000000000000000,
         -0.791569520349330,   0.000000000000000,  -0.696915272959604,
         -0.000000000000000,  -1.430598198373583,   0.000000000000000,
          0.942973224278802,   0.000000000000000,  -0.000000000000000,
          0.000000000000000,  -0.000000000000000,  -0.000000000000000,
          0.000000000000000,  -0.000000000000000,  -0.000000000000000,
          0.000000000000000,  -0.000000000000000,  -0.000000000000000,
         -0.000000000000000,   0.000000000000000,   0.000000000000000,
         -0.864876475773310,  -0.000000000000000,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.000000000000000,
         -0.000000000000000,   0.000000000000000,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.000000000000000,
          0.000000000000000,  -0.000000000000000,   0.000000000000000,
         -0.000000000000000,  -0.000000000000000,  -0.134906442023211,
          0.000000000000000,   0.000000000000000,  -0.000000000000000,
         -0.000000000000000,   0.000000000000000,   0.000000000000000,
          0.077974319687101,   0.000000000000000,   0.000000000000000,
          0.000000000000000,   0.000000000000000,   0.000000000000000,
          0.000000000000000,   0.000000000000000,   0.374434118750229,
          1.342773817293552,   0.060728811672864,   0.000000000000000,
          1.244851651799166,  -0.000000000000000,   0.000000000000000,
          0.000000000000000,  -0.000000000000000,  -0.000000000000000,
          0.000000000000000,   0.000000000000000,   0.000000000000000,
          0.430651230343156,   0.000000000000000,  -0.000000000000000,
         -0.000000000000000], dtype=torch.float64, grad_fn=<MulBackward0>)
thresh_layer.a
tensor(0.500000000000000, requires_grad=True)
thresh_layer.sdev
tensor(0.500000000000000, requires_grad=True)

왜 gradient가 필요한가? 비스칼라 텐서에 대해 Tensor.backward()를 호출할 때, 기울기를 계산하기 위해 명시적인 gradient 값을 필요로 합니다. 이는 미분할 함수의 기울기 방향과 크기를 정의하는 것입니다. 비유: 수학적으로 기울기를 계산할 때, 우리는 어떤 기준값을 가지고 기울기를 측정합니다. gradient는 이러한 기준값을 제공하여, PyTorch가 적절한 기울기를 계산할 수 있도록 돕습니다. 예시 없이 설명하기 스칼라 텐서: x = torch.tensor(3.0, requires_grad=True)와 같은 경우, x.backward()만 호출하면 기울기가 자동으로 계산됩니다. 왜냐하면 x는 단일 값이므로, 기울기를 계산할 기준이 명확합니다. 비스칼라 텐서: x = torch.randn(2, 2, requires_grad=True)와 같은 경우, x.backward()를 호출하려면 gradient 파라미터를 제공해야 합니다. 이 경우, gradient는 x의 모양과 같은 2x2 텐서여야 하며, 각 요소의 기울기 초기 값을 제공합니다. -GPT

  • gradient (Tensor, optional) – The gradient of the function being differentiated w.r.t. self. This argument can be omitted if self is a scalar.공식문서