import torch
from torch_geometric.data import Data
TORCH_GEOMETRIC.NN
GCN
221207
https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html
= torch.tensor([[0, 1, 1, 2],
edge_index 1, 0, 2, 1]], dtype=torch.long)
[= torch.tensor([[-1], [0], [1]], dtype=torch.float)
x = Data(x=x, edge_index=edge_index) data
data
Data(x=[3, 1], edge_index=[2, 4])
import networkx as nx
import matplotlib.pyplot as plt
=nx.Graph()
G'0')
G.add_node('1')
G.add_node('2')
G.add_node('0','1')
G.add_edge('1','2')
G.add_edge(= {}
pos '0'] = (0,0)
pos['1'] = (1,1)
pos['2'] = (2,0)
pos[=True)
nx.draw(G,pos,with_labels plt.show()
from torch.nn import Linear, ReLU
from torch_geometric.nn import Sequential, GCNConv
ex
= Sequential('x, edge_index', [
model 64), 'x, edge_index -> x'),
(GCNConv(in_channels, =True),
ReLU(inplace64, 64), 'x, edge_index -> x'),
(GCNConv(=True),
ReLU(inplace64, out_channels),
Linear( ])
= Sequential('x, edge_index', [
model 3, 64), 'x, edge_index -> x'),
(GCNConv(=True),
ReLU(inplace64, 64), 'x, edge_index -> x'),
(GCNConv(=True),
ReLU(inplace64, 3),
Linear( ])
model(x,edge_index)
from torch.nn import Linear, ReLU, Dropout
from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge
from torch_geometric.nn import global_mean_pool
= Sequential('x, edge_index, batch', [
model =0.5), 'x -> x'),
(Dropout(p64), 'x, edge_index -> x1'),
(GCNConv(dataset.num_features, =True),
ReLU(inplace64, 64), 'x1, edge_index -> x2'),
(GCNConv(=True),
ReLU(inplacelambda x1, x2: [x1, x2], 'x1, x2 -> xs'),
("cat", 64, num_layers=2), 'xs -> x'),
(JumpingKnowledge('x, batch -> x'),
(global_mean_pool, 2 * 64, dataset.num_classes),
Linear( ])
= Sequential('x, edge_index, batch', [
model =0.5), 'x -> x'),
(Dropout(p64), 'x, edge_index -> x1'),
(GCNConv(dataset.num_features, =True),
ReLU(inplace64, 64), 'x1, edge_index -> x2'),
(GCNConv(=True),
ReLU(inplacelambda x1, x2: [x1, x2], 'x1, x2 -> xs'),
("cat", 64, num_layers=2), 'xs -> x'),
(JumpingKnowledge('x, batch -> x'),
(global_mean_pool, 2 * 64, dataset.num_classes),
Linear( ])
torch_geometric.nn.Linear()