import rpy2
import rpy2.robjects as ro
from rpy2.robjects.vectors import FloatVector
from rpy2.robjects.packages import importr
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import torch
import itstgcn_gb as itstgcn
import random
def save_data(data_dict,fname):
with open(fname,'wb') as outfile:
pickle.dump(data_dict,outfile)
def load_data(fname):
with open(fname, 'rb') as outfile:
data_dict = pickle.load(outfile)
return data_dict
def toy_analyze(FX,edges,lags,train_ratio,mrate,filters,epoch,mtype):
data_dict = {'edges':edges, 'node_ids':{i:'node'+str(i) for i in range(FX.shape[-1])}, 'FX':FX}
save_data(data_dict,'toy_ex_dataset.pkl')
data_dict = load_data('toy_ex_dataset.pkl')
loader = itstgcn.DatasetLoader(data_dict)
dataset = loader.get_dataset(lags=lags)
train_dataset, test_dataset = itstgcn.temporal_signal_split(dataset, train_ratio=train_ratio)
mindex_rand = itstgcn.rand_mindex(train_dataset,mrate=mrate)
train_dataset_miss_rand = itstgcn.miss(train_dataset,mindex_rand,mtype=mtype)
train_dataset_padded_rand = itstgcn.padding(train_dataset_miss_rand) # padding(train_dataset_miss,method='linear'와 같음)
lrnr_classic = itstgcn.StgcnLearner(train_dataset_padded_rand)
lrnr_proposed = itstgcn.ITStgcnLearner(train_dataset_padded_rand)
lrnr_classic.learn(filters=filters,epoch=epoch)
lrnr_proposed.learn(filters=filters,epoch=epoch)
yhat_classic=lrnr_classic(dataset)['yhat']
yhat_proposed=lrnr_proposed(dataset)['yhat']
return yhat_classic,yhat_proposed