• GNN의 간단한 구현 예제
  • 그래프 컨볼루션 레이어 구현

논문(paper): 2708개

  • 7개의 class
    • 'Case_Based', 'Genetic_Algorithms', 'Neural_Networks', 'Probabilistic_Methods', 'Reinforcement_Learning', 'Rule_Learning', 'Theory'

인용(cites): 5429개

  • cited paper id
  • citing paper id

단어(content): 1433개


CORA 데이터 이해

  • The goal of our experiments is to predict the categories assigned to each paper.
    • 각 논문에 할당된 카테고리를 예측하는 게 시험의 목표이다.
  • Dataset has been pre-processed removing papers that don't belong to at least one category. Also paper that don't have authors and don't have a title have been discarded.
    • 하나의 카테고리에도 속하지 않은 논문들은 사전에 제거되었다. 저자가 없거나 제목 없는 논문들도 제거되었다.
      • 즉, 결측값은 없다는 뜻
  • 11881 papers belonging to 80 different categories
    • 80개의 다른 카테고리를 가진 11,881개의 논문
  • 16114 authors
    • 16,114명의 저자
  • 34648 citations relations between papers
    • 논문 사이에 인용관계 24,648
  • 27020 authorship relations between papers and authors
    • 논문과 저자 사이의 저작권 관계 27,020
  • Each paper is associated with a vectorial representation containing its title represented as bag-of-words with TF-IDF weights.
    • 각 논문은 벡터 표현되어 있다.
  • 목표
    • link-prediction problem: predict citation relations between papers or authorship relations with authors
      • 연결 예측 문제: 논문 간의 인용관계 예측 또는 저자 간의 저작권 관계 예측
    • multi-label classification problem: predict the categories assigned to each paper
      • 다중 레이블 분류 문제: 각 논문에 관련된 카테고리 예측
  • The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.

    • Cora는 7개의 class에서 분류된 2,708개의 과학 출판물로 구성되어 있다. 이 인용 network는 5,429개의 연결로 구성되어 있고, 각 출판물은 사전에서 부합하는 단어의 존재와 부재를 나타내는 valued 단어 벡터로 0/1로써 나타난다. 사전은 1,433개의 unique 단어들로 구성되어 있다.
  • The Cora dataset consists of 2,708 scientific papers classified into one of seven classes. The citation network consists of 5,429 links. Each paper has a binary word vector of size 1,433, indicating the presence of a corresponding word.

import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
2022-06-01 16:46:00.074687: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-06-01 16:46:00.074715: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

Prepare the Dataset

The dataset has two tap-separated files: cora.cites and cora.content.

  • cora.cites와 cora.content의 tab으로 구분된 파일들이다.

The cora.cites includes the citation records with two columns: cited_paper_id (target) and citing_paper_id (source).

  • cora.cites는 cited_paper_id (target)와 citing_paper_id (source)의 두 개의 열이 있는 인용 레코드가 포함된다.

The cora.content includes the paper content records with 1,435 columns: paper_id, subject, and 1,433 binary features.

  • cora.content는 paper_id, subject, 1,433 binary features이 1,435개의 열이 포함된 논문 내용 기록이 포함된다.
zip_file = keras.utils.get_file(
    fname="cora.tgz",
    origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
    extract=True,
)
data_dir = os.path.join(os.path.dirname(zip_file), "cora")
citations = pd.read_csv(
    os.path.join(data_dir, "cora.cites"),
    sep="\t",
    header=None,
    names=["target", "source"],
)
print("Citations shape:", citations.shape)
Citations shape: (5429, 2)

The target column includes the paper ids cited by the paper ids in the source column.

  • source 열에 있는 논문 id에 의해 인용된 논문 id를 target 열이 포함한다.
citations.sample(frac=1).head()
target source
3283 54129 1128291
447 1365 340299
2083 16476 12359
4054 117315 28336
4526 203646 1116569
column_names = ["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"]
papers = pd.read_csv(
    os.path.join(data_dir, "cora.content"), sep="\t", header=None, names=column_names,
)
print("Papers shape:", papers.shape)
Papers shape: (2708, 1435)

Now we display a sample of the papers DataFrame. The DataFrame includes the paper_id and the subject columns, as well as 1,433 binary column representing whether a term exists in the paper or not.

  • 이 데이터에는 논문에 단어가 있는지 없는 지를 나타내는 1,433개의 이진 열과 논문id와 제목 열을 포함한다.

dataset.saple(n): 데이터에서 랜덤으로 n개를 뽑는다.

print(papers.sample(5).T)
                      1014           136                     804   \
paper_id            582139           8703                   75674   
term_0                   0              0                       0   
term_1                   0              0                       0   
term_2                   0              0                       0   
term_3                   0              0                       1   
...                    ...            ...                     ...   
term_1429                0              0                       0   
term_1430                0              0                       0   
term_1431                0              0                       0   
term_1432                0              0                       0   
subject    Neural_Networks  Rule_Learning  Reinforcement_Learning   

                            12          1401  
paper_id                  109323     1108329  
term_0                         0           0  
term_1                         0           1  
term_2                         1           0  
term_3                         0           0  
...                          ...         ...  
term_1429                      0           0  
term_1430                      0           0  
term_1431                      0           0  
term_1432                      0           0  
subject    Probabilistic_Methods  Case_Based  

[1435 rows x 5 columns]
print(papers.subject.value_counts())
Neural_Networks           818
Probabilistic_Methods     426
Genetic_Algorithms        418
Theory                    351
Case_Based                298
Reinforcement_Learning    217
Rule_Learning             180
Name: subject, dtype: int64

We convert the paper ids and the subjects into zero-based indices.

  • 논문 id와 subject를 0 기반의 표현으로 변환한다.
class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])

isin메서드는 DataFrame객체의 각 요소가 values값과 일치하는지 여부를 bool형식으로 반환합니다.

Now let's visualize the citation graph. Each node in the graph represents a paper, and the color of the node corresponds to its subject. Note that we only show a sample of the papers in the dataset.

  • 그래프의 각 노드는 논문을 의미하고, 노드의 색은 subject를 나타낸다. 여기서 보이는것은 데이터의 sample이다.
plt.figure(figsize=(10, 10))
colors = papers["subject"].tolist()
cora_graph = nx.from_pandas_edgelist(citations.sample(n=1500))
subjects = list(papers[papers["paper_id"].isin(list(cora_graph.nodes))]["subject"])
nx.draw_spring(cora_graph, node_size=15, node_color=subjects)

Split the dataset into stratified train and test sets

  • 데이터를 계층화된 test와 train셋으로 나누기

DataFrame으로 부터 특정 비율의 표본을 무작위로 추출하기 (fraction) df.sample(frac=0.5)

  • DataFrame으로 부터 특정 비율(fraction)으로 무작위 표본 추출을 하고 싶으면 frac 매개변수에 0~1 사이의 부동소수형(float) 값을 입력
train_data, test_data = [], []

for _, group_data in papers.groupby("subject"):
    # Select around 50% of the dataset for training.
    random_selection = np.random.rand(len(group_data.index)) <= 0.5
    train_data.append(group_data[random_selection])
    test_data.append(group_data[~random_selection])

train_data = pd.concat(train_data).sample(frac=1)
test_data = pd.concat(test_data).sample(frac=1)

print("Train data shape:", train_data.shape)
print("Test data shape:", test_data.shape)
Train data shape: (1350, 1435)
Test data shape: (1358, 1435)

Implement Train and Evaluate Experiment

  • train 으로 시험해서 test 평가
hidden_units = [32, 32]
learning_rate = 0.01
dropout_rate = 0.5
num_epochs = 300
batch_size = 256

This function compiles and trains an input model using the given training data.

  • 주어진 training data로 입력 모델 컴파일 및 train하는 함수
  • Adam 사용
  • 학습률은 0.01
  • 손실함수(losses.SparseCategoricalCrossentropy()) (https://www.tensorflow.org/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy)
    • Use this crossentropy loss function when there are two or more label classes. We expect labels to be provided as integers. If you want to provide labels using one-hot representation, please use CategoricalCrossentropy loss. There should be # classes floating point values per feature for y_pred and a single floating point value per feature for y_true.
    • 두 개 이상의 레이블 클래스가 있는 경우 이 crossentropy loss function을 사용한다. 레이블은 정수로 제공될 것으로 기대힌다. one-hot representation을 사용하여 레이블을 나타내려면 CategoricalCrossentropy loss를 사용한다. # classes에 대한 기능당 부동 소수점 값 y_pred 과 y_true에 대한 기능당 단일 부동 소수점 값이 있어야 한다.
    • from_logits
      • Whether y_pred is expected to be a logits tensor. By default, we assume that y_pred encodes a probability distribution.
      • y_pred가 logits tensor로 예상되는지의 여부다. 기본설정은 y_pred가 확률분포라고 (from_logit=False) 되어 있다.
    • Both, categorical cross entropy and sparse categorical cross entropy have the same loss function which you have mentioned above. The only difference is the format in which you mention $Y_i$ (i,e true labels).(https://stats.stackexchange.com/questions/326065/cross-entropy-vs-sparse-cross-entropy-when-to-use-one-over-the-other)
      • If your $Y_i$'s are one-hot encoded, use categorical_crossentropy. Examples (for a 3-class classification): $[1,0,0] , [0,1,0], [0,0,1]$
      • But if your $Y_i$'s are integers, use sparse_categorical_crossentropy. Examples for above 3-class classification problem: $[1] , [2], [3]$
  • metric(tf.keras.metrics.SparseCategoricalAccuracy())
    • tf.keras.metrics.SparseCategoricalAccuracy(name='sparse_categorical_accuracy', dtype=None)
      • name $\to$ (Optional) string name of the metric instance.
  • callback(keras.callbacks.EarlyStopping()) (https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping) (Stop training when a monitored metric has stopped improving. 모니터된 metric이 improve를 멈출때 훈련이 중지된다.)
    • tf.keras.callbacks.EarlyStopping(monitor='val_loss',min_delta=0,patience=0,verbose=0,mode='auto',baseline=None,restore_best_weights=False)
      • monitor Quantity to be monitored.
      • patience Number of epochs with no improvement after which training will be stopped.
      • restore_best_weights Whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used. An epoch will be restored regardless of the performance relative to the baseline. If no epoch improves on baseline, training will run for patience epochs and restore weights from the best epoch in that set.
        • 가중치를 재복원할 것인지의 여부를 뭍는다. False라면 훈련의 마지막 단계에서 얻은 모델 가중치를 얻게 된다. 기저치와 관련있어서 baseline에서 에폭 변화가 없으면 그 상태에서 best 에폭으로부터 가중치를 재복원하는 것 같은..
def run_experiment(model, x_train, y_train):
    # Compile the model.
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
    )
    # Create an early stopping callback.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_acc", patience=50, restore_best_weights=True
    )
    # Fit the model.
    history = model.fit(
        x=x_train,
        y=y_train,
        epochs=num_epochs,
        batch_size=batch_size,
        validation_split=0.15,
        callbacks=[early_stopping],
    )

    return history

This function displays the loss and accuracy curves of the model during training.

  • 훈련하는 동안 손실과 정확도 곡선을 나타내는 함수
