Import

import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
2022-06-08 07:38:20.945738: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-06-08 07:38:20.945762: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

데이터 구성


p = 0.3
papers = pd.concat([pd.DataFrame(np.array([[p]*1500,[1-p]*1500]).reshape(1000,3),columns=['X1','X2','X3']),
           pd.DataFrame(np.array([[1-p]*1500,[p]*1500]).reshape(1000,3),columns=['X4','X5','X6']),
          pd.DataFrame(np.array([['Deep learning']*500,['Reinforcement learning']*500]).reshape(1000,1))],axis=1).reset_index().rename(columns={'index':'paper_id',0:'subject'})
papers['paper_id'] = papers['paper_id']+1
#시도1: 2500 행 모두 target/source 상관없이 1~1000 임의 부여
citations = pd.DataFrame(np.array([[np.random.choice(range(1,1001),size=(2500,1))],
                                   [np.random.choice(range(1,1001),size=(2500,1))]]).reshape(2500,2)).rename(columns = {0:'target',1:'source'})
#시도2: 나머지 1250 행 reinforcement learning 행 500~1000를 target, source에 임의 부여
citations =  pd.concat([pd.DataFrame(np.array([np.random.choice(range(1,501),size=(1250,1)),np.random.choice(range(1,501),size=(1250,1))]).reshape(1250,2)),
                          pd.DataFrame(np.array([np.random.choice(range(501,1001),size=(1250,1)),np.random.choice(range(501,1001),size=(1250,1))]).reshape(1250,2))],ignore_index=True).rename(columns={0:'target',1:'source'})

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train.shape, y_train.shape, x_test.shape, y_test.shape
((60000, 28, 28), (60000,), (10000, 28, 28), (10000,))
x_train, x_test = x_train[..., np.newaxis]/255.0, x_test[..., np.newaxis]/255.0

def filter_36(x, y): keep = (y == 3) | (y == 7) x, y = x[keep], y[keep] y = y == 3 return x,y

x_train, y_train = filter_36(x_train, y_train) x_test, y_test = filter_36(x_test, y_test)

X= x_train.reshape(-1,784)/255
y = y_train
#y = list(map(lambda x: 0 if x == True else 1,y_train))
#XX = x_test.reshape(-1,784)/255
#yy = list(map(lambda x: 0 if x == True else 1,y_test))

y가 3이면 0

y가 7이면 1로

add_list = []
add_list.append(X)
papers = pd.concat([pd.DataFrame(np.array(add_list).reshape(-1,784)),pd.DataFrame(np.array(y).reshape(-1,1))],axis=1).reset_index().iloc[:4999]
column_names = ["paper_id"] + [f"X_{idx}" for idx in range(1,785)] + ["subject"]
papers.columns = column_names
papers['paper_id'] = papers['paper_id'] + 1
papers
paper_id X_1 X_2 X_3 X_4 X_5 X_6 X_7 X_8 X_9 ... X_776 X_777 X_778 X_779 X_780 X_781 X_782 X_783 X_784 subject
0 1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 5
1 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0
2 3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4
3 4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
4 5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 9
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
4994 4995 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0
4995 4996 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 7
4996 4997 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 3
4997 4998 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2
4998 4999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1

4999 rows × 786 columns

_a = []
for i in range(1,len(papers)+1):
    for j in range(1,len(papers)+1):
        _a.append([i])
        _a.append([j])
citations = pd.DataFrame(np.array(_a).reshape(-1,2)).rename(columns = {0:'target',1:'source'})
citations
target source
0 1 1
1 1 2
2 1 3
3 1 4
4 1 5
... ... ...
24989996 4999 4995
24989997 4999 4996
24989998 4999 4997
24989999 4999 4998
24990000 4999 4999

24990001 rows × 2 columns

그래프 표현

class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
plt.figure(figsize=(10, 10))
colors = papers["subject"].tolist()
cora_graph = nx.from_pandas_edgelist(citations.sample(n=800))
subjects = list(papers[papers["paper_id"].isin(list(cora_graph.nodes))]["subject"])
nx.draw_spring(cora_graph, node_size=15, node_color=subjects)

Test vs Train

train_data, test_data = [], []

for _, group_data in papers.groupby("subject"):
    # Select around 50% of the dataset for training.
    random_selection = np.random.rand(len(group_data.index)) <= 0.5
    train_data.append(group_data[random_selection])
    test_data.append(group_data[~random_selection])

train_data = pd.concat(train_data).sample(frac=1)
test_data = pd.concat(test_data).sample(frac=1)

print("Train data shape:", train_data.shape)
print("Test data shape:", test_data.shape)
Train data shape: (2497, 786)
Test data shape: (2502, 786)
hidden_units = [32,32]
learning_rate = 0.01
dropout_rate = 0.5
num_epochs = 300
batch_size = 256
def run_experiment(model, x_train, y_train):
    # Compile the model.
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
    )
    # Create an early stopping callback.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_acc", patience=50, restore_best_weights=True
    )
    # Fit the model.
    history = model.fit(
        x=x_train,
        y=y_train,
        epochs=num_epochs,
        batch_size=batch_size,
        validation_split=0.15,
        callbacks=[early_stopping],
    )

    return history
