with plt.style.context('seaborn-white'):
fig, ((ax1, ax2), (ax3, ax4), (ax5,ax6)) = plt.subplots(3, 2,figsize=(40,20))
# fig.suptitle('Figure 1(node 1)',fontsize=40)
ax1.plot(a_ar_values_true,label='Ground Truth')
ax1.legend(fontsize=20,loc='lower left',facecolor='white', frameon=True)
ax2.plot(np.array(data_dict['FX'])[:,0],'-',color='C3',label='Complete Data')
ax2.legend(fontsize=20,loc='lower left',facecolor='white', frameon=True)
ax2.tick_params(axis='y', labelsize=20)
ax2.tick_params(axis='x', labelsize=20)
ax3.plot(np.array(data_dict['FX'])[:,0],'--',color='C5',alpha=0.5,label='Complete Data')
ax3.plot(torch.cat([torch.tensor(dataset.features)[:2,0,0],torch.tensor(train_dataset_miss_rand.targets).reshape(-1,2)[:,0]],dim=0),'--o',color='C3',label='Observed Data')
ax3.legend(fontsize=20,loc='lower left',facecolor='white', frameon=True)
ax3.tick_params(axis='y', labelsize=20)
ax3.tick_params(axis='x', labelsize=20)
ax4.plot(np.array(data_dict['FX'])[:,0],'--',color='C5',alpha=0.5,label='Complete Data')
ax4.plot(torch.cat([torch.tensor(dataset.features)[:2,0,0],torch.tensor(train_dataset_padded_rand.targets).reshape(-1,2)[:,0]],dim=0),'--o',color='C3',alpha=0.8,label='Linear Interpolation')
ax4.legend(fontsize=20,loc='lower left',facecolor='white', frameon=True)
ax4.tick_params(axis='y', labelsize=20)
ax4.tick_params(axis='x', labelsize=20)
ax5.plot(torch.tensor(data_dict['FX'])[:400,0],'--',color='C5',alpha=0.5,label='Complete Data')
ax5.plot(a_ar_values_true[:400],color='black',label='Ground Truth',lw=3)
ax5.plot(evtor_rand.fhat_tr[:,0],color='brown',lw=3,label='GConvGRU',alpha=0.5)
ax5.plot(evtor_rand_it.fhat_tr[:,0],color='blue',lw=3,label='IT-TGNN',alpha=0.5)
# ax4.plot(55, 0, 'o', markersize=100, markerfacecolor='none', markeredgecolor='red',markeredgewidth=3)
# ax4.plot(150, 0, 'o', markersize=80, markerfacecolor='none', markeredgecolor='red',markeredgewidth=3)
# ax4.plot(185, 0, 'o', markersize=80, markerfacecolor='none', markeredgecolor='red',markeredgewidth=3)
ax5.legend(fontsize=20,loc='lower left',facecolor='white', frameon=True)
ax5.tick_params(axis='y', labelsize=20)
ax5.tick_params(axis='x', labelsize=20)
ax6.plot(torch.tensor(data_dict['FX'])[400:,0],'--',color='C5',alpha=0.5,label='Test')
ax6.plot(a_ar_values_true[400:],color='black',label='Ground Truth',lw=3)
ax6.plot(evtor_rand.fhat_test[:,0],color='brown',lw=3,label='GConvGRU',alpha=0.5)
ax6.plot(evtor_rand_it.fhat_test[:,0],color='blue',lw=3,label='IT-TGNN',alpha=0.5)
ax6.legend(fontsize=20,loc='lower left',facecolor='white', frameon=True)
ax6.tick_params(axis='y', labelsize=20)
ax6.tick_params(axis='x', labelsize=20)
# # plt.savefig('try2_node1.png')