import itstgcnsnd
import torch
import itstgcnsnd.planner
Padalme GSO_st
ITSTGCN
edit
from torch_geometric_temporal.dataset import PedalMeDatasetLoader
= PedalMeDatasetLoader() loader2
import numpy as np
from torch_geometric_temporal.nn.recurrent import GConvGRU
import copy
import torch_geometric_temporal
import torch.nn.functional as F
from rpy2.robjects.vectors import FloatVector
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
import rpy2.robjects.numpy2ri as rpyn
= importr('GNAR') # import GNAR
GNAR #igraph = importr('igraph') # import igraph
= importr('EbayesThresh').ebayesthresh ebayesthresh
def flatten_weight(T,N,ws,wt):
= np.eye(N,N)
Is = [[0]*T for t in range(T)]
lst for i in range(T):
for j in range(T):
if i==j:
= ws
lst[i][j] elif abs(i-j)==1:
= Is
lst[i][j] else:
= Is*0
lst[i][j] return np.concatenate([np.concatenate(l,axis=1) for l in lst],axis=0) # TN*TN matrix
def make_Psi(T,N,edge_index,edge_weight):
= np.zeros((T,T))
wt for i in range(T):
for j in range(T):
if i==j :
= 0
wt[i,j] elif np.abs(i-j) <= 1 :
= 1
wt[i,j] = np.zeros((N,N))
ws for i in range(N):
for j in range(edge_weight.shape[0]):
if edge_index[0][j] == i :
1][j]] = edge_weight[j]
ws[i,edge_index[= flatten_weight(T,N,ws,wt) # TN*TN matrix
W = np.array(W.sum(axis=1))
d = np.diag(d)
D = np.array(np.diag(1/np.sqrt(d)) @ (D-W) @ np.diag(1/np.sqrt(d)))
L = np.linalg.eigh(L)
lamb, Psi return Psi # TN*TN matrix
def trim(f,edge_index,edge_weight):
= np.array(f)
f if len(f.shape)==1: f = f.reshape(-1,1)
= f.shape # f = T*N matrix
T,N = make_Psi(T,N,edge_index,edge_weight) # TN*TN matrix
Psi = Psi.T @ f.reshape(-1,1) # TN*TN X TN*1 matrix = TN*1 matrix
fbar = np.stack([ebayesthresh(FloatVector(fbar.reshape(-1,N)[:,i])) for i in range(N)],axis=1)
fbar_threshed = Psi @ fbar_threshed.reshape(-1,1) # inverse dft
fhat_flatten = fhat_flatten.reshape(-1,N)
fhat return fhat
def update_from_freq_domain(signal, missing_index,edge_index,edge_weight):
= np.array(signal)
signal = signal.shape
T,N = trim(signal,edge_index,edge_weight)
signal_trimed for i in range(N):
try:
= signal_trimed[missing_index[i],i]
signal[missing_index[i],i] except:
pass
return signal
class StgcnLearner:
def __init__(self,train_dataset,dataset_name = None):
self.train_dataset = train_dataset
self.lags = torch.tensor(train_dataset.features).shape[-1]
self.dataset_name = str(train_dataset) if dataset_name is None else dataset_name
self.mindex= getattr(self.train_dataset,'mindex',None)
self.mrate_eachnode = getattr(self.train_dataset,'mrate_eachnode',0)
self.mrate_total = getattr(self.train_dataset,'mrate_total',0)
self.mtype = getattr(self.train_dataset,'mtype',None)
self.interpolation_method = getattr(self.train_dataset,'interpolation_method',None)
self.method = 'STGCN'
def learn(self,filters=32,epoch=50):
self.model = RecurrentGCN(node_features=self.lags, filters=filters)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)
self.model.train()
for e in range(epoch):
for t, snapshot in enumerate(self.train_dataset):
= self.model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
yt_hat = torch.mean((yt_hat-snapshot.y)**2)
cost
cost.backward()self.optimizer.step()
self.optimizer.zero_grad()
print('{}/{}'.format(e+1,epoch),end='\r')
# recording HP
self.nof_filters = filters
self.epochs = epoch+1
def __call__(self,dataset):
= torch.tensor(dataset.features).float()
X = torch.tensor(dataset.targets).float()
y = torch.stack([self.model(snapshot.x, snapshot.edge_index, snapshot.edge_attr) for snapshot in dataset]).detach().squeeze().float()
yhat return {'X':X, 'y':y, 'yhat':yhat}
class ITStgcnLearner(StgcnLearner):
def __init__(self,train_dataset,dataset_name = None):
super().__init__(train_dataset)
self.method = 'IT-STGCN'
def learn(self,filters=32,epoch=50):
self.model = RecurrentGCN(node_features=self.lags, filters=filters)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)
self.model.train()
= copy.copy(self.train_dataset)
train_dataset_temp for e in range(epoch):
= convert_train_dataset(train_dataset_temp)
f,lags = update_from_freq_domain(f,self.mindex,self.train_dataset.edge_index,self.train_dataset.edge_weight)
f = f.shape
T,N = {
data_dict_temp 'edges':self.train_dataset.edge_index.T.tolist(),
'node_ids':{'node'+str(i):i for i in range(N)},
'FX':f
}= DatasetLoader(data_dict_temp).get_dataset(lags=self.lags)
train_dataset_temp for t, snapshot in enumerate(train_dataset_temp):
= self.model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
yt_hat = torch.mean((yt_hat-snapshot.y)**2)
cost
cost.backward()self.optimizer.step()
self.optimizer.zero_grad()
print('{}/{}'.format(e+1,epoch),end='\r')
# record
self.nof_filters = filters
self.epochs = epoch+1
def convert_train_dataset(train_dataset):
= torch.tensor(train_dataset.features).shape[-1]
lags = torch.concat([train_dataset[0].x.T,torch.tensor(train_dataset.targets)],axis=0).numpy()
f return f,lags
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features, filters):
super(RecurrentGCN, self).__init__()
self.recurrent = GConvGRU(node_features, filters, 2)
self.linear = torch.nn.Linear(filters, 1)
def forward(self, x, edge_index, edge_weight):
= self.recurrent(x, edge_index, edge_weight)
h = F.relu(h)
h = self.linear(h)
h return h
class DatasetLoader(object):
def __init__(self,data_dict):
self._dataset = data_dict
def _get_edges(self):
self._edges = np.array(self._dataset["edges"]).T
def _get_edge_weights(self):
self._edge_weights = np.ones(self._edges.shape[1])
def _get_targets_and_features(self):
= np.array(self._dataset["FX"])
stacked_target self.features = [
+ self.lags, :].T
stacked_target[i : i for i in range(stacked_target.shape[0] - self.lags)
]self.targets = [
+ self.lags, :].T
stacked_target[i for i in range(stacked_target.shape[0] - self.lags)
]
def get_dataset(self, lags: int = 4) -> torch_geometric_temporal.signal.StaticGraphTemporalSignal:
"""Returning the Chickenpox Hungary data iterator.
Args types:
* **lags** *(int)* - The number of time lags.
Return types:
* **dataset** *(torch_geometric_temporal.signal.StaticGraphTemporalSignal)* - The Chickenpox Hungary dataset.
"""
self.lags = lags
self._get_edges()
self._get_edge_weights()
self._get_targets_and_features()
= torch_geometric_temporal.signal.StaticGraphTemporalSignal(
dataset self._edges, self._edge_weights, self.features, self.targets
)return dataset
class Evaluator:
def __init__(self,learner,train_dataset,test_dataset):
self.learner = learner
# self.learner.model.eval()
try:self.learner.model.eval()
except:pass
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.lags = self.learner.lags
= self.learner(self.train_dataset)
rslt_tr = self.learner(self.test_dataset)
rslt_test self.X_tr = rslt_tr['X']
self.y_tr = rslt_tr['y']
self.f_tr = torch.concat([self.train_dataset[0].x.T,self.y_tr],axis=0).float()
self.yhat_tr = rslt_tr['yhat']
self.fhat_tr = torch.concat([self.train_dataset[0].x.T,self.yhat_tr],axis=0).float()
self.X_test = rslt_test['X']
self.y_test = rslt_test['y']
self.f_test = self.y_test
self.yhat_test = rslt_test['yhat']
self.fhat_test = self.yhat_test
self.f = torch.concat([self.f_tr,self.f_test],axis=0)
self.fhat = torch.concat([self.fhat_tr,self.fhat_test],axis=0)
def calculate_mse(self):
= ((self.y_test - self.y_test.mean(axis=0).reshape(-1,self.y_test.shape[-1]))**2).mean(axis=0).tolist()
test_base_mse_eachnode = ((self.y_test - self.y_test.mean(axis=0).reshape(-1,self.y_test.shape[-1]))**2).mean().item()
test_base_mse_total = ((self.y_tr-self.yhat_tr)**2).mean(axis=0).tolist()
train_mse_eachnode = ((self.y_tr-self.yhat_tr)**2).mean().item()
train_mse_total = ((self.y_test-self.yhat_test)**2).mean(axis=0).tolist()
test_mse_eachnode = ((self.y_test-self.yhat_test)**2).mean().item()
test_mse_total self.mse = {'train': {'each_node': train_mse_eachnode, 'total': train_mse_total},
'test': {'each_node': test_mse_eachnode, 'total': test_mse_total},
'test(base)': {'each_node': test_base_mse_eachnode, 'total': test_base_mse_total},
}def _plot(self,*args,t=None,h=2.5,max_node=5,**kwargs):
= self.f.shape
T,N if t is None: t = range(T)
= plt.figure()
fig = max(min(N,max_node),2)
nof_axs if min(N,max_node)<2:
print('max_node should be >=2')
= fig.subplots(nof_axs ,1)
ax for n in range(nof_axs):
self.f[:,n],color='gray',*args,**kwargs)
ax[n].plot(t,'node='+str(n))
ax[n].set_title(*h)
fig.set_figheight(nof_axs
fig.tight_layout()
plt.close()return fig
def plot(self,*args,t=None,h=2.5,**kwargs):
self.calculate_mse()
= self._plot(*args,t=None,h=2.5,**kwargs)
fig = fig.get_axes()
ax for i,a in enumerate(ax):
= self.mse['train']['each_node'][i]
_mse1= self.mse['test']['each_node'][i]
_mse2= self.mse['test(base)']['each_node'][i]
_mse3= self.learner.mrate_eachnode if set(dir(self.learner.mrate_eachnode)) & {'__getitem__'} == set() else self.learner.mrate_eachnode[i]
_mrate = 'node{0}, mrate: {1:.2f}% \n mse(train) = {2:.2f}, mse(test) = {3:.2f}, mse(test_base) = {4:.2f}'.format(i,_mrate*100,_mse1,_mse2,_mse3)
_title
a.set_title(_title)= self.lags
_t1 = self.yhat_tr.shape[0]+self.lags
_t2 = len(self.f)
_t3 range(_t1,_t2),self.yhat_tr[:,i],label='fitted (train)',color='C0')
a.plot(range(_t2,_t3),self.yhat_test[:,i],label='fitted (test)',color='C1')
a.plot(
a.legend()= self.mse['train']['total']
_mse1= self.mse['test']['total']
_mse2= self.mse['test(base)']['total']
_mse3=\
_title 'dataset: {0} \n method: {1} \n mrate: {2:.2f}% \n interpolation:{3} \n epochs={4} \n number of filters={5} \n lags = {6} \n mse(train) = {7:.2f}, mse(test) = {8:.2f}, mse(test_base) = {9:.2f} \n'.\
format(self.learner.dataset_name,self.learner.method,self.learner.mrate_total*100,self.learner.interpolation_method,self.learner.epochs,self.learner.nof_filters,self.learner.lags,_mse1,_mse2,_mse3)
fig.suptitle(_title)
fig.tight_layout()return fig
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
= ChickenpoxDatasetLoader() loader1
= loader2.get_dataset(lags=1) a
= torch_geometric_temporal.signal.temporal_signal_split(a, train_ratio=0.8) train_dataset, test_dataset
= itstgcnsnd.rand_mindex(train_dataset,mrate=0.9)
mindex = itstgcnsnd.miss(train_dataset,mindex,mtype='rand')
dataset_miss = itstgcnsnd.padding(dataset_miss,imputation_method='linear') # padding(train_dataset_miss,method='linear'와 같음) dataset_padded
= ITStgcnLearner(dataset_padded) lrnr
=5) lrnr.learn(epoch
5/5
= Evaluator(lrnr,train_dataset,test_dataset) ev
import matplotlib.pyplot as plt
= ev.plot('--.',h=5,max_node=5,label='complete data',alpha=0.5)
fig 20)
fig.set_figwidth(20)
fig.set_figheight(
fig.tight_layout() fig
random
= {
plans_stgcn_rand 'max_iteration': 30,
'method': ['STGCN', 'IT-STGCN'],
'mrate': [0.3,0.6],
'lags': [4],
'nof_filters': [12],
'inter_method': ['linear','nearest'],
'epoch': [50]
}
= itstgcnsnd.planner.PLNR_STGCN_RAND(plans_stgcn_rand,loader2,dataset_name='pedalme') plnr
plnr.simulate()
1/30 is done
2/30 is done
3/30 is done
4/30 is done
5/30 is done
6/30 is done
7/30 is done
8/30 is done
9/30 is done
10/30 is done
11/30 is done
12/30 is done
13/30 is done
14/30 is done
15/30 is done
16/30 is done
17/30 is done
18/30 is done
19/30 is done
20/30 is done
21/30 is done
22/30 is done
23/30 is done
24/30 is done
25/30 is done
26/30 is done
27/30 is done
28/30 is done
29/30 is done
30/30 is done
All results are stored in ./simulation_results/2023-07-02_07-01-12.csv
block
= [[] for _ in range(15)] #pedalme
my_list = list(range(10,25))
another_list 1] = another_list
my_list[3] = another_list
my_list[4] = another_list
my_list[5] = another_list
my_list[= list(range(5,20))
another_list 7] = another_list
my_list[9] = another_list
my_list[10] = another_list
my_list[11] = another_list
my_list[= my_list mindex
# mindex= [[],[],[],list(range(50,150)),[]] # node 1
# mindex= [list(range(10,100)),[],list(range(50,80)),[],[]] # node 2
# mindex= [list(range(10,100)),[],list(range(50,80)),list(range(50,150)),[]] # node3
= {
plans_stgcn_block 'max_iteration': 30,
'method': ['STGCN', 'IT-STGCN'],
'mindex': [mindex],
'lags': [4],
'nof_filters': [12],
'inter_method': ['linear','nearest'],
'epoch': [50]
}
= itstgcnsnd.planner.PLNR_STGCN_MANUAL(plans_stgcn_block,loader2,dataset_name='pedalme')
plnr =mindex,mtype='block') plnr.simulate(mindex
1/30 is done
2/30 is done
3/30 is done
4/30 is done
5/30 is done
6/30 is done
7/30 is done
8/30 is done
9/30 is done
10/30 is done
11/30 is done
12/30 is done
13/30 is done
14/30 is done
15/30 is done
16/30 is done
17/30 is done
18/30 is done
19/30 is done
20/30 is done
21/30 is done
22/30 is done
23/30 is done
24/30 is done
25/30 is done
26/30 is done
27/30 is done
28/30 is done
29/30 is done
30/30 is done
All results are stored in ./simulation_results/2023-07-02_07-19-21.csv
# df1 = pd.read_csv('./simulation_results/2023-04-13_20-37-59.csv')
# data = pd.concat([df1],axis=0);data