def display_learning_curves(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    ax1.plot(history.history["loss"])
    ax1.plot(history.history["val_loss"])
    ax1.legend(["train", "test"], loc="upper right")
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")

    ax2.plot(history.history["acc"])
    ax2.plot(history.history["val_acc"])
    ax2.legend(["train", "test"], loc="upper right")
    ax2.set_xlabel("Epochs")
    ax2.set_ylabel("Accuracy")
    plt.show()
def create_ffn(hidden_units, dropout_rate, name=None):
    fnn_layers = []

    for units in hidden_units:
        fnn_layers.append(layers.BatchNormalization())
        fnn_layers.append(layers.Dropout(dropout_rate))
        fnn_layers.append(layers.Dense(units, activation=tf.nn.gelu))

    return keras.Sequential(fnn_layers, name=name)
feature_names = set(papers.columns) - {"paper_id", "subject"}
num_features = len(feature_names)
num_classes = len(class_idx)

# Create train and test features as a numpy array.
x_train = train_data[feature_names].to_numpy()
x_test = test_data[feature_names].to_numpy()
# Create train and test targets as a numpy array.
y_train = train_data["subject"]
y_test = test_data["subject"]
def create_baseline_model(hidden_units, num_classes, dropout_rate=0.2):
    inputs = layers.Input(shape=(num_features,), name="input_features")
    x = create_ffn(hidden_units, dropout_rate, name=f"ffn_block1")(inputs)
    for block_idx in range(4):
        # Create an FFN block.
        x1 = create_ffn(hidden_units, dropout_rate, name=f"ffn_block{block_idx + 2}")(x)
        # Add skip connection.
        x = layers.Add(name=f"skip_connection{block_idx + 2}")([x, x1])
    # Compute logits.
    logits = layers.Dense(num_classes, name="logits")(x)
    # Create the model.
    return keras.Model(inputs=inputs, outputs=logits, name="baseline")


baseline_model = create_baseline_model(hidden_units, num_classes, dropout_rate)
baseline_model.summary()
Model: "baseline"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_features (InputLayer)    [(None, 784)]        0           []                               
                                                                                                  
 ffn_block1 (Sequential)        (None, 32)           29440       ['input_features[0][0]']         
                                                                                                  
 ffn_block2 (Sequential)        (None, 32)           2368        ['ffn_block1[0][0]']             
                                                                                                  
 skip_connection2 (Add)         (None, 32)           0           ['ffn_block1[0][0]',             
                                                                  'ffn_block2[0][0]']             
                                                                                                  
 ffn_block3 (Sequential)        (None, 32)           2368        ['skip_connection2[0][0]']       
                                                                                                  
 skip_connection3 (Add)         (None, 32)           0           ['skip_connection2[0][0]',       
                                                                  'ffn_block3[0][0]']             
                                                                                                  
 ffn_block4 (Sequential)        (None, 32)           2368        ['skip_connection3[0][0]']       
                                                                                                  
 skip_connection4 (Add)         (None, 32)           0           ['skip_connection3[0][0]',       
                                                                  'ffn_block4[0][0]']             
                                                                                                  
 ffn_block5 (Sequential)        (None, 32)           2368        ['skip_connection4[0][0]']       
                                                                                                  
 skip_connection5 (Add)         (None, 32)           0           ['skip_connection4[0][0]',       
                                                                  'ffn_block5[0][0]']             
                                                                                                  
 logits (Dense)                 (None, 10)           330         ['skip_connection5[0][0]']       
                                                                                                  
==================================================================================================
Total params: 39,242
Trainable params: 37,098
Non-trainable params: 2,144
__________________________________________________________________________________________________
history = run_experiment(baseline_model, x_train, y_train)
Epoch 1/300
9/9 [==============================] - 2s 47ms/step - loss: 3.0821 - acc: 0.1640 - val_loss: 2.3358 - val_acc: 0.0853
Epoch 2/300
9/9 [==============================] - 0s 19ms/step - loss: 2.0163 - acc: 0.3313 - val_loss: 2.3259 - val_acc: 0.0907
Epoch 3/300
9/9 [==============================] - 0s 21ms/step - loss: 1.6156 - acc: 0.4505 - val_loss: 2.3136 - val_acc: 0.0907
Epoch 4/300
9/9 [==============================] - 0s 21ms/step - loss: 1.3902 - acc: 0.5302 - val_loss: 2.3356 - val_acc: 0.0907
Epoch 5/300
9/9 [==============================] - 0s 22ms/step - loss: 1.2061 - acc: 0.5877 - val_loss: 2.3850 - val_acc: 0.0907
Epoch 6/300
9/9 [==============================] - 0s 22ms/step - loss: 1.0709 - acc: 0.6348 - val_loss: 2.4163 - val_acc: 0.0907
Epoch 7/300
9/9 [==============================] - 0s 21ms/step - loss: 0.9084 - acc: 0.6998 - val_loss: 2.4824 - val_acc: 0.0907
Epoch 8/300
9/9 [==============================] - 0s 22ms/step - loss: 0.8695 - acc: 0.7187 - val_loss: 2.5133 - val_acc: 0.0907
Epoch 9/300
9/9 [==============================] - 0s 21ms/step - loss: 0.8098 - acc: 0.7234 - val_loss: 2.5732 - val_acc: 0.0907
Epoch 10/300
9/9 [==============================] - 0s 21ms/step - loss: 0.7636 - acc: 0.7484 - val_loss: 2.5890 - val_acc: 0.0907
Epoch 11/300
9/9 [==============================] - 0s 21ms/step - loss: 0.7496 - acc: 0.7502 - val_loss: 2.6102 - val_acc: 0.0907
Epoch 12/300
9/9 [==============================] - 0s 22ms/step - loss: 0.7353 - acc: 0.7639 - val_loss: 2.5950 - val_acc: 0.0907
Epoch 13/300
9/9 [==============================] - 0s 21ms/step - loss: 0.6886 - acc: 0.7747 - val_loss: 2.5921 - val_acc: 0.0907
Epoch 14/300
9/9 [==============================] - 0s 21ms/step - loss: 0.6587 - acc: 0.7832 - val_loss: 2.5883 - val_acc: 0.0907
Epoch 15/300
9/9 [==============================] - 0s 21ms/step - loss: 0.6733 - acc: 0.7771 - val_loss: 2.5800 - val_acc: 0.0907
Epoch 16/300
9/9 [==============================] - 0s 20ms/step - loss: 0.6241 - acc: 0.7931 - val_loss: 2.5873 - val_acc: 0.0907
Epoch 17/300
9/9 [==============================] - 0s 20ms/step - loss: 0.6367 - acc: 0.7959 - val_loss: 2.6958 - val_acc: 0.0907
Epoch 18/300
9/9 [==============================] - 0s 21ms/step - loss: 0.5997 - acc: 0.8030 - val_loss: 2.6058 - val_acc: 0.0907
Epoch 19/300
9/9 [==============================] - 0s 21ms/step - loss: 0.6123 - acc: 0.8016 - val_loss: 2.4661 - val_acc: 0.0907
Epoch 20/300
9/9 [==============================] - 0s 20ms/step - loss: 0.5830 - acc: 0.8129 - val_loss: 2.4605 - val_acc: 0.0907
Epoch 21/300
9/9 [==============================] - 0s 20ms/step - loss: 0.5894 - acc: 0.8106 - val_loss: 2.5693 - val_acc: 0.0907
Epoch 22/300
9/9 [==============================] - 0s 20ms/step - loss: 0.5536 - acc: 0.8101 - val_loss: 2.4752 - val_acc: 0.1067
Epoch 23/300
9/9 [==============================] - 0s 21ms/step - loss: 0.5718 - acc: 0.8139 - val_loss: 2.4138 - val_acc: 0.1147
Epoch 24/300
9/9 [==============================] - 0s 22ms/step - loss: 0.5555 - acc: 0.8091 - val_loss: 2.4457 - val_acc: 0.1653
Epoch 25/300
9/9 [==============================] - 0s 21ms/step - loss: 0.5322 - acc: 0.8228 - val_loss: 2.3710 - val_acc: 0.1707
Epoch 26/300
9/9 [==============================] - 0s 21ms/step - loss: 0.5270 - acc: 0.8308 - val_loss: 2.3896 - val_acc: 0.0933
Epoch 27/300
9/9 [==============================] - 0s 21ms/step - loss: 0.5071 - acc: 0.8303 - val_loss: 2.4367 - val_acc: 0.0933
Epoch 28/300
9/9 [==============================] - 0s 21ms/step - loss: 0.5298 - acc: 0.8280 - val_loss: 2.4682 - val_acc: 0.1280
Epoch 29/300
9/9 [==============================] - 0s 21ms/step - loss: 0.5443 - acc: 0.8228 - val_loss: 2.5354 - val_acc: 0.1893
Epoch 30/300
9/9 [==============================] - 0s 21ms/step - loss: 0.5235 - acc: 0.8247 - val_loss: 2.3308 - val_acc: 0.1920
Epoch 31/300
9/9 [==============================] - 0s 22ms/step - loss: 0.5040 - acc: 0.8252 - val_loss: 2.2851 - val_acc: 0.2000
Epoch 32/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4902 - acc: 0.8384 - val_loss: 2.2351 - val_acc: 0.1760
Epoch 33/300
9/9 [==============================] - 0s 21ms/step - loss: 0.5216 - acc: 0.8365 - val_loss: 2.0888 - val_acc: 0.1760
Epoch 34/300
9/9 [==============================] - 0s 22ms/step - loss: 0.5047 - acc: 0.8332 - val_loss: 1.9205 - val_acc: 0.2267
Epoch 35/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4851 - acc: 0.8511 - val_loss: 2.0260 - val_acc: 0.1493
Epoch 36/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4836 - acc: 0.8374 - val_loss: 2.2177 - val_acc: 0.0987
Epoch 37/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4719 - acc: 0.8360 - val_loss: 1.9455 - val_acc: 0.1493
Epoch 38/300
9/9 [==============================] - 0s 22ms/step - loss: 0.4435 - acc: 0.8525 - val_loss: 1.8354 - val_acc: 0.3653
Epoch 39/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4580 - acc: 0.8577 - val_loss: 2.0415 - val_acc: 0.2880
Epoch 40/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4738 - acc: 0.8445 - val_loss: 1.9407 - val_acc: 0.2453
Epoch 41/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4725 - acc: 0.8487 - val_loss: 1.7966 - val_acc: 0.2613
Epoch 42/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4407 - acc: 0.8582 - val_loss: 1.7663 - val_acc: 0.3013
Epoch 43/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4838 - acc: 0.8544 - val_loss: 1.6597 - val_acc: 0.3413
Epoch 44/300
9/9 [==============================] - 0s 22ms/step - loss: 0.4817 - acc: 0.8431 - val_loss: 1.5467 - val_acc: 0.5333
Epoch 45/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4548 - acc: 0.8516 - val_loss: 1.2719 - val_acc: 0.4773
Epoch 46/300
9/9 [==============================] - 0s 22ms/step - loss: 0.4548 - acc: 0.8506 - val_loss: 1.3264 - val_acc: 0.5360
Epoch 47/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4477 - acc: 0.8544 - val_loss: 1.1974 - val_acc: 0.5973
Epoch 48/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4455 - acc: 0.8563 - val_loss: 1.2033 - val_acc: 0.4800
Epoch 49/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4488 - acc: 0.8539 - val_loss: 1.2504 - val_acc: 0.5120
Epoch 50/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4409 - acc: 0.8534 - val_loss: 1.2395 - val_acc: 0.5013
Epoch 51/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4548 - acc: 0.8497 - val_loss: 1.1209 - val_acc: 0.5360
Epoch 52/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4337 - acc: 0.8525 - val_loss: 1.0529 - val_acc: 0.6027
Epoch 53/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4449 - acc: 0.8586 - val_loss: 1.0501 - val_acc: 0.6213
Epoch 54/300
9/9 [==============================] - 0s 22ms/step - loss: 0.4363 - acc: 0.8619 - val_loss: 1.2335 - val_acc: 0.4373
Epoch 55/300
9/9 [==============================] - 0s 22ms/step - loss: 0.4641 - acc: 0.8459 - val_loss: 0.8597 - val_acc: 0.6720
Epoch 56/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4514 - acc: 0.8520 - val_loss: 0.7166 - val_acc: 0.7413
Epoch 57/300
9/9 [==============================] - 0s 22ms/step - loss: 0.4438 - acc: 0.8459 - val_loss: 0.5369 - val_acc: 0.8587
Epoch 58/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4267 - acc: 0.8582 - val_loss: 0.6229 - val_acc: 0.8107
Epoch 59/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4251 - acc: 0.8563 - val_loss: 1.0375 - val_acc: 0.6027
Epoch 60/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4351 - acc: 0.8577 - val_loss: 1.2214 - val_acc: 0.5627
Epoch 61/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4168 - acc: 0.8652 - val_loss: 0.6775 - val_acc: 0.7840
Epoch 62/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4536 - acc: 0.8464 - val_loss: 0.5785 - val_acc: 0.8027
Epoch 63/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4343 - acc: 0.8582 - val_loss: 0.5882 - val_acc: 0.7840
Epoch 64/300
9/9 [==============================] - 0s 22ms/step - loss: 0.4054 - acc: 0.8615 - val_loss: 0.3738 - val_acc: 0.8933
Epoch 65/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4120 - acc: 0.8680 - val_loss: 0.4552 - val_acc: 0.8560
Epoch 66/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4369 - acc: 0.8464 - val_loss: 0.3934 - val_acc: 0.8667
Epoch 67/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4107 - acc: 0.8605 - val_loss: 0.4369 - val_acc: 0.8560
Epoch 68/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4065 - acc: 0.8737 - val_loss: 0.4085 - val_acc: 0.8667
Epoch 69/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4255 - acc: 0.8558 - val_loss: 0.5107 - val_acc: 0.8133
Epoch 70/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4096 - acc: 0.8657 - val_loss: 0.4550 - val_acc: 0.8400
Epoch 71/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3981 - acc: 0.8652 - val_loss: 0.5169 - val_acc: 0.8267
Epoch 72/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3986 - acc: 0.8690 - val_loss: 0.3604 - val_acc: 0.8773
Epoch 73/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3941 - acc: 0.8737 - val_loss: 0.3677 - val_acc: 0.8720
Epoch 74/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4193 - acc: 0.8737 - val_loss: 0.3457 - val_acc: 0.8693
Epoch 75/300
9/9 [==============================] - 0s 22ms/step - loss: 0.3910 - acc: 0.8695 - val_loss: 0.2987 - val_acc: 0.9120
Epoch 76/300
9/9 [==============================] - 0s 22ms/step - loss: 0.4181 - acc: 0.8685 - val_loss: 0.2827 - val_acc: 0.9147
Epoch 77/300
9/9 [==============================] - 0s 22ms/step - loss: 0.4185 - acc: 0.8629 - val_loss: 0.3489 - val_acc: 0.8827
Epoch 78/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4177 - acc: 0.8638 - val_loss: 0.2914 - val_acc: 0.9013
Epoch 79/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4170 - acc: 0.8615 - val_loss: 0.2853 - val_acc: 0.8987
Epoch 80/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3988 - acc: 0.8794 - val_loss: 0.2741 - val_acc: 0.9013
Epoch 81/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3976 - acc: 0.8676 - val_loss: 0.3185 - val_acc: 0.9093
Epoch 82/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4124 - acc: 0.8633 - val_loss: 0.2895 - val_acc: 0.8933
Epoch 83/300
9/9 [==============================] - 0s 22ms/step - loss: 0.3936 - acc: 0.8713 - val_loss: 0.2568 - val_acc: 0.9200
Epoch 84/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4076 - acc: 0.8586 - val_loss: 0.2750 - val_acc: 0.9173
Epoch 85/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3843 - acc: 0.8732 - val_loss: 0.2418 - val_acc: 0.9253
Epoch 86/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4158 - acc: 0.8615 - val_loss: 0.2593 - val_acc: 0.9093
Epoch 87/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3824 - acc: 0.8690 - val_loss: 0.2700 - val_acc: 0.9120
Epoch 88/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3851 - acc: 0.8732 - val_loss: 0.2723 - val_acc: 0.9093
Epoch 89/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4100 - acc: 0.8558 - val_loss: 0.2459 - val_acc: 0.9173
Epoch 90/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3625 - acc: 0.8878 - val_loss: 0.2614 - val_acc: 0.9093
Epoch 91/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3651 - acc: 0.8789 - val_loss: 0.2495 - val_acc: 0.9147
Epoch 92/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3904 - acc: 0.8775 - val_loss: 0.2503 - val_acc: 0.9093
Epoch 93/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3837 - acc: 0.8685 - val_loss: 0.2632 - val_acc: 0.9093
Epoch 94/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3696 - acc: 0.8779 - val_loss: 0.2899 - val_acc: 0.9093
Epoch 95/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3706 - acc: 0.8831 - val_loss: 0.2643 - val_acc: 0.9120
Epoch 96/300
9/9 [==============================] - 0s 20ms/step - loss: 0.4084 - acc: 0.8666 - val_loss: 0.2622 - val_acc: 0.9013
Epoch 97/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3698 - acc: 0.8798 - val_loss: 0.2632 - val_acc: 0.9093
Epoch 98/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3613 - acc: 0.8822 - val_loss: 0.2618 - val_acc: 0.9200
Epoch 99/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3821 - acc: 0.8779 - val_loss: 0.3220 - val_acc: 0.8907
Epoch 100/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4036 - acc: 0.8629 - val_loss: 0.2790 - val_acc: 0.9120
Epoch 101/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3985 - acc: 0.8657 - val_loss: 0.2644 - val_acc: 0.9200
Epoch 102/300
9/9 [==============================] - 0s 21ms/step - loss: 0.4047 - acc: 0.8680 - val_loss: 0.2613 - val_acc: 0.9147
Epoch 103/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3804 - acc: 0.8779 - val_loss: 0.2698 - val_acc: 0.8933
Epoch 104/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3576 - acc: 0.8812 - val_loss: 0.2700 - val_acc: 0.9200
Epoch 105/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3640 - acc: 0.8770 - val_loss: 0.2693 - val_acc: 0.9173
Epoch 106/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3804 - acc: 0.8775 - val_loss: 0.2445 - val_acc: 0.9067
Epoch 107/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3841 - acc: 0.8765 - val_loss: 0.2966 - val_acc: 0.9040
Epoch 108/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3366 - acc: 0.8836 - val_loss: 0.2943 - val_acc: 0.9013
Epoch 109/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3845 - acc: 0.8794 - val_loss: 0.2730 - val_acc: 0.9067
Epoch 110/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3688 - acc: 0.8779 - val_loss: 0.2763 - val_acc: 0.9147
Epoch 111/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3544 - acc: 0.8850 - val_loss: 0.2652 - val_acc: 0.9147
Epoch 112/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3774 - acc: 0.8761 - val_loss: 0.2848 - val_acc: 0.9093
Epoch 113/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3410 - acc: 0.8808 - val_loss: 0.2674 - val_acc: 0.9093
Epoch 114/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3693 - acc: 0.8737 - val_loss: 0.2793 - val_acc: 0.9013
Epoch 115/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3816 - acc: 0.8756 - val_loss: 0.2708 - val_acc: 0.9040
Epoch 116/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3717 - acc: 0.8718 - val_loss: 0.2807 - val_acc: 0.9040
Epoch 117/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3511 - acc: 0.8916 - val_loss: 0.2606 - val_acc: 0.9013
Epoch 118/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3761 - acc: 0.8709 - val_loss: 0.2898 - val_acc: 0.9040
Epoch 119/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3574 - acc: 0.8869 - val_loss: 0.3016 - val_acc: 0.9093
Epoch 120/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3882 - acc: 0.8704 - val_loss: 0.2950 - val_acc: 0.9147
Epoch 121/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3551 - acc: 0.8794 - val_loss: 0.2728 - val_acc: 0.9067
Epoch 122/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3594 - acc: 0.8775 - val_loss: 0.2788 - val_acc: 0.9013
Epoch 123/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3603 - acc: 0.8803 - val_loss: 0.2354 - val_acc: 0.8987
Epoch 124/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3576 - acc: 0.8831 - val_loss: 0.2927 - val_acc: 0.9040
Epoch 125/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3765 - acc: 0.8794 - val_loss: 0.2843 - val_acc: 0.9067
Epoch 126/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3888 - acc: 0.8718 - val_loss: 0.2683 - val_acc: 0.8907
Epoch 127/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3817 - acc: 0.8742 - val_loss: 0.2592 - val_acc: 0.9093
Epoch 128/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3505 - acc: 0.8817 - val_loss: 0.2601 - val_acc: 0.9147
Epoch 129/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3584 - acc: 0.8798 - val_loss: 0.2728 - val_acc: 0.9093
Epoch 130/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3557 - acc: 0.8779 - val_loss: 0.2635 - val_acc: 0.9093
Epoch 131/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3578 - acc: 0.8869 - val_loss: 0.2562 - val_acc: 0.9147
Epoch 132/300
9/9 [==============================] - 0s 22ms/step - loss: 0.3394 - acc: 0.8944 - val_loss: 0.2271 - val_acc: 0.9333
Epoch 133/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3361 - acc: 0.8869 - val_loss: 0.2303 - val_acc: 0.9200
Epoch 134/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3521 - acc: 0.8869 - val_loss: 0.2395 - val_acc: 0.9173
Epoch 135/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3660 - acc: 0.8812 - val_loss: 0.2418 - val_acc: 0.9120
Epoch 136/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3579 - acc: 0.8841 - val_loss: 0.2415 - val_acc: 0.9147
Epoch 137/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3612 - acc: 0.8803 - val_loss: 0.2421 - val_acc: 0.9227
Epoch 138/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3400 - acc: 0.8841 - val_loss: 0.2492 - val_acc: 0.9093
Epoch 139/300
9/9 [==============================] - 0s 22ms/step - loss: 0.3593 - acc: 0.8812 - val_loss: 0.2642 - val_acc: 0.9093
Epoch 140/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3574 - acc: 0.8789 - val_loss: 0.2576 - val_acc: 0.9120
Epoch 141/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3713 - acc: 0.8765 - val_loss: 0.2405 - val_acc: 0.9093
Epoch 142/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3274 - acc: 0.8893 - val_loss: 0.2596 - val_acc: 0.9013
Epoch 143/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3181 - acc: 0.9010 - val_loss: 0.2612 - val_acc: 0.9147
Epoch 144/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3423 - acc: 0.8916 - val_loss: 0.2976 - val_acc: 0.9067
Epoch 145/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3535 - acc: 0.8822 - val_loss: 0.2632 - val_acc: 0.9200
Epoch 146/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3328 - acc: 0.8940 - val_loss: 0.2414 - val_acc: 0.9200
Epoch 147/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3404 - acc: 0.8883 - val_loss: 0.2552 - val_acc: 0.9173
Epoch 148/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3221 - acc: 0.8982 - val_loss: 0.2356 - val_acc: 0.9093
Epoch 149/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3276 - acc: 0.8959 - val_loss: 0.2744 - val_acc: 0.9173
Epoch 150/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3580 - acc: 0.8794 - val_loss: 0.2918 - val_acc: 0.9147
Epoch 151/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3475 - acc: 0.8878 - val_loss: 0.2327 - val_acc: 0.9200
Epoch 152/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3577 - acc: 0.8855 - val_loss: 0.3062 - val_acc: 0.9040
Epoch 153/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3697 - acc: 0.8812 - val_loss: 0.2250 - val_acc: 0.9253
Epoch 154/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3241 - acc: 0.8893 - val_loss: 0.2475 - val_acc: 0.9147
Epoch 155/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3453 - acc: 0.8831 - val_loss: 0.2654 - val_acc: 0.9253
Epoch 156/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3564 - acc: 0.8855 - val_loss: 0.2437 - val_acc: 0.9227
Epoch 157/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3523 - acc: 0.8784 - val_loss: 0.2472 - val_acc: 0.9200
Epoch 158/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3286 - acc: 0.8944 - val_loss: 0.2896 - val_acc: 0.9040
Epoch 159/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3472 - acc: 0.8869 - val_loss: 0.2849 - val_acc: 0.8960
Epoch 160/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3384 - acc: 0.8803 - val_loss: 0.2578 - val_acc: 0.9067
Epoch 161/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3317 - acc: 0.8874 - val_loss: 0.2244 - val_acc: 0.9200
Epoch 162/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3362 - acc: 0.8845 - val_loss: 0.2440 - val_acc: 0.9333
Epoch 163/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3558 - acc: 0.8822 - val_loss: 0.2416 - val_acc: 0.9173
Epoch 164/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3315 - acc: 0.8963 - val_loss: 0.2518 - val_acc: 0.9173
Epoch 165/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3289 - acc: 0.8874 - val_loss: 0.2296 - val_acc: 0.9200
Epoch 166/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3562 - acc: 0.8765 - val_loss: 0.2271 - val_acc: 0.9333
Epoch 167/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3631 - acc: 0.8822 - val_loss: 0.2592 - val_acc: 0.9147
Epoch 168/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3211 - acc: 0.8921 - val_loss: 0.2306 - val_acc: 0.9173
Epoch 169/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3308 - acc: 0.8855 - val_loss: 0.2195 - val_acc: 0.9307
Epoch 170/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3541 - acc: 0.8761 - val_loss: 0.2356 - val_acc: 0.9173
Epoch 171/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3718 - acc: 0.8803 - val_loss: 0.2689 - val_acc: 0.9093
Epoch 172/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3431 - acc: 0.8902 - val_loss: 0.2397 - val_acc: 0.9200
Epoch 173/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3650 - acc: 0.8869 - val_loss: 0.2080 - val_acc: 0.9227
Epoch 174/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3296 - acc: 0.8883 - val_loss: 0.2373 - val_acc: 0.9227
Epoch 175/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3204 - acc: 0.8944 - val_loss: 0.2214 - val_acc: 0.9253
Epoch 176/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3416 - acc: 0.8888 - val_loss: 0.2356 - val_acc: 0.9253
Epoch 177/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3333 - acc: 0.8878 - val_loss: 0.2546 - val_acc: 0.9093
Epoch 178/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3461 - acc: 0.8987 - val_loss: 0.2298 - val_acc: 0.9227
Epoch 179/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3390 - acc: 0.8845 - val_loss: 0.2409 - val_acc: 0.9253
Epoch 180/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3191 - acc: 0.8954 - val_loss: 0.2431 - val_acc: 0.9173
Epoch 181/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3689 - acc: 0.8822 - val_loss: 0.2143 - val_acc: 0.9413
Epoch 182/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3395 - acc: 0.8944 - val_loss: 0.2207 - val_acc: 0.9227
Epoch 183/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3471 - acc: 0.8869 - val_loss: 0.2430 - val_acc: 0.9227
Epoch 184/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3660 - acc: 0.8803 - val_loss: 0.2273 - val_acc: 0.9227
Epoch 185/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3002 - acc: 0.8897 - val_loss: 0.2480 - val_acc: 0.9173
Epoch 186/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3468 - acc: 0.8841 - val_loss: 0.2469 - val_acc: 0.9200
Epoch 187/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3206 - acc: 0.8944 - val_loss: 0.2276 - val_acc: 0.9067
Epoch 188/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3217 - acc: 0.8930 - val_loss: 0.2584 - val_acc: 0.9173
Epoch 189/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3181 - acc: 0.8907 - val_loss: 0.2609 - val_acc: 0.9173
Epoch 190/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3190 - acc: 0.8911 - val_loss: 0.2420 - val_acc: 0.9280
Epoch 191/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3050 - acc: 0.8968 - val_loss: 0.2682 - val_acc: 0.9253
Epoch 192/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3249 - acc: 0.8926 - val_loss: 0.2110 - val_acc: 0.9387
Epoch 193/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3400 - acc: 0.8860 - val_loss: 0.2425 - val_acc: 0.9253
Epoch 194/300
9/9 [==============================] - 0s 21ms/step - loss: 0.2997 - acc: 0.9006 - val_loss: 0.2155 - val_acc: 0.9173
Epoch 195/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3044 - acc: 0.9020 - val_loss: 0.2238 - val_acc: 0.9227
Epoch 196/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3246 - acc: 0.8940 - val_loss: 0.2113 - val_acc: 0.9333
Epoch 197/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3301 - acc: 0.8954 - val_loss: 0.1922 - val_acc: 0.9280
Epoch 198/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3410 - acc: 0.8874 - val_loss: 0.2287 - val_acc: 0.9307
Epoch 199/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3468 - acc: 0.8784 - val_loss: 0.2327 - val_acc: 0.9227
Epoch 200/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3270 - acc: 0.8907 - val_loss: 0.2470 - val_acc: 0.9227
Epoch 201/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3468 - acc: 0.8916 - val_loss: 0.2391 - val_acc: 0.9333
Epoch 202/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3193 - acc: 0.8916 - val_loss: 0.2317 - val_acc: 0.9093
Epoch 203/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3092 - acc: 0.9057 - val_loss: 0.2428 - val_acc: 0.9280
Epoch 204/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3343 - acc: 0.8911 - val_loss: 0.2396 - val_acc: 0.9253
Epoch 205/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3410 - acc: 0.8907 - val_loss: 0.2445 - val_acc: 0.9093
Epoch 206/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3202 - acc: 0.8916 - val_loss: 0.2723 - val_acc: 0.9120
Epoch 207/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3495 - acc: 0.8850 - val_loss: 0.2387 - val_acc: 0.9253
Epoch 208/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3298 - acc: 0.8954 - val_loss: 0.2421 - val_acc: 0.9280
Epoch 209/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3088 - acc: 0.8954 - val_loss: 0.2493 - val_acc: 0.9280
Epoch 210/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3240 - acc: 0.8878 - val_loss: 0.2226 - val_acc: 0.9280
Epoch 211/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3289 - acc: 0.8897 - val_loss: 0.2310 - val_acc: 0.9280
Epoch 212/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3359 - acc: 0.8812 - val_loss: 0.2289 - val_acc: 0.9227
Epoch 213/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3075 - acc: 0.8897 - val_loss: 0.2808 - val_acc: 0.9067
Epoch 214/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3221 - acc: 0.8911 - val_loss: 0.2196 - val_acc: 0.9333
Epoch 215/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3306 - acc: 0.8940 - val_loss: 0.2581 - val_acc: 0.9173
Epoch 216/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3322 - acc: 0.8926 - val_loss: 0.2480 - val_acc: 0.9253
Epoch 217/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3205 - acc: 0.8926 - val_loss: 0.2247 - val_acc: 0.9200
Epoch 218/300
9/9 [==============================] - 0s 20ms/step - loss: 0.2994 - acc: 0.9048 - val_loss: 0.2313 - val_acc: 0.9253
Epoch 219/300
9/9 [==============================] - 0s 20ms/step - loss: 0.2976 - acc: 0.9010 - val_loss: 0.3234 - val_acc: 0.9120
Epoch 220/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3019 - acc: 0.9034 - val_loss: 0.2305 - val_acc: 0.9280
Epoch 221/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3291 - acc: 0.8911 - val_loss: 0.2374 - val_acc: 0.9173
Epoch 222/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3271 - acc: 0.8926 - val_loss: 0.2439 - val_acc: 0.9227
Epoch 223/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3160 - acc: 0.9001 - val_loss: 0.2464 - val_acc: 0.9280
Epoch 224/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3144 - acc: 0.8973 - val_loss: 0.2419 - val_acc: 0.9280
Epoch 225/300
9/9 [==============================] - 0s 22ms/step - loss: 0.3074 - acc: 0.8954 - val_loss: 0.2257 - val_acc: 0.9253
Epoch 226/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3185 - acc: 0.8841 - val_loss: 0.2430 - val_acc: 0.9173
Epoch 227/300
9/9 [==============================] - 0s 22ms/step - loss: 0.3171 - acc: 0.8940 - val_loss: 0.2129 - val_acc: 0.9253
Epoch 228/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3281 - acc: 0.8935 - val_loss: 0.3000 - val_acc: 0.9120
Epoch 229/300
9/9 [==============================] - 0s 20ms/step - loss: 0.3426 - acc: 0.8940 - val_loss: 0.2401 - val_acc: 0.9200
Epoch 230/300
9/9 [==============================] - 0s 21ms/step - loss: 0.3424 - acc: 0.8874 - val_loss: 0.2459 - val_acc: 0.9227
Epoch 231/300
9/9 [==============================] - 0s 22ms/step - loss: 0.3451 - acc: 0.8902 - val_loss: 0.2249 - val_acc: 0.9253
display_learning_curves(history)
_, test_accuracy = baseline_model.evaluate(x=x_test, y=y_test, verbose=0)
print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
Test accuracy: 93.45%

baseline 모델 예측

def generate_random_instances(num_instances):
    token_probability = x_train.mean(axis=0)
    instances = []
    for _ in range(num_instances):
        probabilities = np.random.uniform(size=len(token_probability))
        instance = (probabilities <= token_probability).astype(int)
        instances.append(instance)

    return np.array(instances)


def display_class_probabilities(probabilities):
    for instance_idx, probs in enumerate(probabilities):
        print(f"Instance {instance_idx + 1}:")
        for class_idx, prob in enumerate(probs):
            print(f"- {class_values[class_idx]}: {round(prob * 100, 2)}%")
new_instances = generate_random_instances(num_classes)
logits = baseline_model.predict(new_instances)
probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
display_class_probabilities(probabilities)
Instance 1:
- 0: 0.77%
- 1: 9.74%
- 2: 8.62%
- 3: 2.62%
- 4: 30.39%
- 5: 21.94%
- 6: 0.28%
- 7: 22.4%
- 8: 2.6%
- 9: 0.64%
Instance 2:
- 0: 0.77%
- 1: 9.74%
- 2: 8.62%
- 3: 2.62%
- 4: 30.39%
- 5: 21.94%
- 6: 0.28%
- 7: 22.4%
- 8: 2.6%
- 9: 0.64%
Instance 3:
- 0: 0.77%
- 1: 9.74%
- 2: 8.62%
- 3: 2.62%
- 4: 30.39%
- 5: 21.94%
- 6: 0.28%
- 7: 22.4%
- 8: 2.6%
- 9: 0.64%
Instance 4:
- 0: 0.77%
- 1: 9.74%
- 2: 8.62%
- 3: 2.62%
- 4: 30.39%
- 5: 21.94%
- 6: 0.28%
- 7: 22.4%
- 8: 2.6%
- 9: 0.64%
Instance 5:
- 0: 0.0%
- 1: 0.0%
- 2: 0.0%
- 3: 0.01%
- 4: 0.2%
- 5: 99.79%
- 6: 0.0%
- 7: 0.0%
- 8: 0.0%
- 9: 0.0%
Instance 6:
- 0: 0.77%
- 1: 9.74%
- 2: 8.62%
- 3: 2.62%
- 4: 30.39%
- 5: 21.94%
- 6: 0.28%
- 7: 22.4%
- 8: 2.6%
- 9: 0.64%
Instance 7:
- 0: 0.0%
- 1: 99.98%
- 2: 0.0%
- 3: 0.0%
- 4: 0.0%
- 5: 0.02%
- 6: 0.0%
- 7: 0.0%
- 8: 0.0%
- 9: 0.0%
Instance 8:
- 0: 0.77%
- 1: 9.74%
- 2: 8.62%
- 3: 2.62%
- 4: 30.39%
- 5: 21.94%
- 6: 0.28%
- 7: 22.4%
- 8: 2.6%
- 9: 0.64%
Instance 9:
- 0: 0.0%
- 1: 0.0%
- 2: 99.53%
- 3: 0.0%
- 4: 0.47%
- 5: 0.0%
- 6: 0.0%
- 7: 0.0%
- 8: 0.0%
- 9: 0.0%
Instance 10:
- 0: 0.0%
- 1: 0.0%
- 2: 0.0%
- 3: 0.0%
- 4: 0.0%
- 5: 100.0%
- 6: 0.0%
- 7: 0.0%
- 8: 0.0%
- 9: 0.0%
theta = 5000
edge_weights = []
for i in range(len(citations)):
    edge_weights.append(np.exp(-((citations.target[i] - citations.source[i])**2/theta).sum()))
edges = citations[["source", "target"]].to_numpy().T
# Create an edge weights array of ones.
#edge_weights = tf.ones(shape=edges.shape[1])
edge_weights = tf.constant(edge_weights,shape=edges.shape[1])
# Create a node features array of shape [num_nodes, num_features].
node_features = tf.cast(
    papers.sort_values("paper_id")[feature_names].to_numpy(), dtype=tf.dtypes.float32
)
# Create graph info tuple with node_features, edges, and edge_weights.
graph_info = (node_features, edges, edge_weights)

print("Edges shape:", edges.shape)
print("Edge weight shape:", edge_weights.shape)
print("Nodes shape:", node_features.shape)
Edges shape: (2, 24990001)
Edge weight shape: (24990001,)
Nodes shape: (4999, 784)
class GraphConvLayer(layers.Layer):
    def __init__(
        self,
        hidden_units,
        dropout_rate=0.2,
        aggregation_type="mean",
        combination_type="concat",
        normalize=False,
        *args,
        **kwargs,
    ):
        super(GraphConvLayer, self).__init__(*args, **kwargs)

        self.aggregation_type = aggregation_type
        self.combination_type = combination_type
        self.normalize = normalize

        self.ffn_prepare = create_ffn(hidden_units, dropout_rate)
        if self.combination_type == "gated":
            self.update_fn = layers.GRU(
                units=hidden_units,
                activation="tanh",
                recurrent_activation="sigmoid",
                dropout=dropout_rate,
                return_state=True,
                recurrent_dropout=dropout_rate,
            )
        else:
            self.update_fn = create_ffn(hidden_units, dropout_rate)

    def prepare(self, node_repesentations, weights=None):
        # node_repesentations shape is [num_edges, embedding_dim].
        messages = self.ffn_prepare(node_repesentations)
        if weights is not None:
            messages = messages * tf.expand_dims(weights, -1)
        return messages

    def aggregate(self, node_indices, neighbour_messages):
        # node_indices shape is [num_edges].
        # neighbour_messages shape: [num_edges, representation_dim].
        num_nodes = tf.math.reduce_max(node_indices) + 1
        if self.aggregation_type == "sum":
            aggregated_message = tf.math.unsorted_segment_sum(
                neighbour_messages, node_indices, num_segments=num_nodes
            )
        elif self.aggregation_type == "mean":
            aggregated_message = tf.math.unsorted_segment_mean(
                neighbour_messages, node_indices, num_segments=num_nodes
            )
        elif self.aggregation_type == "max":
            aggregated_message = tf.math.unsorted_segment_max(
                neighbour_messages, node_indices, num_segments=num_nodes
            )
        else:
            raise ValueError(f"Invalid aggregation type: {self.aggregation_type}.")

        return aggregated_message

    def update(self, node_repesentations, aggregated_messages):
        # node_repesentations shape is [num_nodes, representation_dim].
        # aggregated_messages shape is [num_nodes, representation_dim].
        if self.combination_type == "gru":
            # Create a sequence of two elements for the GRU layer.
            h = tf.stack([node_repesentations, aggregated_messages], axis=1)
        elif self.combination_type == "concat":
            # Concatenate the node_repesentations and aggregated_messages.
            h = tf.concat([node_repesentations, aggregated_messages], axis=1)
        elif self.combination_type == "add":
            # Add node_repesentations and aggregated_messages.
            h = node_repesentations + aggregated_messages
        else:
            raise ValueError(f"Invalid combination type: {self.combination_type}.")

        # Apply the processing function.
        node_embeddings = self.update_fn(h)
        if self.combination_type == "gru":
            node_embeddings = tf.unstack(node_embeddings, axis=1)[-1]

        if self.normalize:
            node_embeddings = tf.nn.l2_normalize(node_embeddings, axis=-1)
        return node_embeddings

    def call(self, inputs):
        """Process the inputs to produce the node_embeddings.

        inputs: a tuple of three elements: node_repesentations, edges, edge_weights.
        Returns: node_embeddings of shape [num_nodes, representation_dim].
        """

        node_repesentations, edges, edge_weights = inputs
        # Get node_indices (source) and neighbour_indices (target) from edges.
        node_indices, neighbour_indices = edges[0], edges[1]
        # neighbour_repesentations shape is [num_edges, representation_dim].
        neighbour_repesentations = tf.gather(node_repesentations, neighbour_indices)

        # Prepare the messages of the neighbours.
        neighbour_messages = self.prepare(neighbour_repesentations, edge_weights)
        # Aggregate the neighbour messages.
        aggregated_messages = self.aggregate(node_indices, neighbour_messages)
        # Update the node embedding with the neighbour messages.
        return self.update(node_repesentations, aggregated_messages)
class GNNNodeClassifier(tf.keras.Model):
    def __init__(
        self,
        graph_info,
        num_classes,
        hidden_units,
        aggregation_type="sum",
        combination_type="concat",
        dropout_rate=0.2,
        normalize=True,
        *args,
        **kwargs,
    ):
        super(GNNNodeClassifier, self).__init__(*args, **kwargs)

        # Unpack graph_info to three elements: node_features, edges, and edge_weight.
        node_features, edges, edge_weights = graph_info
        self.node_features = node_features
        self.edges = edges
        self.edge_weights = edge_weights
        # Set edge_weights to ones if not provided.
        if self.edge_weights is None:
            self.edge_weights = tf.ones(shape=edges.shape[1])
        # Scale edge_weights to sum to 1.
        self.edge_weights = self.edge_weights / tf.math.reduce_sum(self.edge_weights)

        # Create a process layer.
        self.preprocess = create_ffn(hidden_units, dropout_rate, name="preprocess")
        # Create the first GraphConv layer.
        self.conv1 = GraphConvLayer(
            hidden_units,
            dropout_rate,
            aggregation_type,
            combination_type,
            normalize,
            name="graph_conv1",
        )
        # Create the second GraphConv layer.
        self.conv2 = GraphConvLayer(
            hidden_units,
            dropout_rate,
            aggregation_type,
            combination_type,
            normalize,
            name="graph_conv2",
        )
        # Create a postprocess layer.
        self.postprocess = create_ffn(hidden_units, dropout_rate, name="postprocess")
        # Create a compute logits layer.
        self.compute_logits = layers.Dense(units=num_classes, name="logits")

    def call(self, input_node_indices):
        # Preprocess the node_features to produce node representations.
        x = self.preprocess(self.node_features)
        # Apply the first graph conv layer.
        x1 = self.conv1((x, self.edges, self.edge_weights))
        # Skip connection.
        x = x1 + x
        # Apply the second graph conv layer.
        x2 = self.conv2((x, self.edges, self.edge_weights))
        # Skip connection.
        x = x2 + x
        # Postprocess node embedding.
        x = self.postprocess(x)
        # Fetch node embeddings for the input node_indices.
        node_embeddings = tf.gather(x, input_node_indices)
        # Compute logits
        return self.compute_logits(node_embeddings)
gnn_model = GNNNodeClassifier(
    graph_info=graph_info,
    num_classes=num_classes,
    hidden_units=hidden_units,
    dropout_rate=dropout_rate,
    name="gnn_model",
)

print("GNN output shape:", gnn_model([1, 10, 100]))

gnn_model.summary()
GNN output shape: tf.Tensor(
[[ 0.06893904 -0.0480812   0.02823199  0.04244309  0.08524816  0.171109
  -0.17538266  0.05950593 -0.06625694 -0.01614501]
 [ 0.03818522 -0.09067823  0.00322416 -0.04379443  0.04936944  0.02724189
  -0.0214003  -0.05344744 -0.0940266  -0.14025614]
 [ 0.00369792  0.03900653 -0.02962643  0.06179402  0.04200113  0.0408471
  -0.02868286  0.01675259 -0.03795212 -0.0916561 ]], shape=(3, 10), dtype=float32)
Model: "gnn_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 preprocess (Sequential)     (4999, 32)                29440     
                                                                 
 graph_conv1 (GraphConvLayer  multiple                 5888      
 )                                                               
                                                                 
 graph_conv2 (GraphConvLayer  multiple                 5888      
 )                                                               
                                                                 
 postprocess (Sequential)    (4999, 32)                2368      
                                                                 
 logits (Dense)              multiple                  330       
                                                                 
=================================================================
Total params: 43,914
Trainable params: 41,514
Non-trainable params: 2,400
_________________________________________________________________
x_train = train_data.paper_id.to_numpy()
history = run_experiment(gnn_model, x_train, y_train)
Epoch 1/300
9/9 [==============================] - 348s 34s/step - loss: 2.4703 - acc: 0.1385 - val_loss: 2.3062 - val_acc: 0.0960
Epoch 2/300
9/9 [==============================] - 289s 32s/step - loss: 2.1113 - acc: 0.2399 - val_loss: 2.3019 - val_acc: 0.0960
Epoch 3/300
9/9 [==============================] - 288s 32s/step - loss: 1.9192 - acc: 0.3049 - val_loss: 2.3029 - val_acc: 0.0987
Epoch 4/300
9/9 [==============================] - 288s 32s/step - loss: 1.7462 - acc: 0.3680 - val_loss: 2.3155 - val_acc: 0.0907
Epoch 5/300
9/9 [==============================] - 288s 32s/step - loss: 1.6050 - acc: 0.4006 - val_loss: 2.3918 - val_acc: 0.0907
Epoch 6/300
9/9 [==============================] - 290s 32s/step - loss: 1.4482 - acc: 0.4599 - val_loss: 2.5441 - val_acc: 0.1547
Epoch 7/300
9/9 [==============================] - 289s 32s/step - loss: 1.3789 - acc: 0.4986 - val_loss: 2.7455 - val_acc: 0.1787
Epoch 8/300
9/9 [==============================] - 287s 32s/step - loss: 1.2838 - acc: 0.5189 - val_loss: 2.9160 - val_acc: 0.0960
Epoch 9/300
9/9 [==============================] - 289s 32s/step - loss: 1.3040 - acc: 0.5113 - val_loss: 2.3928 - val_acc: 0.1067
Epoch 10/300
9/9 [==============================] - 287s 32s/step - loss: 1.2443 - acc: 0.5476 - val_loss: 2.4522 - val_acc: 0.1067
Epoch 11/300
9/9 [==============================] - 289s 32s/step - loss: 1.2411 - acc: 0.5372 - val_loss: 2.4832 - val_acc: 0.1387
Epoch 12/300
9/9 [==============================] - 289s 32s/step - loss: 1.1643 - acc: 0.5740 - val_loss: 2.8263 - val_acc: 0.1227
Epoch 13/300
9/9 [==============================] - 288s 32s/step - loss: 1.1429 - acc: 0.5749 - val_loss: 3.0935 - val_acc: 0.1893
Epoch 14/300
9/9 [==============================] - 289s 32s/step - loss: 1.1149 - acc: 0.5933 - val_loss: 2.6812 - val_acc: 0.0907
Epoch 15/300
9/9 [==============================] - 290s 32s/step - loss: 1.0751 - acc: 0.6117 - val_loss: 2.4274 - val_acc: 0.1120
Epoch 16/300
9/9 [==============================] - 288s 32s/step - loss: 1.0634 - acc: 0.6249 - val_loss: 2.5907 - val_acc: 0.1173
Epoch 17/300
9/9 [==============================] - 289s 32s/step - loss: 1.0316 - acc: 0.6254 - val_loss: 2.6404 - val_acc: 0.1147
Epoch 18/300
9/9 [==============================] - 289s 32s/step - loss: 1.0656 - acc: 0.6230 - val_loss: 2.6794 - val_acc: 0.1627
Epoch 19/300
9/9 [==============================] - 289s 32s/step - loss: 1.0182 - acc: 0.6418 - val_loss: 2.3105 - val_acc: 0.1147
Epoch 20/300
9/9 [==============================] - 289s 32s/step - loss: 0.9837 - acc: 0.6499 - val_loss: 2.3644 - val_acc: 0.1680
Epoch 21/300
9/9 [==============================] - 288s 32s/step - loss: 1.0113 - acc: 0.6569 - val_loss: 2.5085 - val_acc: 0.1093
Epoch 22/300
9/9 [==============================] - 289s 32s/step - loss: 0.9505 - acc: 0.6791 - val_loss: 2.4427 - val_acc: 0.1867
Epoch 23/300
9/9 [==============================] - 289s 32s/step - loss: 0.9237 - acc: 0.6795 - val_loss: 2.3116 - val_acc: 0.0907
Epoch 24/300
9/9 [==============================] - 289s 32s/step - loss: 0.9253 - acc: 0.6819 - val_loss: 2.4463 - val_acc: 0.0907
Epoch 25/300
9/9 [==============================] - 287s 32s/step - loss: 0.9440 - acc: 0.6777 - val_loss: 2.2554 - val_acc: 0.2720
Epoch 26/300
9/9 [==============================] - 288s 32s/step - loss: 0.9395 - acc: 0.6762 - val_loss: 2.1910 - val_acc: 0.2240
Epoch 27/300
9/9 [==============================] - 289s 32s/step - loss: 0.9160 - acc: 0.6942 - val_loss: 1.9624 - val_acc: 0.2853
Epoch 28/300
9/9 [==============================] - 289s 32s/step - loss: 0.8955 - acc: 0.6871 - val_loss: 1.8314 - val_acc: 0.1893
Epoch 29/300
9/9 [==============================] - 288s 32s/step - loss: 0.9091 - acc: 0.6918 - val_loss: 2.1182 - val_acc: 0.1307
Epoch 30/300
9/9 [==============================] - 288s 32s/step - loss: 0.9227 - acc: 0.6852 - val_loss: 2.6566 - val_acc: 0.1387
Epoch 31/300
9/9 [==============================] - 289s 32s/step - loss: 0.8968 - acc: 0.7017 - val_loss: 2.4067 - val_acc: 0.1947
Epoch 32/300
9/9 [==============================] - 288s 32s/step - loss: 0.8997 - acc: 0.7074 - val_loss: 2.0040 - val_acc: 0.2080
Epoch 33/300
9/9 [==============================] - 288s 32s/step - loss: 0.8719 - acc: 0.7008 - val_loss: 1.6479 - val_acc: 0.2720
Epoch 34/300
9/9 [==============================] - 289s 32s/step - loss: 0.8532 - acc: 0.6998 - val_loss: 1.6932 - val_acc: 0.3520
Epoch 35/300
9/9 [==============================] - 289s 32s/step - loss: 0.8403 - acc: 0.7220 - val_loss: 2.3145 - val_acc: 0.1733
Epoch 36/300
9/9 [==============================] - 288s 32s/step - loss: 0.8134 - acc: 0.7139 - val_loss: 2.1808 - val_acc: 0.1227
Epoch 37/300
9/9 [==============================] - 289s 32s/step - loss: 0.8113 - acc: 0.7234 - val_loss: 2.2719 - val_acc: 0.2933
Epoch 38/300
9/9 [==============================] - 289s 32s/step - loss: 0.8021 - acc: 0.7328 - val_loss: 1.6167 - val_acc: 0.3200
Epoch 39/300
9/9 [==============================] - 289s 32s/step - loss: 0.8366 - acc: 0.7281 - val_loss: 1.9052 - val_acc: 0.3253
Epoch 40/300
9/9 [==============================] - 288s 32s/step - loss: 0.8150 - acc: 0.7290 - val_loss: 1.8888 - val_acc: 0.4613
Epoch 41/300
9/9 [==============================] - 288s 32s/step - loss: 0.8302 - acc: 0.7271 - val_loss: 1.5892 - val_acc: 0.4453
Epoch 42/300
9/9 [==============================] - 288s 32s/step - loss: 0.8100 - acc: 0.7333 - val_loss: 2.0013 - val_acc: 0.3387
Epoch 43/300
9/9 [==============================] - 288s 32s/step - loss: 0.8282 - acc: 0.7177 - val_loss: 1.4504 - val_acc: 0.3760
Epoch 44/300
9/9 [==============================] - 289s 32s/step - loss: 0.8096 - acc: 0.7234 - val_loss: 1.9500 - val_acc: 0.2773
Epoch 45/300
9/9 [==============================] - 288s 32s/step - loss: 0.7833 - acc: 0.7446 - val_loss: 1.3887 - val_acc: 0.4373
Epoch 46/300
9/9 [==============================] - 289s 32s/step - loss: 0.8093 - acc: 0.7356 - val_loss: 1.1449 - val_acc: 0.5493
Epoch 47/300
9/9 [==============================] - 288s 32s/step - loss: 0.7724 - acc: 0.7356 - val_loss: 0.9704 - val_acc: 0.6853
Epoch 48/300
9/9 [==============================] - 289s 32s/step - loss: 0.7803 - acc: 0.7370 - val_loss: 1.2775 - val_acc: 0.5120
Epoch 49/300
9/9 [==============================] - 289s 32s/step - loss: 0.7821 - acc: 0.7361 - val_loss: 1.7544 - val_acc: 0.4240
Epoch 50/300
9/9 [==============================] - 288s 32s/step - loss: 0.7443 - acc: 0.7568 - val_loss: 1.2639 - val_acc: 0.5200
Epoch 51/300
9/9 [==============================] - 289s 32s/step - loss: 0.7875 - acc: 0.7370 - val_loss: 1.4599 - val_acc: 0.4613
Epoch 52/300
9/9 [==============================] - 290s 32s/step - loss: 0.7560 - acc: 0.7512 - val_loss: 1.3865 - val_acc: 0.5760
Epoch 53/300
9/9 [==============================] - 288s 32s/step - loss: 0.7662 - acc: 0.7526 - val_loss: 0.9696 - val_acc: 0.6507
Epoch 54/300
9/9 [==============================] - 288s 32s/step - loss: 0.7311 - acc: 0.7540 - val_loss: 1.3867 - val_acc: 0.4347
Epoch 55/300
9/9 [==============================] - 289s 32s/step - loss: 0.7360 - acc: 0.7620 - val_loss: 0.7989 - val_acc: 0.7413
Epoch 56/300
9/9 [==============================] - 289s 32s/step - loss: 0.7731 - acc: 0.7418 - val_loss: 0.8699 - val_acc: 0.6773
Epoch 57/300
9/9 [==============================] - 288s 32s/step - loss: 0.7164 - acc: 0.7686 - val_loss: 0.8272 - val_acc: 0.7253
Epoch 58/300
9/9 [==============================] - 289s 32s/step - loss: 0.7456 - acc: 0.7625 - val_loss: 1.0401 - val_acc: 0.6587
Epoch 59/300
9/9 [==============================] - 288s 32s/step - loss: 0.7532 - acc: 0.7568 - val_loss: 0.6484 - val_acc: 0.8240
Epoch 60/300
9/9 [==============================] - 289s 32s/step - loss: 0.7165 - acc: 0.7644 - val_loss: 0.6620 - val_acc: 0.8027
Epoch 61/300
9/9 [==============================] - 289s 32s/step - loss: 0.7065 - acc: 0.7724 - val_loss: 0.9984 - val_acc: 0.6373
Epoch 62/300
9/9 [==============================] - 289s 32s/step - loss: 0.7564 - acc: 0.7488 - val_loss: 0.6564 - val_acc: 0.8240
Epoch 63/300
9/9 [==============================] - 288s 32s/step - loss: 0.6844 - acc: 0.7738 - val_loss: 0.7051 - val_acc: 0.7867
Epoch 64/300
9/9 [==============================] - 290s 32s/step - loss: 0.6970 - acc: 0.7776 - val_loss: 0.5473 - val_acc: 0.8400
Epoch 65/300
9/9 [==============================] - 289s 32s/step - loss: 0.7142 - acc: 0.7705 - val_loss: 0.5511 - val_acc: 0.8293
Epoch 66/300
9/9 [==============================] - 288s 32s/step - loss: 0.7367 - acc: 0.7630 - val_loss: 0.4868 - val_acc: 0.8640
Epoch 67/300
9/9 [==============================] - 289s 32s/step - loss: 0.7094 - acc: 0.7719 - val_loss: 0.5449 - val_acc: 0.8080
Epoch 68/300
9/9 [==============================] - 289s 32s/step - loss: 0.7332 - acc: 0.7686 - val_loss: 0.6491 - val_acc: 0.7813
Epoch 69/300
9/9 [==============================] - 289s 32s/step - loss: 0.6884 - acc: 0.7648 - val_loss: 0.5276 - val_acc: 0.8347
Epoch 70/300
9/9 [==============================] - 287s 32s/step - loss: 0.7121 - acc: 0.7818 - val_loss: 0.4928 - val_acc: 0.8427
Epoch 71/300
9/9 [==============================] - 289s 32s/step - loss: 0.6976 - acc: 0.7922 - val_loss: 0.4663 - val_acc: 0.8613
Epoch 72/300
9/9 [==============================] - 289s 32s/step - loss: 0.7133 - acc: 0.7667 - val_loss: 0.3886 - val_acc: 0.8880
Epoch 73/300
9/9 [==============================] - 288s 32s/step - loss: 0.7059 - acc: 0.7780 - val_loss: 0.3784 - val_acc: 0.9173
Epoch 74/300
9/9 [==============================] - 288s 32s/step - loss: 0.6531 - acc: 0.7875 - val_loss: 0.3568 - val_acc: 0.8880
Epoch 75/300
9/9 [==============================] - 288s 32s/step - loss: 0.6900 - acc: 0.7733 - val_loss: 0.4615 - val_acc: 0.8507
Epoch 76/300
9/9 [==============================] - 290s 32s/step - loss: 0.6587 - acc: 0.7988 - val_loss: 0.3631 - val_acc: 0.8880
Epoch 77/300
9/9 [==============================] - 289s 32s/step - loss: 0.6966 - acc: 0.7828 - val_loss: 0.3805 - val_acc: 0.8853
Epoch 78/300
9/9 [==============================] - 289s 32s/step - loss: 0.6688 - acc: 0.7870 - val_loss: 0.3958 - val_acc: 0.8800
Epoch 79/300
9/9 [==============================] - 289s 32s/step - loss: 0.6541 - acc: 0.7879 - val_loss: 0.3434 - val_acc: 0.9040
Epoch 80/300
9/9 [==============================] - 290s 32s/step - loss: 0.6430 - acc: 0.8068 - val_loss: 0.4306 - val_acc: 0.8613
Epoch 81/300
9/9 [==============================] - 288s 32s/step - loss: 0.6611 - acc: 0.7912 - val_loss: 0.3325 - val_acc: 0.9147
Epoch 82/300
9/9 [==============================] - 288s 32s/step - loss: 0.6653 - acc: 0.7842 - val_loss: 0.4299 - val_acc: 0.8907
Epoch 83/300
9/9 [==============================] - 289s 32s/step - loss: 0.6843 - acc: 0.7870 - val_loss: 0.4413 - val_acc: 0.8720
Epoch 84/300
9/9 [==============================] - 289s 32s/step - loss: 0.6861 - acc: 0.7799 - val_loss: 0.3113 - val_acc: 0.9120
Epoch 85/300
9/9 [==============================] - 289s 32s/step - loss: 0.6650 - acc: 0.7926 - val_loss: 0.3560 - val_acc: 0.8773
Epoch 86/300
9/9 [==============================] - 289s 32s/step - loss: 0.6640 - acc: 0.7889 - val_loss: 0.3545 - val_acc: 0.8933
Epoch 87/300
9/9 [==============================] - 289s 32s/step - loss: 0.6672 - acc: 0.7945 - val_loss: 0.3433 - val_acc: 0.8987
Epoch 88/300
9/9 [==============================] - 289s 32s/step - loss: 0.6553 - acc: 0.7997 - val_loss: 0.3190 - val_acc: 0.9120
Epoch 89/300
9/9 [==============================] - 288s 32s/step - loss: 0.6288 - acc: 0.8011 - val_loss: 0.3502 - val_acc: 0.9013
Epoch 90/300
9/9 [==============================] - 289s 32s/step - loss: 0.6683 - acc: 0.7842 - val_loss: 0.3640 - val_acc: 0.8827
Epoch 91/300
9/9 [==============================] - 288s 32s/step - loss: 0.6642 - acc: 0.7955 - val_loss: 0.3282 - val_acc: 0.9013
Epoch 92/300
9/9 [==============================] - 289s 32s/step - loss: 0.6469 - acc: 0.7936 - val_loss: 0.3219 - val_acc: 0.8987
Epoch 93/300
9/9 [==============================] - 288s 32s/step - loss: 0.6985 - acc: 0.7865 - val_loss: 0.3244 - val_acc: 0.9013
Epoch 94/300
9/9 [==============================] - 290s 32s/step - loss: 0.6634 - acc: 0.7997 - val_loss: 0.3430 - val_acc: 0.8933
Epoch 95/300
9/9 [==============================] - 289s 32s/step - loss: 0.6717 - acc: 0.7941 - val_loss: 0.3334 - val_acc: 0.9013
Epoch 96/300
9/9 [==============================] - 289s 32s/step - loss: 0.6492 - acc: 0.8040 - val_loss: 0.3472 - val_acc: 0.9067
Epoch 97/300
9/9 [==============================] - 288s 32s/step - loss: 0.6555 - acc: 0.7917 - val_loss: 0.2848 - val_acc: 0.9173
Epoch 98/300
9/9 [==============================] - 290s 32s/step - loss: 0.6592 - acc: 0.7983 - val_loss: 0.3885 - val_acc: 0.8773
Epoch 99/300
9/9 [==============================] - 289s 32s/step - loss: 0.6385 - acc: 0.7936 - val_loss: 0.3580 - val_acc: 0.9013
Epoch 100/300
9/9 [==============================] - 289s 32s/step - loss: 0.6575 - acc: 0.7870 - val_loss: 0.3591 - val_acc: 0.8773
Epoch 101/300
9/9 [==============================] - 289s 32s/step - loss: 0.6534 - acc: 0.8040 - val_loss: 0.3208 - val_acc: 0.8987
Epoch 102/300
9/9 [==============================] - 290s 32s/step - loss: 0.6505 - acc: 0.7988 - val_loss: 0.3588 - val_acc: 0.8933
Epoch 103/300
9/9 [==============================] - 288s 32s/step - loss: 0.6505 - acc: 0.7959 - val_loss: 0.3133 - val_acc: 0.9093
Epoch 104/300
9/9 [==============================] - 288s 32s/step - loss: 0.6624 - acc: 0.8016 - val_loss: 0.3640 - val_acc: 0.8880
Epoch 105/300
9/9 [==============================] - 288s 32s/step - loss: 0.5964 - acc: 0.8120 - val_loss: 0.4180 - val_acc: 0.9013
Epoch 106/300
9/9 [==============================] - 290s 32s/step - loss: 0.6535 - acc: 0.7959 - val_loss: 0.3277 - val_acc: 0.9120
Epoch 107/300
9/9 [==============================] - 288s 32s/step - loss: 0.6431 - acc: 0.7936 - val_loss: 0.3360 - val_acc: 0.9093
Epoch 108/300
9/9 [==============================] - 288s 32s/step - loss: 0.6210 - acc: 0.8077 - val_loss: 0.3520 - val_acc: 0.9067
Epoch 109/300
9/9 [==============================] - 289s 32s/step - loss: 0.6449 - acc: 0.8049 - val_loss: 0.3523 - val_acc: 0.9040
Epoch 110/300
9/9 [==============================] - 289s 32s/step - loss: 0.6506 - acc: 0.8073 - val_loss: 0.3740 - val_acc: 0.8933
Epoch 111/300
9/9 [==============================] - 290s 32s/step - loss: 0.6432 - acc: 0.8016 - val_loss: 0.3338 - val_acc: 0.8987
Epoch 112/300
9/9 [==============================] - 288s 32s/step - loss: 0.6369 - acc: 0.8030 - val_loss: 0.3134 - val_acc: 0.9173
Epoch 113/300
9/9 [==============================] - 290s 32s/step - loss: 0.6455 - acc: 0.8040 - val_loss: 0.3178 - val_acc: 0.9200
Epoch 114/300
9/9 [==============================] - 289s 32s/step - loss: 0.6575 - acc: 0.8002 - val_loss: 0.4226 - val_acc: 0.8747
Epoch 115/300
9/9 [==============================] - 289s 32s/step - loss: 0.6401 - acc: 0.8030 - val_loss: 0.3579 - val_acc: 0.8987
Epoch 116/300
9/9 [==============================] - 289s 32s/step - loss: 0.6106 - acc: 0.8153 - val_loss: 0.3670 - val_acc: 0.8933
Epoch 117/300
9/9 [==============================] - 289s 32s/step - loss: 0.6678 - acc: 0.7936 - val_loss: 0.3216 - val_acc: 0.9067
Epoch 118/300
9/9 [==============================] - 289s 32s/step - loss: 0.6309 - acc: 0.8091 - val_loss: 0.3824 - val_acc: 0.8960
Epoch 119/300
9/9 [==============================] - 288s 32s/step - loss: 0.6611 - acc: 0.8002 - val_loss: 0.3420 - val_acc: 0.9120
Epoch 120/300
9/9 [==============================] - 289s 32s/step - loss: 0.6504 - acc: 0.8073 - val_loss: 0.3037 - val_acc: 0.9227
Epoch 121/300
9/9 [==============================] - 290s 32s/step - loss: 0.6110 - acc: 0.8058 - val_loss: 0.3220 - val_acc: 0.9093
Epoch 122/300
9/9 [==============================] - 288s 32s/step - loss: 0.6309 - acc: 0.8063 - val_loss: 0.3872 - val_acc: 0.9013
Epoch 123/300
9/9 [==============================] - 290s 32s/step - loss: 0.6287 - acc: 0.8139 - val_loss: 0.2961 - val_acc: 0.9093
Epoch 124/300
9/9 [==============================] - 288s 32s/step - loss: 0.6294 - acc: 0.8049 - val_loss: 0.3854 - val_acc: 0.8987
Epoch 125/300
9/9 [==============================] - 289s 32s/step - loss: 0.6466 - acc: 0.8007 - val_loss: 0.2911 - val_acc: 0.9200
Epoch 126/300
9/9 [==============================] - 289s 32s/step - loss: 0.6440 - acc: 0.8124 - val_loss: 0.3548 - val_acc: 0.9067
Epoch 127/300
9/9 [==============================] - 287s 32s/step - loss: 0.5944 - acc: 0.8238 - val_loss: 0.4064 - val_acc: 0.8773
Epoch 128/300
9/9 [==============================] - 289s 32s/step - loss: 0.6159 - acc: 0.8120 - val_loss: 0.3219 - val_acc: 0.9067
Epoch 129/300
9/9 [==============================] - 290s 32s/step - loss: 0.6268 - acc: 0.8063 - val_loss: 0.3755 - val_acc: 0.8933
Epoch 130/300
9/9 [==============================] - 288s 32s/step - loss: 0.6151 - acc: 0.8157 - val_loss: 0.2980 - val_acc: 0.9120
Epoch 131/300
9/9 [==============================] - 288s 32s/step - loss: 0.6023 - acc: 0.8148 - val_loss: 0.2796 - val_acc: 0.9200
Epoch 132/300
9/9 [==============================] - 288s 32s/step - loss: 0.5971 - acc: 0.8096 - val_loss: 0.3257 - val_acc: 0.9067
Epoch 133/300
9/9 [==============================] - 288s 32s/step - loss: 0.6351 - acc: 0.8058 - val_loss: 0.3507 - val_acc: 0.8960
Epoch 134/300
9/9 [==============================] - 290s 32s/step - loss: 0.5959 - acc: 0.8205 - val_loss: 0.3397 - val_acc: 0.8987
Epoch 135/300
9/9 [==============================] - 288s 32s/step - loss: 0.6242 - acc: 0.8110 - val_loss: 0.3224 - val_acc: 0.9013
Epoch 136/300
9/9 [==============================] - 289s 32s/step - loss: 0.6013 - acc: 0.8110 - val_loss: 0.3346 - val_acc: 0.9067
Epoch 137/300
9/9 [==============================] - 289s 32s/step - loss: 0.6196 - acc: 0.8129 - val_loss: 0.3167 - val_acc: 0.9120
Epoch 138/300
9/9 [==============================] - 288s 32s/step - loss: 0.5701 - acc: 0.8289 - val_loss: 0.3381 - val_acc: 0.8933
Epoch 139/300
9/9 [==============================] - 289s 32s/step - loss: 0.5786 - acc: 0.8355 - val_loss: 0.3196 - val_acc: 0.9040
Epoch 140/300
9/9 [==============================] - 289s 32s/step - loss: 0.6233 - acc: 0.8172 - val_loss: 0.3297 - val_acc: 0.9067
Epoch 141/300
9/9 [==============================] - 288s 32s/step - loss: 0.6218 - acc: 0.8134 - val_loss: 0.3219 - val_acc: 0.9040
Epoch 142/300
9/9 [==============================] - 289s 32s/step - loss: 0.6020 - acc: 0.8143 - val_loss: 0.3231 - val_acc: 0.8987
Epoch 143/300
9/9 [==============================] - 288s 32s/step - loss: 0.5944 - acc: 0.8190 - val_loss: 0.2943 - val_acc: 0.9040
Epoch 144/300
9/9 [==============================] - 289s 32s/step - loss: 0.5891 - acc: 0.8275 - val_loss: 0.3513 - val_acc: 0.8933
Epoch 145/300
9/9 [==============================] - 289s 32s/step - loss: 0.5830 - acc: 0.8275 - val_loss: 0.2994 - val_acc: 0.9147
Epoch 146/300
9/9 [==============================] - 289s 32s/step - loss: 0.6253 - acc: 0.8058 - val_loss: 0.3389 - val_acc: 0.9067
Epoch 147/300
9/9 [==============================] - 290s 32s/step - loss: 0.6130 - acc: 0.8040 - val_loss: 0.3578 - val_acc: 0.8933
Epoch 148/300
9/9 [==============================] - 289s 32s/step - loss: 0.6400 - acc: 0.8101 - val_loss: 0.3605 - val_acc: 0.8960
Epoch 149/300
9/9 [==============================] - 289s 32s/step - loss: 0.6003 - acc: 0.8157 - val_loss: 0.3641 - val_acc: 0.8880
Epoch 150/300
9/9 [==============================] - 291s 32s/step - loss: 0.6197 - acc: 0.8044 - val_loss: 0.3772 - val_acc: 0.8880
Epoch 151/300
9/9 [==============================] - 288s 32s/step - loss: 0.5750 - acc: 0.8214 - val_loss: 0.3119 - val_acc: 0.9040
Epoch 152/300
9/9 [==============================] - 289s 32s/step - loss: 0.6107 - acc: 0.8162 - val_loss: 0.4201 - val_acc: 0.8827
Epoch 153/300
9/9 [==============================] - 290s 32s/step - loss: 0.5889 - acc: 0.8256 - val_loss: 0.3413 - val_acc: 0.8933
Epoch 154/300
9/9 [==============================] - 289s 32s/step - loss: 0.6040 - acc: 0.8223 - val_loss: 0.3555 - val_acc: 0.8960
Epoch 155/300
9/9 [==============================] - 290s 32s/step - loss: 0.6071 - acc: 0.8190 - val_loss: 0.3382 - val_acc: 0.9093
Epoch 156/300
9/9 [==============================] - 289s 32s/step - loss: 0.6186 - acc: 0.8153 - val_loss: 0.2963 - val_acc: 0.9147
Epoch 157/300
9/9 [==============================] - 289s 32s/step - loss: 0.5809 - acc: 0.8200 - val_loss: 0.3544 - val_acc: 0.8960
Epoch 158/300
9/9 [==============================] - 290s 32s/step - loss: 0.5708 - acc: 0.8374 - val_loss: 0.3035 - val_acc: 0.9013
Epoch 159/300
9/9 [==============================] - 290s 32s/step - loss: 0.5810 - acc: 0.8162 - val_loss: 0.3234 - val_acc: 0.9013
Epoch 160/300
9/9 [==============================] - 289s 32s/step - loss: 0.6112 - acc: 0.8200 - val_loss: 0.3257 - val_acc: 0.9120
Epoch 161/300
9/9 [==============================] - 289s 32s/step - loss: 0.5862 - acc: 0.8252 - val_loss: 0.3288 - val_acc: 0.9093
Epoch 162/300
9/9 [==============================] - 289s 32s/step - loss: 0.5606 - acc: 0.8266 - val_loss: 0.3009 - val_acc: 0.9147
Epoch 163/300
9/9 [==============================] - 289s 32s/step - loss: 0.5631 - acc: 0.8336 - val_loss: 0.3147 - val_acc: 0.8987
Epoch 164/300
9/9 [==============================] - 289s 32s/step - loss: 0.5900 - acc: 0.8205 - val_loss: 0.3116 - val_acc: 0.9067
Epoch 165/300
9/9 [==============================] - 289s 32s/step - loss: 0.5798 - acc: 0.8238 - val_loss: 0.3508 - val_acc: 0.8960
Epoch 166/300
9/9 [==============================] - 288s 32s/step - loss: 0.6124 - acc: 0.8176 - val_loss: 0.2925 - val_acc: 0.9013
Epoch 167/300
9/9 [==============================] - 290s 32s/step - loss: 0.5694 - acc: 0.8186 - val_loss: 0.3553 - val_acc: 0.9013
Epoch 168/300
9/9 [==============================] - 289s 32s/step - loss: 0.6047 - acc: 0.8190 - val_loss: 0.3539 - val_acc: 0.8960
Epoch 169/300
9/9 [==============================] - 289s 32s/step - loss: 0.6276 - acc: 0.8077 - val_loss: 0.3073 - val_acc: 0.9093
Epoch 170/300
8/9 [=========================>....] - ETA: 30s - loss: 0.6222 - acc: 0.8149 
display_learning_curves(history)
x_test = test_data.paper_id.to_numpy()
_, test_accuracy = gnn_model.evaluate(x=x_test, y=y_test, verbose=0)
print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
Test accuracy: 90.53%
# by appending the new_instance to node_features.
num_nodes = node_features.shape[0]
new_node_features = np.concatenate([node_features, new_instances])
# Second we add the M edges (citations) from each new node to a set
# of existing nodes in a particular subject
new_node_indices = [i + num_nodes for i in range(num_classes)]
new_citations = []
for subject_idx, group in papers.groupby("subject"):
    subject_papers = list(group.paper_id)
    # Select random x papers specific subject.
    selected_paper_indices1 = np.random.choice(subject_papers, 5)
    # Select random y papers from any subject (where y < x).
    selected_paper_indices2 = np.random.choice(list(papers.paper_id), 2)
    # Merge the selected paper indices.
    selected_paper_indices = np.concatenate(
        [selected_paper_indices1, selected_paper_indices2], axis=0
    )
    # Create edges between a citing paper idx and the selected cited papers.
    citing_paper_indx = new_node_indices[subject_idx]
    for cited_paper_idx in selected_paper_indices:
        new_citations.append([citing_paper_indx, cited_paper_idx])

new_citations = np.array(new_citations).T
new_edges = np.concatenate([edges, new_citations], axis=1)
print("Original node_features shape:", gnn_model.node_features.shape)
print("Original edges shape:", gnn_model.edges.shape)
gnn_model.node_features = new_node_features
gnn_model.edges = new_edges
gnn_model.edge_weights = tf.ones(shape=new_edges.shape[1])
print("New node_features shape:", gnn_model.node_features.shape)
print("New edges shape:", gnn_model.edges.shape)

logits = gnn_model.predict(tf.convert_to_tensor(new_node_indices))
probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
display_class_probabilities(probabilities)
Original node_features shape: (5009, 784)
Original edges shape: (2, 24990071)
New node_features shape: (5009, 784)
New edges shape: (2, 24990071)
Instance 1:
- 0: 0.13%
- 1: 74.08%
- 2: 7.11%
- 3: 1.89%
- 4: 2.82%
- 5: 1.28%
- 6: 3.23%
- 7: 3.35%
- 8: 4.91%
- 9: 1.2%
Instance 2:
- 0: 0.27%
- 1: 63.9%
- 2: 9.72%
- 3: 2.54%
- 4: 4.36%
- 5: 1.67%
- 6: 4.41%
- 7: 5.28%
- 8: 5.68%
- 9: 2.17%
Instance 3:
- 0: 0.2%
- 1: 64.57%
- 2: 7.98%
- 3: 2.21%
- 4: 4.56%
- 5: 2.08%
- 6: 4.93%
- 7: 3.61%
- 8: 7.78%
- 9: 2.06%
Instance 4:
- 0: 0.22%
- 1: 63.32%
- 2: 9.5%
- 3: 2.36%
- 4: 4.79%
- 5: 1.68%
- 6: 4.54%
- 7: 4.8%
- 8: 6.51%
- 9: 2.29%
Instance 5:
- 0: 0.0%
- 1: 0.0%
- 2: 0.0%
- 3: 0.0%
- 4: 0.0%
- 5: 0.0%
- 6: 99.99%
- 7: 0.0%
- 8: 0.0%
- 9: 0.0%
Instance 6:
- 0: 0.25%
- 1: 65.76%
- 2: 8.04%
- 3: 2.62%
- 4: 4.44%
- 5: 1.69%
- 6: 3.6%
- 7: 5.74%
- 8: 5.47%
- 9: 2.39%
Instance 7:
- 0: 0.02%
- 1: 0.08%
- 2: 0.01%
- 3: 68.99%
- 4: 0.0%
- 5: 30.73%
- 6: 0.0%
- 7: 0.0%
- 8: 0.13%
- 9: 0.02%
Instance 8:
- 0: 0.28%
- 1: 64.91%
- 2: 9.17%
- 3: 2.73%
- 4: 4.16%
- 5: 1.61%
- 6: 3.69%
- 7: 5.98%
- 8: 5.2%
- 9: 2.26%
Instance 9:
- 0: 0.0%
- 1: 0.0%
- 2: 0.0%
- 3: 0.0%
- 4: 0.0%
- 5: 0.0%
- 6: 100.0%
- 7: 0.0%
- 8: 0.0%
- 9: 0.0%
Instance 10:
- 0: 0.03%
- 1: 0.0%
- 2: 0.0%
- 3: 0.01%
- 4: 0.0%
- 5: 99.93%
- 6: 0.03%
- 7: 0.0%
- 8: 0.0%
- 9: 0.0%