import torch
from torch_geometric_temporal.signal import temporal_signal_split
temporal_signal_split 확인
import matplotlib.pyplot as plt
import pandas as pd
from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
= WikiMathsDatasetLoader().get_dataset()
dataset = temporal_signal_split(dataset, train_ratio=0.8) train_dataset, test_dataset
dataset.snapshot_count
723
train_dataset.snapshot_count
578
test_dataset.snapshot_count
145
0].shape dataset.features[
(1068, 8)
0].shape train_dataset.features[
(1068, 8)
0].shape test_dataset.features[
(1068, 8)
= [snapshot.y.mean().item() for snapshot in dataset]
mean_cases = [snapshot.y.std().item() for snapshot in dataset]
std_cases = pd.DataFrame(mean_cases, columns=['mean'])
df 'std'] = pd.DataFrame(std_cases, columns=['std'])
df[
=(10,5))
plt.figure(figsize'mean'], 'k-', label='Mean')
plt.plot(df[=':')
plt.grid(linestyle'mean']-df['std'], df['mean']+df['std'], color='r', alpha=0.1)
plt.fill_between(df.index, df[=578, color='b', linestyle='--')
plt.axvline(x578, 1.5, 'Train/test split', rotation=-90, color='b')
plt.text('Time (days)')
plt.xlabel('Normalized number of visits')
plt.ylabel(='upper right') plt.legend(loc
= [snapshot.y.mean().item() for snapshot in train_dataset]
mean_cases = [snapshot.y.std().item() for snapshot in train_dataset]
std_cases = pd.DataFrame(mean_cases, columns=['mean'])
df 'std'] = pd.DataFrame(std_cases, columns=['std'])
df[=(10,5))
plt.figure(figsize'mean'], 'k-', label='Mean')
plt.plot(df[=':')
plt.grid(linestyle'mean']-df['std'], df['mean']+df['std'], color='r', alpha=0.1)
plt.fill_between(df.index, df[=578, color='b', linestyle='--')
plt.axvline(x# plt.text(360, 1.5, 'Train/test split', rotation=-90, color='b')
'Time (days)')
plt.xlabel('Normalized number of visits')
plt.ylabel(='upper right') plt.legend(loc
= [snapshot.y.mean().item() for snapshot in test_dataset]
mean_cases = [snapshot.y.std().item() for snapshot in test_dataset]
std_cases = pd.DataFrame(mean_cases, columns=['mean'])
df 'std'] = pd.DataFrame(std_cases, columns=['std'])
df[
=(10,5))
plt.figure(figsize'mean'], 'k-', label='Mean')
plt.plot(df[=':')
plt.grid(linestyle'mean']-df['std'], df['mean']+df['std'], color='r', alpha=0.1)
plt.fill_between(df.index, df[=0, color='b', linestyle='--')
plt.axvline(x# plt.text(360, 1.5, 'Train/test split', rotation=-90, color='b')
'Time (days)')
plt.xlabel('Normalized number of visits')
plt.ylabel(='upper right') plt.legend(loc