def display_learning_curves(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    ax1.plot(history.history["loss"])
    ax1.plot(history.history["val_loss"])
    ax1.legend(["train", "test"], loc="upper right")
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")

    ax2.plot(history.history["acc"])
    ax2.plot(history.history["val_acc"])
    ax2.legend(["train", "test"], loc="upper right")
    ax2.set_xlabel("Epochs")
    ax2.set_ylabel("Accuracy")
    plt.show()

Implement Feedforward Network (FFN) Module

  • 순방향 신경망 구현
  • 활성화함수(tf.nn.gelu) (https://www.tensorflow.org/api_docs/python/tf/nn/gelu) (Gaussian Error Linear Unit (GELU) activation function.)
    • tf.nn.gelu(features, approximate=False, name=None)
    • Gaussian error linear unit (GELU) computes $x * P(X <= x)$, where $P(X) \sim N(0, 1)$. The (GELU) nonlinearity weights inputs by their value, rather than gates inputs by their sign as in ReLU.
      • GELU 가우시안 에러 선형 단위(?), Relu 함수는 부호에 따라 입력함수를 처리해주었다면, Gelu 함수는 값에 따라 입력함수를 처리한다. x $\times$ (x의 확률이 표준정규분포를 따를때 표준정규분포표에서 x 보다 작을 확률)
def create_ffn(hidden_units, dropout_rate, name=None):
    fnn_layers = []

    for units in hidden_units:
        fnn_layers.append(layers.BatchNormalization())
        fnn_layers.append(layers.Dropout(dropout_rate))
        fnn_layers.append(layers.Dense(units, activation=tf.nn.gelu))

    return keras.Sequential(fnn_layers, name=name)

Build a Baseline Neural Network Model

feature_names = set(papers.columns) - {"paper_id", "subject"}
num_features = len(feature_names)
num_classes = len(class_idx)

# Create train and test features as a numpy array.
x_train = train_data[feature_names].to_numpy()
x_test = test_data[feature_names].to_numpy()
# Create train and test targets as a numpy array.
y_train = train_data["subject"]
y_test = test_data["subject"]

Implement a baseline classifier

  • We add five FFN blocks with skip connections, so that we generate a baseline model with roughly the same number of parameters as the GNN models to be built later.
  • skip connection을 가진 FFN 블록을 5개 추가해서 나중에 GNN 모델 만들때 파라미터의 수를 대략적으로 기준 모델과 같게 하도록 한다.

위에서 hidden_units은 [32,32]로, num_class는 class index 수로 정의

def create_baseline_model(hidden_units, num_classes, dropout_rate=0.2):
    inputs = layers.Input(shape=(num_features,), name="input_features")
    x = create_ffn(hidden_units, dropout_rate, name=f"ffn_block1")(inputs)
    for block_idx in range(4):
        # Create an FFN block.
        x1 = create_ffn(hidden_units, dropout_rate, name=f"ffn_block{block_idx + 2}")(x)
        # Add skip connection.
        x = layers.Add(name=f"skip_connection{block_idx + 2}")([x, x1])
    # Compute logits.
    logits = layers.Dense(num_classes, name="logits")(x)
    # Create the model.
    return keras.Model(inputs=inputs, outputs=logits, name="baseline")


baseline_model = create_baseline_model(hidden_units, num_classes, dropout_rate)
baseline_model.summary()
Model: "baseline"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_features (InputLayer)    [(None, 1433)]       0           []                               
                                                                                                  
 ffn_block1 (Sequential)        (None, 32)           52804       ['input_features[0][0]']         
                                                                                                  
 ffn_block2 (Sequential)        (None, 32)           2368        ['ffn_block1[0][0]']             
                                                                                                  
 skip_connection2 (Add)         (None, 32)           0           ['ffn_block1[0][0]',             
                                                                  'ffn_block2[0][0]']             
                                                                                                  
 ffn_block3 (Sequential)        (None, 32)           2368        ['skip_connection2[0][0]']       
                                                                                                  
 skip_connection3 (Add)         (None, 32)           0           ['skip_connection2[0][0]',       
                                                                  'ffn_block3[0][0]']             
                                                                                                  
 ffn_block4 (Sequential)        (None, 32)           2368        ['skip_connection3[0][0]']       
                                                                                                  
 skip_connection4 (Add)         (None, 32)           0           ['skip_connection3[0][0]',       
                                                                  'ffn_block4[0][0]']             
                                                                                                  
 ffn_block5 (Sequential)        (None, 32)           2368        ['skip_connection4[0][0]']       
                                                                                                  
 skip_connection5 (Add)         (None, 32)           0           ['skip_connection4[0][0]',       
                                                                  'ffn_block5[0][0]']             
                                                                                                  
 logits (Dense)                 (None, 7)            231         ['skip_connection5[0][0]']       
                                                                                                  
==================================================================================================
Total params: 62,507
Trainable params: 59,065
Non-trainable params: 3,442
__________________________________________________________________________________________________
history = run_experiment(baseline_model, x_train, y_train)
Epoch 1/300
5/5 [==============================] - 2s 81ms/step - loss: 4.7939 - acc: 0.1430 - val_loss: 1.8833 - val_acc: 0.1724
Epoch 2/300
5/5 [==============================] - 0s 25ms/step - loss: 3.0728 - acc: 0.2302 - val_loss: 1.8828 - val_acc: 0.2266
Epoch 3/300
5/5 [==============================] - 0s 25ms/step - loss: 2.6068 - acc: 0.1997 - val_loss: 1.8472 - val_acc: 0.3448
Epoch 4/300
5/5 [==============================] - 0s 27ms/step - loss: 2.3820 - acc: 0.2581 - val_loss: 1.8516 - val_acc: 0.3448
Epoch 5/300
5/5 [==============================] - 0s 29ms/step - loss: 2.1429 - acc: 0.3008 - val_loss: 1.8687 - val_acc: 0.3596
Epoch 6/300
5/5 [==============================] - 0s 27ms/step - loss: 1.9173 - acc: 0.2895 - val_loss: 1.8850 - val_acc: 0.3498
Epoch 7/300
5/5 [==============================] - 0s 26ms/step - loss: 1.7667 - acc: 0.3330 - val_loss: 1.8825 - val_acc: 0.2808
Epoch 8/300
5/5 [==============================] - 0s 27ms/step - loss: 1.7313 - acc: 0.3452 - val_loss: 1.8539 - val_acc: 0.3744
Epoch 9/300
5/5 [==============================] - 0s 28ms/step - loss: 1.7446 - acc: 0.3566 - val_loss: 1.8322 - val_acc: 0.3842
Epoch 10/300
5/5 [==============================] - 0s 29ms/step - loss: 1.6173 - acc: 0.4063 - val_loss: 1.7959 - val_acc: 0.4138
Epoch 11/300
5/5 [==============================] - 0s 27ms/step - loss: 1.5615 - acc: 0.4124 - val_loss: 1.7378 - val_acc: 0.5074
Epoch 12/300
5/5 [==============================] - 0s 27ms/step - loss: 1.5314 - acc: 0.4298 - val_loss: 1.6639 - val_acc: 0.5468
Epoch 13/300
5/5 [==============================] - 0s 25ms/step - loss: 1.3981 - acc: 0.4821 - val_loss: 1.5793 - val_acc: 0.5271
Epoch 14/300
5/5 [==============================] - 0s 24ms/step - loss: 1.3896 - acc: 0.4961 - val_loss: 1.4911 - val_acc: 0.4926
Epoch 15/300
5/5 [==============================] - 0s 28ms/step - loss: 1.2557 - acc: 0.5379 - val_loss: 1.4219 - val_acc: 0.4877
Epoch 16/300
5/5 [==============================] - 0s 27ms/step - loss: 1.2507 - acc: 0.5440 - val_loss: 1.3815 - val_acc: 0.4877
Epoch 17/300
5/5 [==============================] - 0s 26ms/step - loss: 1.2118 - acc: 0.5728 - val_loss: 1.3516 - val_acc: 0.4877
Epoch 18/300
5/5 [==============================] - 0s 25ms/step - loss: 1.1647 - acc: 0.5763 - val_loss: 1.3429 - val_acc: 0.4828
Epoch 19/300
5/5 [==============================] - 0s 25ms/step - loss: 1.1089 - acc: 0.5824 - val_loss: 1.3507 - val_acc: 0.4828
Epoch 20/300
5/5 [==============================] - 0s 24ms/step - loss: 1.0482 - acc: 0.6085 - val_loss: 1.3705 - val_acc: 0.4975
Epoch 21/300
5/5 [==============================] - 0s 26ms/step - loss: 1.0318 - acc: 0.6486 - val_loss: 1.3319 - val_acc: 0.5074
Epoch 22/300
5/5 [==============================] - 0s 28ms/step - loss: 0.9702 - acc: 0.6399 - val_loss: 1.3350 - val_acc: 0.5172
Epoch 23/300
5/5 [==============================] - 0s 26ms/step - loss: 0.9267 - acc: 0.6792 - val_loss: 1.4295 - val_acc: 0.4729
Epoch 24/300
5/5 [==============================] - 0s 24ms/step - loss: 0.8597 - acc: 0.6861 - val_loss: 1.3771 - val_acc: 0.4877
Epoch 25/300
5/5 [==============================] - 0s 24ms/step - loss: 0.8225 - acc: 0.7027 - val_loss: 1.3567 - val_acc: 0.5123
Epoch 26/300
5/5 [==============================] - 0s 26ms/step - loss: 0.8120 - acc: 0.6931 - val_loss: 1.2692 - val_acc: 0.5419
Epoch 27/300
5/5 [==============================] - 0s 27ms/step - loss: 0.7965 - acc: 0.7140 - val_loss: 1.1794 - val_acc: 0.5911
Epoch 28/300
5/5 [==============================] - 0s 27ms/step - loss: 0.7671 - acc: 0.7384 - val_loss: 1.1200 - val_acc: 0.6158
Epoch 29/300
5/5 [==============================] - 0s 25ms/step - loss: 0.7835 - acc: 0.7341 - val_loss: 1.1190 - val_acc: 0.6158
Epoch 30/300
5/5 [==============================] - 0s 28ms/step - loss: 0.7617 - acc: 0.7332 - val_loss: 1.1314 - val_acc: 0.6158
Epoch 31/300
5/5 [==============================] - 0s 25ms/step - loss: 0.6875 - acc: 0.7463 - val_loss: 1.1503 - val_acc: 0.5862
Epoch 32/300
5/5 [==============================] - 0s 26ms/step - loss: 0.7088 - acc: 0.7533 - val_loss: 1.1511 - val_acc: 0.5911
Epoch 33/300
5/5 [==============================] - 0s 26ms/step - loss: 0.6860 - acc: 0.7629 - val_loss: 1.0424 - val_acc: 0.6305
Epoch 34/300
5/5 [==============================] - 0s 29ms/step - loss: 0.6531 - acc: 0.7620 - val_loss: 0.9604 - val_acc: 0.6700
Epoch 35/300
5/5 [==============================] - 0s 30ms/step - loss: 0.6293 - acc: 0.7899 - val_loss: 0.9740 - val_acc: 0.6749
Epoch 36/300
5/5 [==============================] - 0s 26ms/step - loss: 0.6675 - acc: 0.7820 - val_loss: 0.9916 - val_acc: 0.6601
Epoch 37/300
5/5 [==============================] - 0s 25ms/step - loss: 0.6462 - acc: 0.7873 - val_loss: 0.9870 - val_acc: 0.6601
Epoch 38/300
5/5 [==============================] - 0s 25ms/step - loss: 0.6342 - acc: 0.7786 - val_loss: 0.9586 - val_acc: 0.6650
Epoch 39/300
5/5 [==============================] - 0s 24ms/step - loss: 0.6328 - acc: 0.7890 - val_loss: 0.9779 - val_acc: 0.6355
Epoch 40/300
5/5 [==============================] - 0s 24ms/step - loss: 0.5744 - acc: 0.7995 - val_loss: 0.9386 - val_acc: 0.6601
Epoch 41/300
5/5 [==============================] - 0s 24ms/step - loss: 0.5898 - acc: 0.7925 - val_loss: 0.9406 - val_acc: 0.6601
Epoch 42/300
5/5 [==============================] - 0s 25ms/step - loss: 0.5649 - acc: 0.7951 - val_loss: 0.9019 - val_acc: 0.6650
Epoch 43/300
5/5 [==============================] - 0s 25ms/step - loss: 0.5546 - acc: 0.8003 - val_loss: 0.8630 - val_acc: 0.6700
Epoch 44/300
5/5 [==============================] - 0s 25ms/step - loss: 0.5769 - acc: 0.8134 - val_loss: 0.8561 - val_acc: 0.6700
Epoch 45/300
5/5 [==============================] - 0s 27ms/step - loss: 0.5582 - acc: 0.8126 - val_loss: 0.8525 - val_acc: 0.6847
Epoch 46/300
5/5 [==============================] - 0s 26ms/step - loss: 0.5824 - acc: 0.7942 - val_loss: 0.8498 - val_acc: 0.6700
Epoch 47/300
5/5 [==============================] - 0s 27ms/step - loss: 0.5515 - acc: 0.8082 - val_loss: 0.8381 - val_acc: 0.6798
Epoch 48/300
5/5 [==============================] - 0s 28ms/step - loss: 0.5300 - acc: 0.8152 - val_loss: 0.8288 - val_acc: 0.6897
Epoch 49/300
5/5 [==============================] - 0s 28ms/step - loss: 0.5647 - acc: 0.7977 - val_loss: 0.8426 - val_acc: 0.6847
Epoch 50/300
5/5 [==============================] - 0s 29ms/step - loss: 0.5082 - acc: 0.8317 - val_loss: 0.8033 - val_acc: 0.7044
Epoch 51/300
5/5 [==============================] - 0s 28ms/step - loss: 0.4946 - acc: 0.8248 - val_loss: 0.7746 - val_acc: 0.7389
Epoch 52/300
5/5 [==============================] - 0s 25ms/step - loss: 0.5068 - acc: 0.8230 - val_loss: 0.7725 - val_acc: 0.7340
Epoch 53/300
5/5 [==============================] - 0s 26ms/step - loss: 0.5040 - acc: 0.8152 - val_loss: 0.8055 - val_acc: 0.7241
Epoch 54/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4913 - acc: 0.8300 - val_loss: 0.8529 - val_acc: 0.7094
Epoch 55/300
5/5 [==============================] - 0s 25ms/step - loss: 0.5104 - acc: 0.8152 - val_loss: 0.8488 - val_acc: 0.7143
Epoch 56/300
5/5 [==============================] - 0s 26ms/step - loss: 0.5524 - acc: 0.8082 - val_loss: 0.8712 - val_acc: 0.7340
Epoch 57/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4871 - acc: 0.8344 - val_loss: 0.8704 - val_acc: 0.7241
Epoch 58/300
5/5 [==============================] - 0s 26ms/step - loss: 0.5074 - acc: 0.8326 - val_loss: 0.9098 - val_acc: 0.7044
Epoch 59/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4985 - acc: 0.8187 - val_loss: 0.9677 - val_acc: 0.6897
Epoch 60/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4695 - acc: 0.8370 - val_loss: 0.9404 - val_acc: 0.6995
Epoch 61/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4883 - acc: 0.8291 - val_loss: 0.8529 - val_acc: 0.7094
Epoch 62/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4793 - acc: 0.8396 - val_loss: 0.7893 - val_acc: 0.7340
Epoch 63/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4857 - acc: 0.8387 - val_loss: 0.8048 - val_acc: 0.7291
Epoch 64/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4582 - acc: 0.8361 - val_loss: 0.8141 - val_acc: 0.7389
Epoch 65/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4769 - acc: 0.8300 - val_loss: 0.8043 - val_acc: 0.7192
Epoch 66/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4644 - acc: 0.8474 - val_loss: 0.7873 - val_acc: 0.7389
Epoch 67/300
5/5 [==============================] - 0s 27ms/step - loss: 0.4759 - acc: 0.8248 - val_loss: 0.7970 - val_acc: 0.7537
Epoch 68/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4225 - acc: 0.8518 - val_loss: 0.8167 - val_acc: 0.7389
Epoch 69/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4430 - acc: 0.8518 - val_loss: 0.8184 - val_acc: 0.7488
Epoch 70/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4497 - acc: 0.8483 - val_loss: 0.7922 - val_acc: 0.7586
Epoch 71/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4655 - acc: 0.8274 - val_loss: 0.7635 - val_acc: 0.7586
Epoch 72/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4176 - acc: 0.8588 - val_loss: 0.7514 - val_acc: 0.7833
Epoch 73/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4578 - acc: 0.8344 - val_loss: 0.7541 - val_acc: 0.7833
Epoch 74/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4592 - acc: 0.8282 - val_loss: 0.7658 - val_acc: 0.7734
Epoch 75/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4321 - acc: 0.8527 - val_loss: 0.7881 - val_acc: 0.7685
Epoch 76/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4731 - acc: 0.8413 - val_loss: 0.7894 - val_acc: 0.7685
Epoch 77/300
5/5 [==============================] - 0s 28ms/step - loss: 0.4619 - acc: 0.8387 - val_loss: 0.7669 - val_acc: 0.7783
Epoch 78/300
5/5 [==============================] - 0s 27ms/step - loss: 0.4321 - acc: 0.8509 - val_loss: 0.7519 - val_acc: 0.7882
Epoch 79/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4776 - acc: 0.8274 - val_loss: 0.7837 - val_acc: 0.7783
Epoch 80/300
5/5 [==============================] - 0s 24ms/step - loss: 0.3961 - acc: 0.8684 - val_loss: 0.8228 - val_acc: 0.7537
Epoch 81/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4610 - acc: 0.8361 - val_loss: 0.7969 - val_acc: 0.7635
Epoch 82/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4438 - acc: 0.8396 - val_loss: 0.7988 - val_acc: 0.7635
Epoch 83/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4513 - acc: 0.8387 - val_loss: 0.8329 - val_acc: 0.7488
Epoch 84/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4510 - acc: 0.8448 - val_loss: 0.8634 - val_acc: 0.7389
Epoch 85/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4243 - acc: 0.8561 - val_loss: 0.8297 - val_acc: 0.7291
Epoch 86/300
5/5 [==============================] - 0s 27ms/step - loss: 0.4222 - acc: 0.8466 - val_loss: 0.8234 - val_acc: 0.7438
Epoch 87/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4171 - acc: 0.8509 - val_loss: 0.8713 - val_acc: 0.7488
Epoch 88/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4379 - acc: 0.8561 - val_loss: 0.8733 - val_acc: 0.7635
Epoch 89/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4202 - acc: 0.8509 - val_loss: 0.8371 - val_acc: 0.7586
Epoch 90/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4242 - acc: 0.8509 - val_loss: 0.8241 - val_acc: 0.7734
Epoch 91/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4701 - acc: 0.8378 - val_loss: 0.8390 - val_acc: 0.7586
Epoch 92/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4603 - acc: 0.8457 - val_loss: 0.8524 - val_acc: 0.7635
Epoch 93/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4188 - acc: 0.8657 - val_loss: 0.8900 - val_acc: 0.7635
Epoch 94/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4142 - acc: 0.8535 - val_loss: 0.8863 - val_acc: 0.7635
Epoch 95/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4093 - acc: 0.8588 - val_loss: 0.8624 - val_acc: 0.7635
Epoch 96/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4106 - acc: 0.8500 - val_loss: 0.8389 - val_acc: 0.7586
Epoch 97/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4160 - acc: 0.8544 - val_loss: 0.8281 - val_acc: 0.7537
Epoch 98/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4101 - acc: 0.8535 - val_loss: 0.8266 - val_acc: 0.7537
Epoch 99/300
5/5 [==============================] - 0s 26ms/step - loss: 0.3917 - acc: 0.8727 - val_loss: 0.8415 - val_acc: 0.7438
Epoch 100/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4363 - acc: 0.8405 - val_loss: 0.8709 - val_acc: 0.7488
Epoch 101/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4375 - acc: 0.8370 - val_loss: 0.8527 - val_acc: 0.7537
Epoch 102/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4173 - acc: 0.8588 - val_loss: 0.8214 - val_acc: 0.7586
Epoch 103/300
5/5 [==============================] - 0s 25ms/step - loss: 0.3831 - acc: 0.8666 - val_loss: 0.8109 - val_acc: 0.7586
Epoch 104/300
5/5 [==============================] - 0s 25ms/step - loss: 0.3792 - acc: 0.8614 - val_loss: 0.8146 - val_acc: 0.7635
Epoch 105/300
5/5 [==============================] - 0s 24ms/step - loss: 0.3762 - acc: 0.8727 - val_loss: 0.8498 - val_acc: 0.7882
Epoch 106/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4443 - acc: 0.8474 - val_loss: 0.8972 - val_acc: 0.7783
Epoch 107/300
5/5 [==============================] - 0s 25ms/step - loss: 0.3573 - acc: 0.8806 - val_loss: 0.8806 - val_acc: 0.7783
Epoch 108/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4002 - acc: 0.8649 - val_loss: 0.8795 - val_acc: 0.7734
Epoch 109/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4177 - acc: 0.8544 - val_loss: 0.8751 - val_acc: 0.7488
Epoch 110/300
5/5 [==============================] - 0s 25ms/step - loss: 0.3842 - acc: 0.8684 - val_loss: 0.8595 - val_acc: 0.7586
Epoch 111/300
5/5 [==============================] - 0s 24ms/step - loss: 0.4019 - acc: 0.8561 - val_loss: 0.8658 - val_acc: 0.7685
Epoch 112/300
5/5 [==============================] - 0s 25ms/step - loss: 0.3568 - acc: 0.8684 - val_loss: 0.8864 - val_acc: 0.7734
Epoch 113/300
5/5 [==============================] - 0s 28ms/step - loss: 0.3698 - acc: 0.8701 - val_loss: 0.9192 - val_acc: 0.7438
Epoch 114/300
5/5 [==============================] - 0s 26ms/step - loss: 0.4518 - acc: 0.8483 - val_loss: 0.9145 - val_acc: 0.7488
Epoch 115/300
5/5 [==============================] - 0s 25ms/step - loss: 0.3704 - acc: 0.8779 - val_loss: 0.9271 - val_acc: 0.7389
Epoch 116/300
5/5 [==============================] - 0s 25ms/step - loss: 0.4072 - acc: 0.8605 - val_loss: 0.9087 - val_acc: 0.7537
Epoch 117/300
5/5 [==============================] - 0s 27ms/step - loss: 0.3748 - acc: 0.8753 - val_loss: 0.9093 - val_acc: 0.7537
Epoch 118/300
5/5 [==============================] - 0s 25ms/step - loss: 0.3630 - acc: 0.8771 - val_loss: 0.9179 - val_acc: 0.7488
Epoch 119/300
5/5 [==============================] - 0s 26ms/step - loss: 0.3775 - acc: 0.8657 - val_loss: 0.9278 - val_acc: 0.7438
Epoch 120/300
5/5 [==============================] - 0s 25ms/step - loss: 0.3773 - acc: 0.8666 - val_loss: 0.9702 - val_acc: 0.7537
Epoch 121/300
5/5 [==============================] - 0s 24ms/step - loss: 0.3957 - acc: 0.8640 - val_loss: 1.0036 - val_acc: 0.7340
Epoch 122/300
5/5 [==============================] - 0s 24ms/step - loss: 0.3733 - acc: 0.8727 - val_loss: 1.0000 - val_acc: 0.7488
Epoch 123/300
5/5 [==============================] - 0s 24ms/step - loss: 0.3770 - acc: 0.8701 - val_loss: 0.9801 - val_acc: 0.7340
Epoch 124/300
5/5 [==============================] - 0s 23ms/step - loss: 0.3918 - acc: 0.8570 - val_loss: 0.9858 - val_acc: 0.7340
Epoch 125/300
5/5 [==============================] - 0s 24ms/step - loss: 0.3471 - acc: 0.8893 - val_loss: 1.0083 - val_acc: 0.7389
Epoch 126/300
5/5 [==============================] - 0s 24ms/step - loss: 0.3895 - acc: 0.8675 - val_loss: 1.0593 - val_acc: 0.7340
Epoch 127/300
5/5 [==============================] - 0s 25ms/step - loss: 0.3615 - acc: 0.8692 - val_loss: 1.0561 - val_acc: 0.7389
Epoch 128/300
5/5 [==============================] - 0s 29ms/step - loss: 0.3926 - acc: 0.8675 - val_loss: 1.0603 - val_acc: 0.7340
display_learning_curves(history)
_, test_accuracy = baseline_model.evaluate(x=x_test, y=y_test, verbose=0)
print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
Test accuracy: 75.99%

Examine the baseline model predictions

Let's create new data instances by randomly generating binary word vectors with respect to the word presence probabilities.

  • 단어가 존재하는 확률과 관련해서 이진 단어 벡터를 무작위로 만드는 새로운 데이터 인스턴스를 생성
def generate_random_instances(num_instances):
    token_probability = x_train.mean(axis=0)
    instances = []
    for _ in range(num_instances):
        probabilities = np.random.uniform(size=len(token_probability))
        instance = (probabilities <= token_probability).astype(int)
        instances.append(instance)

    return np.array(instances)


def display_class_probabilities(probabilities):
    for instance_idx, probs in enumerate(probabilities):
        print(f"Instance {instance_idx + 1}:")
        for class_idx, prob in enumerate(probs):
            print(f"- {class_values[class_idx]}: {round(prob * 100, 2)}%")
  • tf.convert_to_tensor(logits) logit을 텐서로 변환
new_instances = generate_random_instances(num_classes)
logits = baseline_model.predict(new_instances)
probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
display_class_probabilities(probabilities)
Instance 1:
- Case_Based: 3.23%
- Genetic_Algorithms: 26.95%
- Neural_Networks: 18.55%
- Probabilistic_Methods: 1.72%
- Reinforcement_Learning: 47.51%
- Rule_Learning: 1.04%
- Theory: 1.0%
Instance 2:
- Case_Based: 0.8%
- Genetic_Algorithms: 0.62%
- Neural_Networks: 73.16%
- Probabilistic_Methods: 16.96%
- Reinforcement_Learning: 2.26%
- Rule_Learning: 1.37%
- Theory: 4.84%
Instance 3:
- Case_Based: 0.77%
- Genetic_Algorithms: 1.35%
- Neural_Networks: 93.03%
- Probabilistic_Methods: 0.86%
- Reinforcement_Learning: 2.42%
- Rule_Learning: 0.56%
- Theory: 1.02%
Instance 4:
- Case_Based: 0.74%
- Genetic_Algorithms: 1.68%
- Neural_Networks: 85.81%
- Probabilistic_Methods: 3.06%
- Reinforcement_Learning: 2.03%
- Rule_Learning: 1.14%
- Theory: 5.54%
Instance 5:
- Case_Based: 0.48%
- Genetic_Algorithms: 94.5%
- Neural_Networks: 2.5%
- Probabilistic_Methods: 0.37%
- Reinforcement_Learning: 0.56%
- Rule_Learning: 0.53%
- Theory: 1.06%
Instance 6:
- Case_Based: 2.25%
- Genetic_Algorithms: 0.15%
- Neural_Networks: 1.91%
- Probabilistic_Methods: 93.11%
- Reinforcement_Learning: 0.08%
- Rule_Learning: 0.93%
- Theory: 1.57%
Instance 7:
- Case_Based: 0.8%
- Genetic_Algorithms: 1.73%
- Neural_Networks: 21.15%
- Probabilistic_Methods: 3.58%
- Reinforcement_Learning: 0.51%
- Rule_Learning: 67.77%
- Theory: 4.47%

Build a Graph Neural Network Model

The graph data is represented by the graph_info tuple, which consists of the following three elements:

  1. node_features: This is a [num_nodes, num_features] NumPy array that includes the node features. In this dataset, the nodes are the papers, and the node_features are the word-presence binary vectors of each paper.
    • 노드 특징을 포함하는 넘파이 배열이다.노드는 papers이고, node_features는 각 논문 사이의 단어 존재 이진 벡터이다.
  2. edges: This is [num_edges, num_edges] NumPy array representing a sparse adjacency matrix of the links between the nodes. In this example, the links are the citations between the papers.
    • 노드 사이의 연결의 희소 인접 행렬을 나타내는 넘파이 배열이다. 이 예제에서 link는 paper 사이의 인용이다.
  3. edge_weights (optional): This is a [num_edges] NumPy array that includes the edge weights, which quantify the relationships between nodes in the graph. In this example, there are no weights for the paper citations.
    • 그래프에서 노드 사이 관계를 정량화하는 엣지 가중치를 포함하는 넘파이 배열이다. 이 예제에서 논문 잉용에 대한 가중치는 없다.

sparse adjacency matrix

  • 노드 수보다 엣지 수가 적은 matrix
edges = citations[["source", "target"]].to_numpy().T
# Create an edge weights array of ones.
edge_weights = tf.ones(shape=edges.shape[1])
# Create a node features array of shape [num_nodes, num_features].
node_features = tf.cast(
    papers.sort_values("paper_id")[feature_names].to_numpy(), dtype=tf.dtypes.float32
)
# Create graph info tuple with node_features, edges, and edge_weights.
graph_info = (node_features, edges, edge_weights)

print("Edges shape:", edges.shape)
print("Nodes shape:", node_features.shape)
Edges shape: (2, 5429)
Nodes shape: (2708, 1433)

Implement a graph convolution layer

gru? 가 뭘까..

We implement a graph convolution module as a Keras Layer. Our GraphConvLayer performs the following steps:

  1. Prepare: The input node representations are processed using a FFN to produce a message. You can simplify the processing by only applying linear transformation to the representations.
    • 입력 노드 표현은 메세지를 제공하는 FFN을 사용하여 처리된다. 그 표현을 오직 선형 변환만 적용함으로써 처리를 단순화할 수 있다.
  2. Aggregate: The messages of the neighbours of each node are aggregated with respect to the edge_weights using a permutation invariant pooling operation, such as sum, mean, and max, to prepare a single aggregated message for each node. See, for example, tf.math.unsorted_segment_sum APIs used to aggregate neighbour messages.
    • 각 노드들의 이웃들의 메세지는 각 노드에 대해 단순 집계된 메세지를 준비하기 위해서 sum, mean, max 같은 순열 불변 풀링 operation을 사용한 edge_weights를 나타내어 집계한다. 예를들어, tf.math.unsorted_usum API는 이웃 메세지들을 집계하는데 사용된다.
  3. Update: The node_repesentations and aggregated_messages—both of shape [num_nodes, representation_dim]— are combined and processed to produce the new state of the node representations (node embeddings). If combination_type is gru, the node_repesentations and aggregated_messages are stacked to create a sequence, then processed by a GRU layer. Otherwise, the node_repesentations and aggregated_messages are added or concatenated, then processed using a FFN.
    • node_repesentations 과 aggregated_messages—both는 노드 표현(노드 임베딩)의 새로운 상태를 제공하기 위해 결합되고 처리된다. 만일 conbination type이 gru라면, node_repesentations and aggregated_messages는 시퀀스를 만들기 위해 쌓이게 되면 GPU 층에 의해 처리된다. 반면에 node_repesentations and aggregated_messages가 추가되거나 연결되면 FFN을 사용하여 처리된다.

GRU; Gated Recurrent Unit 게이트 순환 유닛(https://keras.io/api/layers/recurrent_layers/gru/)

  • LSTM에서의 장기 의존성 문제의 해결책은 유지하면서 은닉 상태 업데이트 계산을 줄임
    • 학습 속도는 빠르지만 비슷한 성능을 보인다.
  • 3개의 게이트(출력, 입력, 삭제)가 존재하는 LSTM
  • 2개의 게이트(업데이트, 리셋)만 존재하는 GRU

*arg

  • argument. 인자로 받음.
  • 함수 내에 몇 개의 인자로 받을지 확실하지 않을때 사용

**karg

  • keyword argument, dictionary형테 {'keyword': value}로 함수 내부에 전달

$\star$순서

def f(일반변수, *arg, **karg):
...

super(other class,self).__init__()

  • other class를 상속받는 방법

Hyperbolic Tangent = Hyperbolic Sine / Hyperbolic Cosine $$\tanh z = \frac{\sinh z}{\cosh z} = \frac{e^z - e^{-z}}{e^z + e^{-z}} = \frac{e^{2z} -1}{e^{2z} + 1}$$


tf.expand_dims(,axis=) (https://www.tensorflow.org/api_docs/python/tf/expand_dims)

  • 배열의 차원을 늘려준다.
  • axis는 몇 번째 차원을 늘릴 것인지.

tf.math.unsorted_segment_max (https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_max)

  • 세그먼트들의 최대값을 계산해낸다.

tf.math.unsorted_segment_mean

  • 세그먼트들의 평균을 계산해낸다.

tf.math.unsorted_segment_sum

  • 세그먼트들의 합을 계산해낸다.

class GraphConvLayer(layers.Layer):
    def __init__(
        self,
        hidden_units,
        dropout_rate=0.2,
        aggregation_type="mean",
        combination_type="concat",
        normalize=False,
        *args,
        **kwargs,
    ):
        super(GraphConvLayer, self).__init__(*args, **kwargs)

        self.aggregation_type = aggregation_type
        self.combination_type = combination_type
        self.normalize = normalize

        self.ffn_prepare = create_ffn(hidden_units, dropout_rate)
        if self.combination_type == "gated":
            self.update_fn = layers.GRU(
                units=hidden_units,
                activation="tanh",
                recurrent_activation="sigmoid",
                dropout=dropout_rate,
                return_state=True,
                recurrent_dropout=dropout_rate,
            )
        else:
            self.update_fn = create_ffn(hidden_units, dropout_rate)

    def prepare(self, node_repesentations, weights=None):
        # node_repesentations shape is [num_edges, embedding_dim].
        messages = self.ffn_prepare(node_repesentations)
        if weights is not None:
            messages = messages * tf.expand_dims(weights, -1)
        return messages

    def aggregate(self, node_indices, neighbour_messages):
        # node_indices shape is [num_edges].
        # neighbour_messages shape: [num_edges, representation_dim].
        num_nodes = tf.math.reduce_max(node_indices) + 1
        if self.aggregation_type == "sum":
            aggregated_message = tf.math.unsorted_segment_sum(
                neighbour_messages, node_indices, num_segments=num_nodes
            )
        elif self.aggregation_type == "mean":
            aggregated_message = tf.math.unsorted_segment_mean(
                neighbour_messages, node_indices, num_segments=num_nodes
            )
        elif self.aggregation_type == "max":
            aggregated_message = tf.math.unsorted_segment_max(
                neighbour_messages, node_indices, num_segments=num_nodes
            )
        else:
            raise ValueError(f"Invalid aggregation type: {self.aggregation_type}.")

        return aggregated_message

    def update(self, node_repesentations, aggregated_messages):
        # node_repesentations shape is [num_nodes, representation_dim].
        # aggregated_messages shape is [num_nodes, representation_dim].
        if self.combination_type == "gru":
            # Create a sequence of two elements for the GRU layer.
            h = tf.stack([node_repesentations, aggregated_messages], axis=1)
        elif self.combination_type == "concat":
            # Concatenate the node_repesentations and aggregated_messages.
            h = tf.concat([node_repesentations, aggregated_messages], axis=1)
        elif self.combination_type == "add":
            # Add node_repesentations and aggregated_messages.
            h = node_repesentations + aggregated_messages
        else:
            raise ValueError(f"Invalid combination type: {self.combination_type}.")

        # Apply the processing function.
        node_embeddings = self.update_fn(h)
        if self.combination_type == "gru":
            node_embeddings = tf.unstack(node_embeddings, axis=1)[-1]

        if self.normalize:
            node_embeddings = tf.nn.l2_normalize(node_embeddings, axis=-1)
        return node_embeddings

    def call(self, inputs):
        """Process the inputs to produce the node_embeddings.

        inputs: a tuple of three elements: node_repesentations, edges, edge_weights.
        Returns: node_embeddings of shape [num_nodes, representation_dim].
        """

        node_repesentations, edges, edge_weights = inputs
        # Get node_indices (source) and neighbour_indices (target) from edges.
        node_indices, neighbour_indices = edges[0], edges[1]
        # neighbour_repesentations shape is [num_edges, representation_dim].
        neighbour_repesentations = tf.gather(node_repesentations, neighbour_indices)

        # Prepare the messages of the neighbours.
        neighbour_messages = self.prepare(neighbour_repesentations, edge_weights)
        # Aggregate the neighbour messages.
        aggregated_messages = self.aggregate(node_indices, neighbour_messages)
        # Update the node embedding with the neighbour messages.
        return self.update(node_repesentations, aggregated_messages)

The GNN classification model follows the Design Space for Graph Neural Networks approach, as follows:

  1. Apply preprocessing using FFN to the node features to generate initial node representations.
    • 초기 노드 표현을 일반화하기 위해 노드 특징들을 FFN을 사용하여 준비한다.
  2. Apply one or more graph convolutional layer, with skip connections, to the node representation to produce node embeddings.
    • slip connection을 가진 하나 이상의 그래프 컴볼루션 레이어를 노드 표현에 적용하여 노드 임베딩을 만든다.
  3. Apply post-processing using FFN to the node embeddings to generat the final node embeddings.
    • FFN 사용 후 마지막 노드 임베딩을 만들기 위한 노드 임베딩의 적용
  4. Feed the node embeddings in a Softmax layer to predict the node class.
    • softmax 레이어의 노드 임베딩을 feed해 노드 클래스 예측

Each graph convolutional layer added captures information from a further level of neighbours. However, adding many graph convolutional layer can cause oversmoothing, where the model produces similar embeddings for all the nodes.

  • 각 그래프 컨볼루션 레이어는 더 높은 수준의 이웃으로부터 정보를 capture한다. 하지만 그래프 컨볼루션 레이어를 많이 추가하는 것은 모델이 모든 노드에 대해 유사한 임베딩을 생성하는 oversmoothing을 야기할 수 있다.

Note that the graph_info passed to the constructor of the Keras model, and used as a property of the Keras model object, rather than input data for training or prediction. The model will accept a batch of node_indices, which are used to lookup the node features and neighbours from the graph_info.

  • graph_info는 케라스 모델의 condtructor에게 전달되며, 훈련이나 예측을 위한 입력 데이터보다는 케라스모델 개체의 속성으로 사용된다. 모델은 graph_info에서 노드 특징과 이웃을 lookup하는데 사용되는 node_indices의 배치를 수용한다.
class GNNNodeClassifier(tf.keras.Model):
    def __init__(
        self,
        graph_info,
        num_classes,
        hidden_units,
        aggregation_type="sum",
        combination_type="concat",
        dropout_rate=0.2,
        normalize=True,
        *args,
        **kwargs,
    ):
        super(GNNNodeClassifier, self).__init__(*args, **kwargs)

        # Unpack graph_info to three elements: node_features, edges, and edge_weight.
        node_features, edges, edge_weights = graph_info
        self.node_features = node_features
        self.edges = edges
        self.edge_weights = edge_weights
        # Set edge_weights to ones if not provided.
        if self.edge_weights is None:
            self.edge_weights = tf.ones(shape=edges.shape[1])
        # Scale edge_weights to sum to 1.
        self.edge_weights = self.edge_weights / tf.math.reduce_sum(self.edge_weights)

        # Create a process layer.
        self.preprocess = create_ffn(hidden_units, dropout_rate, name="preprocess")
        # Create the first GraphConv layer.
        self.conv1 = GraphConvLayer(
            hidden_units,
            dropout_rate,
            aggregation_type,
            combination_type,
            normalize,
            name="graph_conv1",
        )
        # Create the second GraphConv layer.
        self.conv2 = GraphConvLayer(
            hidden_units,
            dropout_rate,
            aggregation_type,
            combination_type,
            normalize,
            name="graph_conv2",
        )
        # Create a postprocess layer.
        self.postprocess = create_ffn(hidden_units, dropout_rate, name="postprocess")
        # Create a compute logits layer.
        self.compute_logits = layers.Dense(units=num_classes, name="logits")

    def call(self, input_node_indices):
        # Preprocess the node_features to produce node representations.
        x = self.preprocess(self.node_features)
        # Apply the first graph conv layer.
        x1 = self.conv1((x, self.edges, self.edge_weights))
        # Skip connection.
        x = x1 + x
        # Apply the second graph conv layer.
        x2 = self.conv2((x, self.edges, self.edge_weights))
        # Skip connection.
        x = x2 + x
        # Postprocess node embedding.
        x = self.postprocess(x)
        # Fetch node embeddings for the input node_indices.
        node_embeddings = tf.gather(x, input_node_indices)
        # Compute logits
        return self.compute_logits(node_embeddings)

Train the GNN model

Let's test instantiating and calling the GNN model. Notice that if you provide N node indices, the output will be a tensor of shape [N, num_classes], regardless of the size of the graph.

  • GNN 모델을 인스턴스화하고 호출하는 것을 테스트한다. N개의 node indices를 주면 결과는 그래프의 크기에 관계없이 n by class수 모양의 tensor가 될 것이다.
gnn_model = GNNNodeClassifier(
    graph_info=graph_info,
    num_classes=num_classes,
    hidden_units=hidden_units,
    dropout_rate=dropout_rate,
    name="gnn_model",
)

print("GNN output shape:", gnn_model([1, 10, 100]))

gnn_model.summary()
GNN output shape: tf.Tensor(
[[ 0.04372223 -0.0167095  -0.01827219  0.03111573  0.06829903  0.14005291
   0.03946935]
 [ 0.00511216  0.02377748 -0.26658255  0.1280128  -0.01203793  0.18010132
  -0.06544791]
 [ 0.05582289 -0.02315693 -0.12227201  0.16160506  0.1258692   0.13425142
   0.01099861]], shape=(3, 7), dtype=float32)
Model: "gnn_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 preprocess (Sequential)     (2708, 32)                52804     
                                                                 
 graph_conv1 (GraphConvLayer  multiple                 5888      
 )                                                               
                                                                 
 graph_conv2 (GraphConvLayer  multiple                 5888      
 )                                                               
                                                                 
 postprocess (Sequential)    (2708, 32)                2368      
                                                                 
 logits (Dense)              multiple                  231       
                                                                 
=================================================================
Total params: 67,179
Trainable params: 63,481
Non-trainable params: 3,698
_________________________________________________________________
x_train = train_data.paper_id.to_numpy()
history = run_experiment(gnn_model, x_train, y_train)
Epoch 1/300
5/5 [==============================] - 3s 136ms/step - loss: 2.1877 - acc: 0.1962 - val_loss: 1.8877 - val_acc: 0.3202
Epoch 2/300
5/5 [==============================] - 0s 67ms/step - loss: 1.9742 - acc: 0.2371 - val_loss: 1.8800 - val_acc: 0.3202
Epoch 3/300
5/5 [==============================] - 0s 74ms/step - loss: 1.9301 - acc: 0.2642 - val_loss: 1.8818 - val_acc: 0.3202
Epoch 4/300
5/5 [==============================] - 0s 71ms/step - loss: 1.9075 - acc: 0.2659 - val_loss: 1.8816 - val_acc: 0.3202
Epoch 5/300
5/5 [==============================] - 0s 74ms/step - loss: 1.8849 - acc: 0.2703 - val_loss: 1.8757 - val_acc: 0.3202
Epoch 6/300
5/5 [==============================] - 0s 75ms/step - loss: 1.8482 - acc: 0.2895 - val_loss: 1.8642 - val_acc: 0.3202
Epoch 7/300
5/5 [==============================] - 0s 75ms/step - loss: 1.8424 - acc: 0.2956 - val_loss: 1.8446 - val_acc: 0.3202
Epoch 8/300
5/5 [==============================] - 0s 72ms/step - loss: 1.8283 - acc: 0.3069 - val_loss: 1.8193 - val_acc: 0.3202
Epoch 9/300
5/5 [==============================] - 0s 74ms/step - loss: 1.8073 - acc: 0.3121 - val_loss: 1.7899 - val_acc: 0.3202
Epoch 10/300
5/5 [==============================] - 0s 69ms/step - loss: 1.7979 - acc: 0.3034 - val_loss: 1.7533 - val_acc: 0.3202
Epoch 11/300
5/5 [==============================] - 0s 68ms/step - loss: 1.7920 - acc: 0.3182 - val_loss: 1.7202 - val_acc: 0.3202
Epoch 12/300
5/5 [==============================] - 0s 72ms/step - loss: 1.7356 - acc: 0.3060 - val_loss: 1.6817 - val_acc: 0.3547
Epoch 13/300
5/5 [==============================] - 0s 73ms/step - loss: 1.7408 - acc: 0.3400 - val_loss: 1.6447 - val_acc: 0.3842
Epoch 14/300
5/5 [==============================] - 0s 72ms/step - loss: 1.6764 - acc: 0.3644 - val_loss: 1.5753 - val_acc: 0.4187
Epoch 15/300
5/5 [==============================] - 0s 74ms/step - loss: 1.6356 - acc: 0.3871 - val_loss: 1.5034 - val_acc: 0.4631
Epoch 16/300
5/5 [==============================] - 0s 74ms/step - loss: 1.5970 - acc: 0.4098 - val_loss: 1.4474 - val_acc: 0.4729
Epoch 17/300
5/5 [==============================] - 0s 68ms/step - loss: 1.5434 - acc: 0.4385 - val_loss: 1.4800 - val_acc: 0.4532
Epoch 18/300
5/5 [==============================] - 0s 71ms/step - loss: 1.5043 - acc: 0.4394 - val_loss: 1.5394 - val_acc: 0.4286
Epoch 19/300
5/5 [==============================] - 0s 68ms/step - loss: 1.4557 - acc: 0.4621 - val_loss: 1.4767 - val_acc: 0.4384
Epoch 20/300
5/5 [==============================] - 0s 69ms/step - loss: 1.3903 - acc: 0.4813 - val_loss: 1.4290 - val_acc: 0.4631
Epoch 21/300
5/5 [==============================] - 0s 76ms/step - loss: 1.3899 - acc: 0.4926 - val_loss: 1.3693 - val_acc: 0.5025
Epoch 22/300
5/5 [==============================] - 0s 73ms/step - loss: 1.3234 - acc: 0.5353 - val_loss: 1.2677 - val_acc: 0.5271
Epoch 23/300
5/5 [==============================] - 0s 71ms/step - loss: 1.2578 - acc: 0.5423 - val_loss: 1.1757 - val_acc: 0.5567
Epoch 24/300
5/5 [==============================] - 0s 72ms/step - loss: 1.2165 - acc: 0.5615 - val_loss: 1.2354 - val_acc: 0.6010
Epoch 25/300
5/5 [==============================] - 0s 71ms/step - loss: 1.1845 - acc: 0.5702 - val_loss: 1.4043 - val_acc: 0.5567
Epoch 26/300
5/5 [==============================] - 0s 70ms/step - loss: 1.1798 - acc: 0.5711 - val_loss: 1.5380 - val_acc: 0.5369
Epoch 27/300
5/5 [==============================] - 0s 70ms/step - loss: 1.1060 - acc: 0.6085 - val_loss: 1.4470 - val_acc: 0.5271
Epoch 28/300
5/5 [==============================] - 0s 71ms/step - loss: 1.1271 - acc: 0.6068 - val_loss: 1.3779 - val_acc: 0.5468
Epoch 29/300
5/5 [==============================] - 0s 71ms/step - loss: 1.1139 - acc: 0.6059 - val_loss: 1.2160 - val_acc: 0.5764
Epoch 30/300
5/5 [==============================] - 0s 70ms/step - loss: 1.0634 - acc: 0.6251 - val_loss: 1.1684 - val_acc: 0.6010
Epoch 31/300
5/5 [==============================] - 0s 75ms/step - loss: 1.0316 - acc: 0.6382 - val_loss: 1.0150 - val_acc: 0.6305
Epoch 32/300
5/5 [==============================] - 0s 75ms/step - loss: 1.0388 - acc: 0.6417 - val_loss: 0.9632 - val_acc: 0.6650
Epoch 33/300
5/5 [==============================] - 0s 71ms/step - loss: 1.0357 - acc: 0.6434 - val_loss: 1.0318 - val_acc: 0.6010
Epoch 34/300
5/5 [==============================] - 0s 75ms/step - loss: 0.9671 - acc: 0.6704 - val_loss: 0.9549 - val_acc: 0.6700
Epoch 35/300
5/5 [==============================] - 0s 74ms/step - loss: 0.9515 - acc: 0.6617 - val_loss: 0.9431 - val_acc: 0.6749
Epoch 36/300
5/5 [==============================] - 0s 72ms/step - loss: 0.9105 - acc: 0.6922 - val_loss: 1.0325 - val_acc: 0.6798
Epoch 37/300
5/5 [==============================] - 0s 71ms/step - loss: 0.8872 - acc: 0.6835 - val_loss: 1.0000 - val_acc: 0.6798
Epoch 38/300
5/5 [==============================] - 0s 70ms/step - loss: 0.8596 - acc: 0.6949 - val_loss: 0.8974 - val_acc: 0.7094
Epoch 39/300
5/5 [==============================] - 0s 74ms/step - loss: 0.8571 - acc: 0.6992 - val_loss: 0.8223 - val_acc: 0.7340
Epoch 40/300
5/5 [==============================] - 0s 70ms/step - loss: 0.8205 - acc: 0.7228 - val_loss: 0.8606 - val_acc: 0.7044
Epoch 41/300
5/5 [==============================] - 0s 71ms/step - loss: 0.8041 - acc: 0.7053 - val_loss: 0.8667 - val_acc: 0.6995
Epoch 42/300
5/5 [==============================] - 0s 69ms/step - loss: 0.8122 - acc: 0.7219 - val_loss: 0.8617 - val_acc: 0.7094
Epoch 43/300
5/5 [==============================] - 0s 70ms/step - loss: 0.7596 - acc: 0.7411 - val_loss: 0.8816 - val_acc: 0.6897
Epoch 44/300
5/5 [==============================] - 0s 71ms/step - loss: 0.8079 - acc: 0.7071 - val_loss: 0.8519 - val_acc: 0.7094
Epoch 45/300
5/5 [==============================] - 0s 70ms/step - loss: 0.7778 - acc: 0.7228 - val_loss: 0.8663 - val_acc: 0.6897
Epoch 46/300
5/5 [==============================] - 0s 69ms/step - loss: 0.7614 - acc: 0.7341 - val_loss: 0.8626 - val_acc: 0.6946
Epoch 47/300
5/5 [==============================] - 0s 71ms/step - loss: 0.7623 - acc: 0.7489 - val_loss: 0.9026 - val_acc: 0.6897
Epoch 48/300
5/5 [==============================] - 0s 70ms/step - loss: 0.7841 - acc: 0.7289 - val_loss: 0.8970 - val_acc: 0.6897
Epoch 49/300
5/5 [==============================] - 0s 68ms/step - loss: 0.7634 - acc: 0.7332 - val_loss: 0.8405 - val_acc: 0.7094
Epoch 50/300
5/5 [==============================] - 0s 73ms/step - loss: 0.7288 - acc: 0.7472 - val_loss: 0.7935 - val_acc: 0.7340
Epoch 51/300
5/5 [==============================] - 0s 77ms/step - loss: 0.7450 - acc: 0.7393 - val_loss: 0.7588 - val_acc: 0.7389
Epoch 52/300
5/5 [==============================] - 0s 72ms/step - loss: 0.7054 - acc: 0.7507 - val_loss: 0.7456 - val_acc: 0.7635
Epoch 53/300
5/5 [==============================] - 0s 72ms/step - loss: 0.6709 - acc: 0.7637 - val_loss: 0.7813 - val_acc: 0.7389
Epoch 54/300
5/5 [==============================] - 0s 73ms/step - loss: 0.7004 - acc: 0.7585 - val_loss: 0.7976 - val_acc: 0.7438
Epoch 55/300
5/5 [==============================] - 0s 70ms/step - loss: 0.6860 - acc: 0.7411 - val_loss: 0.7961 - val_acc: 0.7438
Epoch 56/300
5/5 [==============================] - 0s 71ms/step - loss: 0.6063 - acc: 0.7969 - val_loss: 0.7623 - val_acc: 0.7389
Epoch 57/300
5/5 [==============================] - 0s 71ms/step - loss: 0.6509 - acc: 0.7742 - val_loss: 0.7971 - val_acc: 0.7241
Epoch 58/300
5/5 [==============================] - 0s 73ms/step - loss: 0.6344 - acc: 0.7690 - val_loss: 0.7857 - val_acc: 0.7241
Epoch 59/300
5/5 [==============================] - 0s 71ms/step - loss: 0.6125 - acc: 0.7951 - val_loss: 0.7621 - val_acc: 0.7537
Epoch 60/300
5/5 [==============================] - 0s 70ms/step - loss: 0.6248 - acc: 0.7925 - val_loss: 0.7713 - val_acc: 0.7586
Epoch 61/300
5/5 [==============================] - 0s 74ms/step - loss: 0.6197 - acc: 0.7899 - val_loss: 0.7664 - val_acc: 0.7685
Epoch 62/300
5/5 [==============================] - 0s 72ms/step - loss: 0.6795 - acc: 0.7655 - val_loss: 0.8215 - val_acc: 0.7291
Epoch 63/300
5/5 [==============================] - 0s 71ms/step - loss: 0.6455 - acc: 0.7768 - val_loss: 0.8387 - val_acc: 0.7438
Epoch 64/300
5/5 [==============================] - 0s 71ms/step - loss: 0.6594 - acc: 0.7742 - val_loss: 0.8156 - val_acc: 0.7438
Epoch 65/300
5/5 [==============================] - 0s 74ms/step - loss: 0.6221 - acc: 0.7995 - val_loss: 0.7363 - val_acc: 0.7734
Epoch 66/300
5/5 [==============================] - 0s 72ms/step - loss: 0.5990 - acc: 0.7916 - val_loss: 0.6923 - val_acc: 0.7833
Epoch 67/300
5/5 [==============================] - 0s 68ms/step - loss: 0.6435 - acc: 0.7890 - val_loss: 0.7022 - val_acc: 0.7635
Epoch 68/300
5/5 [==============================] - 0s 75ms/step - loss: 0.5928 - acc: 0.7951 - val_loss: 0.7121 - val_acc: 0.7882
Epoch 69/300
5/5 [==============================] - 0s 73ms/step - loss: 0.5883 - acc: 0.7794 - val_loss: 0.7172 - val_acc: 0.7685
Epoch 70/300
5/5 [==============================] - 0s 70ms/step - loss: 0.6226 - acc: 0.7934 - val_loss: 0.7187 - val_acc: 0.7635
Epoch 71/300
5/5 [==============================] - 0s 73ms/step - loss: 0.5928 - acc: 0.7986 - val_loss: 0.7582 - val_acc: 0.7635
Epoch 72/300
5/5 [==============================] - 0s 72ms/step - loss: 0.5620 - acc: 0.8003 - val_loss: 0.7866 - val_acc: 0.7586
Epoch 73/300
5/5 [==============================] - 0s 70ms/step - loss: 0.5647 - acc: 0.8187 - val_loss: 0.7646 - val_acc: 0.7635
Epoch 74/300
5/5 [==============================] - 0s 69ms/step - loss: 0.5887 - acc: 0.7812 - val_loss: 0.6797 - val_acc: 0.7783
Epoch 75/300
5/5 [==============================] - 0s 74ms/step - loss: 0.4954 - acc: 0.8396 - val_loss: 0.6895 - val_acc: 0.7931
Epoch 76/300
5/5 [==============================] - 0s 68ms/step - loss: 0.5888 - acc: 0.8030 - val_loss: 0.7502 - val_acc: 0.7685
Epoch 77/300
5/5 [==============================] - 0s 66ms/step - loss: 0.5595 - acc: 0.8178 - val_loss: 0.8261 - val_acc: 0.7192
Epoch 78/300
5/5 [==============================] - 0s 70ms/step - loss: 0.5855 - acc: 0.7995 - val_loss: 0.8154 - val_acc: 0.7241
Epoch 79/300
5/5 [==============================] - 0s 71ms/step - loss: 0.5644 - acc: 0.8143 - val_loss: 0.6953 - val_acc: 0.7980
Epoch 80/300
5/5 [==============================] - 0s 72ms/step - loss: 0.5807 - acc: 0.8065 - val_loss: 0.6883 - val_acc: 0.7833
Epoch 81/300
5/5 [==============================] - 0s 72ms/step - loss: 0.5641 - acc: 0.8108 - val_loss: 0.7526 - val_acc: 0.7488
Epoch 82/300
5/5 [==============================] - 0s 73ms/step - loss: 0.5326 - acc: 0.8160 - val_loss: 0.7868 - val_acc: 0.7537
Epoch 83/300
5/5 [==============================] - 0s 69ms/step - loss: 0.5484 - acc: 0.8099 - val_loss: 0.7409 - val_acc: 0.7783
Epoch 84/300
5/5 [==============================] - 0s 76ms/step - loss: 0.5109 - acc: 0.8282 - val_loss: 0.6577 - val_acc: 0.8079
Epoch 85/300
5/5 [==============================] - 0s 74ms/step - loss: 0.5091 - acc: 0.8335 - val_loss: 0.6220 - val_acc: 0.8030
Epoch 86/300
5/5 [==============================] - 0s 77ms/step - loss: 0.5085 - acc: 0.8239 - val_loss: 0.6225 - val_acc: 0.8227
Epoch 87/300
5/5 [==============================] - 0s 72ms/step - loss: 0.5207 - acc: 0.8169 - val_loss: 0.6295 - val_acc: 0.8177
Epoch 88/300
5/5 [==============================] - 0s 72ms/step - loss: 0.5202 - acc: 0.8344 - val_loss: 0.6754 - val_acc: 0.7980
Epoch 89/300
5/5 [==============================] - 0s 72ms/step - loss: 0.5037 - acc: 0.8378 - val_loss: 0.6530 - val_acc: 0.8177
Epoch 90/300
5/5 [==============================] - 0s 73ms/step - loss: 0.4948 - acc: 0.8265 - val_loss: 0.6613 - val_acc: 0.8030
Epoch 91/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4624 - acc: 0.8396 - val_loss: 0.6758 - val_acc: 0.8030
Epoch 92/300
5/5 [==============================] - 0s 70ms/step - loss: 0.5198 - acc: 0.8230 - val_loss: 0.6615 - val_acc: 0.7931
Epoch 93/300
5/5 [==============================] - 0s 70ms/step - loss: 0.5176 - acc: 0.8300 - val_loss: 0.6377 - val_acc: 0.8227
Epoch 94/300
5/5 [==============================] - 0s 70ms/step - loss: 0.5125 - acc: 0.8309 - val_loss: 0.6621 - val_acc: 0.8079
Epoch 95/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4803 - acc: 0.8378 - val_loss: 0.6931 - val_acc: 0.7882
Epoch 96/300
5/5 [==============================] - 0s 73ms/step - loss: 0.4444 - acc: 0.8605 - val_loss: 0.6860 - val_acc: 0.8128
Epoch 97/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4569 - acc: 0.8457 - val_loss: 0.6506 - val_acc: 0.8128
Epoch 98/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4916 - acc: 0.8291 - val_loss: 0.6602 - val_acc: 0.8079
Epoch 99/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4997 - acc: 0.8448 - val_loss: 0.6574 - val_acc: 0.8128
Epoch 100/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4550 - acc: 0.8553 - val_loss: 0.6690 - val_acc: 0.8128
Epoch 101/300
5/5 [==============================] - 0s 72ms/step - loss: 0.5246 - acc: 0.8265 - val_loss: 0.6720 - val_acc: 0.7980
Epoch 102/300
5/5 [==============================] - 0s 69ms/step - loss: 0.5120 - acc: 0.8291 - val_loss: 0.6221 - val_acc: 0.8079
Epoch 103/300
5/5 [==============================] - 0s 76ms/step - loss: 0.4695 - acc: 0.8474 - val_loss: 0.6213 - val_acc: 0.8276
Epoch 104/300
5/5 [==============================] - 0s 74ms/step - loss: 0.5030 - acc: 0.8300 - val_loss: 0.6291 - val_acc: 0.8325
Epoch 105/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4667 - acc: 0.8361 - val_loss: 0.6234 - val_acc: 0.8227
Epoch 106/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4428 - acc: 0.8588 - val_loss: 0.6319 - val_acc: 0.8177
Epoch 107/300
5/5 [==============================] - 0s 68ms/step - loss: 0.4335 - acc: 0.8657 - val_loss: 0.6437 - val_acc: 0.8177
Epoch 108/300
5/5 [==============================] - 0s 69ms/step - loss: 0.4952 - acc: 0.8431 - val_loss: 0.6586 - val_acc: 0.8128
Epoch 109/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4548 - acc: 0.8439 - val_loss: 0.6301 - val_acc: 0.8276
Epoch 110/300
5/5 [==============================] - 0s 68ms/step - loss: 0.4712 - acc: 0.8553 - val_loss: 0.6346 - val_acc: 0.8227
Epoch 111/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4393 - acc: 0.8553 - val_loss: 0.6639 - val_acc: 0.7980
Epoch 112/300
5/5 [==============================] - 0s 73ms/step - loss: 0.4362 - acc: 0.8640 - val_loss: 0.6877 - val_acc: 0.7980
Epoch 113/300
5/5 [==============================] - 0s 69ms/step - loss: 0.4578 - acc: 0.8544 - val_loss: 0.7096 - val_acc: 0.8030
Epoch 114/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4213 - acc: 0.8570 - val_loss: 0.6775 - val_acc: 0.8177
Epoch 115/300
5/5 [==============================] - 0s 72ms/step - loss: 0.4376 - acc: 0.8500 - val_loss: 0.6466 - val_acc: 0.8128
Epoch 116/300
5/5 [==============================] - 0s 71ms/step - loss: 0.5007 - acc: 0.8361 - val_loss: 0.6364 - val_acc: 0.8079
Epoch 117/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4509 - acc: 0.8579 - val_loss: 0.6613 - val_acc: 0.8030
Epoch 118/300
5/5 [==============================] - 0s 64ms/step - loss: 0.4280 - acc: 0.8605 - val_loss: 0.6838 - val_acc: 0.8128
Epoch 119/300
5/5 [==============================] - 0s 68ms/step - loss: 0.4141 - acc: 0.8527 - val_loss: 0.6555 - val_acc: 0.8079
Epoch 120/300
5/5 [==============================] - 0s 72ms/step - loss: 0.4236 - acc: 0.8579 - val_loss: 0.6104 - val_acc: 0.8177
Epoch 121/300
5/5 [==============================] - 0s 72ms/step - loss: 0.3830 - acc: 0.8788 - val_loss: 0.6170 - val_acc: 0.8276
Epoch 122/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4386 - acc: 0.8500 - val_loss: 0.6813 - val_acc: 0.8128
Epoch 123/300
5/5 [==============================] - 0s 69ms/step - loss: 0.4543 - acc: 0.8561 - val_loss: 0.6756 - val_acc: 0.8128
Epoch 124/300
5/5 [==============================] - 0s 68ms/step - loss: 0.4453 - acc: 0.8666 - val_loss: 0.6211 - val_acc: 0.8128
Epoch 125/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4049 - acc: 0.8605 - val_loss: 0.5918 - val_acc: 0.8276
Epoch 126/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4375 - acc: 0.8466 - val_loss: 0.5950 - val_acc: 0.8276
Epoch 127/300
5/5 [==============================] - 0s 69ms/step - loss: 0.4776 - acc: 0.8448 - val_loss: 0.5687 - val_acc: 0.8325
Epoch 128/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4447 - acc: 0.8492 - val_loss: 0.5553 - val_acc: 0.8177
Epoch 129/300
5/5 [==============================] - 0s 67ms/step - loss: 0.4246 - acc: 0.8561 - val_loss: 0.5590 - val_acc: 0.8325
Epoch 130/300
5/5 [==============================] - 0s 69ms/step - loss: 0.4335 - acc: 0.8570 - val_loss: 0.5943 - val_acc: 0.8522
Epoch 131/300
5/5 [==============================] - 0s 72ms/step - loss: 0.4201 - acc: 0.8622 - val_loss: 0.6363 - val_acc: 0.8325
Epoch 132/300
5/5 [==============================] - 0s 76ms/step - loss: 0.4653 - acc: 0.8553 - val_loss: 0.6402 - val_acc: 0.8177
Epoch 133/300
5/5 [==============================] - 0s 72ms/step - loss: 0.4149 - acc: 0.8762 - val_loss: 0.6149 - val_acc: 0.8177
Epoch 134/300
5/5 [==============================] - 0s 68ms/step - loss: 0.4466 - acc: 0.8579 - val_loss: 0.6328 - val_acc: 0.8276
Epoch 135/300
5/5 [==============================] - 0s 68ms/step - loss: 0.4338 - acc: 0.8527 - val_loss: 0.6374 - val_acc: 0.8424
Epoch 136/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4378 - acc: 0.8570 - val_loss: 0.6232 - val_acc: 0.8424
Epoch 137/300
5/5 [==============================] - 0s 69ms/step - loss: 0.3791 - acc: 0.8736 - val_loss: 0.5961 - val_acc: 0.8325
Epoch 138/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4135 - acc: 0.8736 - val_loss: 0.5929 - val_acc: 0.8227
Epoch 139/300
5/5 [==============================] - 0s 72ms/step - loss: 0.4238 - acc: 0.8588 - val_loss: 0.5816 - val_acc: 0.8424
Epoch 140/300
5/5 [==============================] - 0s 74ms/step - loss: 0.4258 - acc: 0.8579 - val_loss: 0.6181 - val_acc: 0.8424
Epoch 141/300
5/5 [==============================] - 0s 72ms/step - loss: 0.4194 - acc: 0.8492 - val_loss: 0.6105 - val_acc: 0.8473
Epoch 142/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4059 - acc: 0.8684 - val_loss: 0.5710 - val_acc: 0.8424
Epoch 143/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4143 - acc: 0.8692 - val_loss: 0.5895 - val_acc: 0.8424
Epoch 144/300
5/5 [==============================] - 0s 69ms/step - loss: 0.4052 - acc: 0.8657 - val_loss: 0.6240 - val_acc: 0.8473
Epoch 145/300
5/5 [==============================] - 0s 71ms/step - loss: 0.3988 - acc: 0.8832 - val_loss: 0.6333 - val_acc: 0.8325
Epoch 146/300
5/5 [==============================] - 0s 69ms/step - loss: 0.4382 - acc: 0.8657 - val_loss: 0.6252 - val_acc: 0.8177
Epoch 147/300
5/5 [==============================] - 0s 68ms/step - loss: 0.3865 - acc: 0.8657 - val_loss: 0.5876 - val_acc: 0.8276
Epoch 148/300
5/5 [==============================] - 0s 72ms/step - loss: 0.4112 - acc: 0.8727 - val_loss: 0.5631 - val_acc: 0.8473
Epoch 149/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4102 - acc: 0.8701 - val_loss: 0.5624 - val_acc: 0.8374
Epoch 150/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4126 - acc: 0.8579 - val_loss: 0.5538 - val_acc: 0.8227
Epoch 151/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4102 - acc: 0.8649 - val_loss: 0.5608 - val_acc: 0.8227
Epoch 152/300
5/5 [==============================] - 0s 70ms/step - loss: 0.3632 - acc: 0.8840 - val_loss: 0.5811 - val_acc: 0.8227
Epoch 153/300
5/5 [==============================] - 0s 69ms/step - loss: 0.4210 - acc: 0.8762 - val_loss: 0.5918 - val_acc: 0.8424
Epoch 154/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4198 - acc: 0.8701 - val_loss: 0.5991 - val_acc: 0.8424
Epoch 155/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4325 - acc: 0.8596 - val_loss: 0.6005 - val_acc: 0.8424
Epoch 156/300
5/5 [==============================] - 0s 72ms/step - loss: 0.3665 - acc: 0.8736 - val_loss: 0.6045 - val_acc: 0.8473
Epoch 157/300
5/5 [==============================] - 0s 73ms/step - loss: 0.3607 - acc: 0.8971 - val_loss: 0.6044 - val_acc: 0.8424
Epoch 158/300
5/5 [==============================] - 0s 71ms/step - loss: 0.3911 - acc: 0.8710 - val_loss: 0.6179 - val_acc: 0.8128
Epoch 159/300
5/5 [==============================] - 0s 68ms/step - loss: 0.3758 - acc: 0.8788 - val_loss: 0.5890 - val_acc: 0.8325
Epoch 160/300
5/5 [==============================] - 0s 67ms/step - loss: 0.3846 - acc: 0.8823 - val_loss: 0.5841 - val_acc: 0.8276
Epoch 161/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4128 - acc: 0.8753 - val_loss: 0.5919 - val_acc: 0.8325
Epoch 162/300
5/5 [==============================] - 0s 72ms/step - loss: 0.4082 - acc: 0.8605 - val_loss: 0.6408 - val_acc: 0.8177
Epoch 163/300
5/5 [==============================] - 0s 72ms/step - loss: 0.3895 - acc: 0.8736 - val_loss: 0.6295 - val_acc: 0.8374
Epoch 164/300
5/5 [==============================] - 0s 72ms/step - loss: 0.3569 - acc: 0.8893 - val_loss: 0.6120 - val_acc: 0.8325
Epoch 165/300
5/5 [==============================] - 0s 71ms/step - loss: 0.3792 - acc: 0.8823 - val_loss: 0.6267 - val_acc: 0.8276
Epoch 166/300
5/5 [==============================] - 0s 70ms/step - loss: 0.3875 - acc: 0.8849 - val_loss: 0.6338 - val_acc: 0.8227
Epoch 167/300
5/5 [==============================] - 0s 71ms/step - loss: 0.3719 - acc: 0.8797 - val_loss: 0.6188 - val_acc: 0.8374
Epoch 168/300
5/5 [==============================] - 0s 73ms/step - loss: 0.3407 - acc: 0.8989 - val_loss: 0.6190 - val_acc: 0.8374
Epoch 169/300
5/5 [==============================] - 0s 72ms/step - loss: 0.4117 - acc: 0.8727 - val_loss: 0.6198 - val_acc: 0.8374
Epoch 170/300
5/5 [==============================] - 0s 72ms/step - loss: 0.3675 - acc: 0.8858 - val_loss: 0.6801 - val_acc: 0.8276
Epoch 171/300
5/5 [==============================] - 0s 70ms/step - loss: 0.4233 - acc: 0.8588 - val_loss: 0.6518 - val_acc: 0.8325
Epoch 172/300
5/5 [==============================] - 0s 68ms/step - loss: 0.3627 - acc: 0.8806 - val_loss: 0.6194 - val_acc: 0.8227
Epoch 173/300
5/5 [==============================] - 0s 69ms/step - loss: 0.3301 - acc: 0.8840 - val_loss: 0.6253 - val_acc: 0.8276
Epoch 174/300
5/5 [==============================] - 0s 71ms/step - loss: 0.4151 - acc: 0.8666 - val_loss: 0.6319 - val_acc: 0.8177
Epoch 175/300
5/5 [==============================] - 0s 70ms/step - loss: 0.3803 - acc: 0.8867 - val_loss: 0.6453 - val_acc: 0.8276
Epoch 176/300
5/5 [==============================] - 0s 69ms/step - loss: 0.3638 - acc: 0.8832 - val_loss: 0.6529 - val_acc: 0.8325
Epoch 177/300
5/5 [==============================] - 0s 67ms/step - loss: 0.3633 - acc: 0.8797 - val_loss: 0.6146 - val_acc: 0.8227
Epoch 178/300
5/5 [==============================] - 0s 72ms/step - loss: 0.3982 - acc: 0.8797 - val_loss: 0.6068 - val_acc: 0.8276
Epoch 179/300
5/5 [==============================] - 0s 70ms/step - loss: 0.3731 - acc: 0.8762 - val_loss: 0.6217 - val_acc: 0.8227
Epoch 180/300
5/5 [==============================] - 0s 75ms/step - loss: 0.3586 - acc: 0.8963 - val_loss: 0.6319 - val_acc: 0.8227

Note that we use the standard supervised cross-entropy loss to train the model. However, we can add another self-supervised loss term for the generated node embeddings that makes sure that neighbouring nodes in graph have similar representations, while faraway nodes have dissimilar representations.

  • 모델 학습을 위해 cross-entropy loss 이용. 하지만 생성된 node embedding에 대해 또다른 자체 supervised loss term을 추가할 수 있다. 여기서 멀리 있는 노드들은 서로다른 표현을 가지는동안 그래프에 이웃한 노드들은 유사한 표현을 가지고 있다.

임베딩하는 함수가 representation이라고 보면 될듯

display_learning_curves(history)

Now we evaluate the GNN model on the test data split. The results may vary depending on the training sample, however the GNN model always outperforms the baseline model in terms of the test accuracy.

  • 분할한 test 데이터에 GNN모델 평가. 결과는 훈련 샘플에 따라 달라질 수 있지만 GNN 모델은 항상 테스트 정확도의 면에서 기준모델을 능가한다.
x_test = test_data.paper_id.to_numpy()
_, test_accuracy = gnn_model.evaluate(x=x_test, y=y_test, verbose=0)
print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
Test accuracy: 82.18%

Examine the GNN model predictions

Let's add the new instances as nodes to the node_features, and generate links (citations) to existing nodes.

  • node_feature에 새로운 인스턴스 추가하고 기존노드에 인용연결 생성
# by appending the new_instance to node_features.
num_nodes = node_features.shape[0]
new_node_features = np.concatenate([node_features, new_instances])
# Second we add the M edges (citations) from each new node to a set
# of existing nodes in a particular subject
new_node_indices = [i + num_nodes for i in range(num_classes)]
new_citations = []
for subject_idx, group in papers.groupby("subject"):
    subject_papers = list(group.paper_id)
    # Select random x papers specific subject.
    selected_paper_indices1 = np.random.choice(subject_papers, 5)
    # Select random y papers from any subject (where y < x).
    selected_paper_indices2 = np.random.choice(list(papers.paper_id), 2)
    # Merge the selected paper indices.
    selected_paper_indices = np.concatenate(
        [selected_paper_indices1, selected_paper_indices2], axis=0
    )
    # Create edges between a citing paper idx and the selected cited papers.
    citing_paper_indx = new_node_indices[subject_idx]
    for cited_paper_idx in selected_paper_indices:
        new_citations.append([citing_paper_indx, cited_paper_idx])

new_citations = np.array(new_citations).T
new_edges = np.concatenate([edges, new_citations], axis=1)

Now let's update the node_features and the edges in the GNN model.

print("Original node_features shape:", gnn_model.node_features.shape)
print("Original edges shape:", gnn_model.edges.shape)
gnn_model.node_features = new_node_features
gnn_model.edges = new_edges
gnn_model.edge_weights = tf.ones(shape=new_edges.shape[1])
print("New node_features shape:", gnn_model.node_features.shape)
print("New edges shape:", gnn_model.edges.shape)

logits = gnn_model.predict(tf.convert_to_tensor(new_node_indices))
probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
display_class_probabilities(probabilities)
Original node_features shape: (2708, 1433)
Original edges shape: (2, 5429)
New node_features shape: (2715, 1433)
New edges shape: (2, 5478)
Instance 1:
- Case_Based: 52.86%
- Genetic_Algorithms: 1.57%
- Neural_Networks: 30.51%
- Probabilistic_Methods: 2.98%
- Reinforcement_Learning: 5.21%
- Rule_Learning: 3.7%
- Theory: 3.17%
Instance 2:
- Case_Based: 0.08%
- Genetic_Algorithms: 97.66%
- Neural_Networks: 1.29%
- Probabilistic_Methods: 0.24%
- Reinforcement_Learning: 0.37%
- Rule_Learning: 0.03%
- Theory: 0.33%
Instance 3:
- Case_Based: 0.26%
- Genetic_Algorithms: 0.12%
- Neural_Networks: 93.04%
- Probabilistic_Methods: 3.72%
- Reinforcement_Learning: 0.16%
- Rule_Learning: 0.05%
- Theory: 2.65%
Instance 4:
- Case_Based: 0.21%
- Genetic_Algorithms: 0.36%
- Neural_Networks: 86.56%
- Probabilistic_Methods: 9.43%
- Reinforcement_Learning: 0.49%
- Rule_Learning: 0.06%
- Theory: 2.89%
Instance 5:
- Case_Based: 0.18%
- Genetic_Algorithms: 96.78%
- Neural_Networks: 0.27%
- Probabilistic_Methods: 0.05%
- Reinforcement_Learning: 2.56%
- Rule_Learning: 0.04%
- Theory: 0.12%
Instance 6:
- Case_Based: 0.08%
- Genetic_Algorithms: 0.1%
- Neural_Networks: 0.93%
- Probabilistic_Methods: 97.9%
- Reinforcement_Learning: 0.15%
- Rule_Learning: 0.05%
- Theory: 0.79%
Instance 7:
- Case_Based: 0.09%
- Genetic_Algorithms: 96.84%
- Neural_Networks: 0.57%
- Probabilistic_Methods: 0.13%
- Reinforcement_Learning: 0.13%
- Rule_Learning: 0.25%
- Theory: 1.99%