import rpy2
import rpy2.robjects as ro
from rpy2.robjects.vectors import FloatVector
from rpy2.robjects.packages import importr
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import torch
import itstgcn_gb as itstgcn
import random
def save_data(data_dict,fname):
with open(fname,'wb') as outfile:
pickle.dump(data_dict,outfile)
def load_data(fname):
with open(fname, 'rb') as outfile:
= pickle.load(outfile)
data_dict return data_dict
def toy_analyze(FX,edges,lags,train_ratio,mrate,filters,epoch,mtype):
= {'edges':edges, 'node_ids':{i:'node'+str(i) for i in range(FX.shape[-1])}, 'FX':FX}
data_dict 'toy_ex_dataset.pkl')
save_data(data_dict,= load_data('toy_ex_dataset.pkl')
data_dict = itstgcn.DatasetLoader(data_dict)
loader = loader.get_dataset(lags=lags)
dataset = itstgcn.temporal_signal_split(dataset, train_ratio=train_ratio)
train_dataset, test_dataset = itstgcn.rand_mindex(train_dataset,mrate=mrate)
mindex_rand = itstgcn.miss(train_dataset,mindex_rand,mtype=mtype)
train_dataset_miss_rand = itstgcn.padding(train_dataset_miss_rand) # padding(train_dataset_miss,method='linear'와 같음)
train_dataset_padded_rand = itstgcn.StgcnLearner(train_dataset_padded_rand)
lrnr_classic = itstgcn.ITStgcnLearner(train_dataset_padded_rand)
lrnr_proposed =filters,epoch=epoch)
lrnr_classic.learn(filters=filters,epoch=epoch)
lrnr_proposed.learn(filters=lrnr_classic(dataset)['yhat']
yhat_classic=lrnr_proposed(dataset)['yhat']
yhat_proposedreturn yhat_classic,yhat_proposed
Toy example using GNAR
STGCN
Import
Data
= 50
T = np.linspace(0,np.pi*2,T)
t = np.random.randn(T)*0.1
e = np.cos(2*t)
y1 = np.cos(3*t)
y2 = y1+y2+e
y3 = np.stack([y1,y2,y3],axis=1)
y = y.shape
_, N = 0.9
train_ratio = int(T*(1-train_ratio))
test_len = T - test_len
tr_len = [False]*tr_len + [True]*test_len
test_index = [[0,2],[1,2]]
edges = 8
lags = 0.8
mrate = 2
filters = 50
epoch = 'rand'
mtype ###
= toy_analyze(y,edges,lags,train_ratio,mrate,filters,epoch,mtype) yhat_classic,yhat_proposed
50/50
= 2
node 'figure.dpi'] = 200
plt.rcParams['.')
plt.plot(y[lags:,node],'--',label='classic')
plt.plot(yhat_classic[:,node],'--',label='proposed')
plt.plot(yhat_proposed[:,node],0]+y[lags:,1],'-',label='true')
plt.plot(y[lags:, plt.legend()