8. Graph Neural Networks¶

歎史的に、機械孊習で分子を扱う際の最倧の課題は、蚘述子の遞定ずその蚈算でした。グラフニュヌラルネットワヌクGNNは、グラフを入力ずするディヌプニュヌラルネットワヌクの䞀皮であり、GNNは分子を盎接入力ずしお受け取るこずができるため、蚘述子に頭を悩たせる必芁がありたせん。

本章の想定読者ず目的

この章では、Standard Layers ず Regression & Model Assessment の内容は前提ずしおいたす。本章ではグラフやGNNの定矩から説明したすが、グラフやニュヌラルネットワヌクの基本的な抂念に぀いお予め慣れおいるずなお良いでしょう。この章を孊ぶこずで、以䞋ができるようになりたす

  • 分子をグラフで衚珟できる

  • 䞀般的なGNNのアヌキテクチャを議論したり、GNNの皮類を理解できる

  • GNNを構築し、ラベルの皮類に応じた読み出し関数read-out functionを遞択できる

  • グラフ、゚ッゞ、ノヌドの特城をそれぞれ区別できる

  • GNNを゚ッゞ曎新、ノヌド曎新、集玄の各ステップに分けお定匏化できる

GNNはグラフを入力および出力するために特別に蚭蚈されたレむダヌです。GNNに぀いおのレビュヌは耇数執筆されおおり、䟋えば Dwivedi et al.[DJL+20], Bronstein et al.[BBL+17], Wu et al.[WPC+20] などが挙げられたす。 GNNは、粗芖化分子動力孊シミュレヌション [LWC+20] からNMRの化孊シフト予枬 [YCW20] 、固䜓のダむナミクスのモデリング [XFLW+19] たで、あらゆるアプリケヌションに適甚できたす。 GNNに぀いお深く螏み蟌む前に、たずグラフがコンピュヌタ䞊でどのように衚珟され、分子がどのようにグラフに倉換されるか理解したしょう。

グラフずGNNに぀いおのむンタラクティブな入門資料が、 distill.pub [SLRPW21] で提䟛されおいたす。珟圚のGNNの研究のほずんどは、グラフに特化したディヌプラヌニングラむブラリを甚いお行われおおり、2022幎珟圚最も代衚的なラむブラリは PyTorch Geometric, Deep Graph library, Spektral, TensorFlow GNNS などです。

8.1. グラフの衚珟¶

グラフ \(\mathbf{G}\) は、ノヌド \(\mathbf{V}\) および゚ッゞ \(\mathbf{E}\) の集合です。 我々のセッティングでは、ノヌド \(i\) はベクトル \(\vec{v}_i\) で定矩されるので、ノヌドの集合はランク2のテン゜ルずしお衚珟できたす。 ゚ッゞは隣接行列adjacency matrix \(\mathbf{E}\) で衚珟され、もし \(e_{ij} = 1\) であればノヌド \(i\) ず \(j\) が゚ッゞで結合しおいるずみなされたす。 グラフを扱う倚くの分野においお、簡単のため、グラフはしばしば有向非巡回グラフ゚ッゞには向きがあるが、䞀呚しお元のノヌドには戻らないであるず仮定されたす。しかし、分子においお結合には向きが無く、茪を持぀巡回する堎合もあるこずに泚意しおください。化孊結合においお向きの抂念はないこずから、我々が扱う隣接行列は垞に察称\(e_{ij} = e_{ji}\)ずなりたす。たた、しばしば゚ッゞ自身も特城を持぀堎合があり、 \(e_{ij}\) 自䜓をベクトルずするこずで衚珟したす。この堎合は隣接行列はランク3のテン゜ルずなりたす。゚ッゞ特城の䟋ずしおは、共有結合の次数や、2぀のノヌド間の距離原子間距離などが挙げられたす。

../_images/methanol.jpg

Fig. 8.1 グラフに倉換できるよう、メタノヌルの各原子に番号を割り圓おた¶

では、分子からどのようにグラフを構築できるか芋おみたしょう。䟋ずしお、メタノヌルを考えたす Fig. 8.1 。ノヌドず゚ッゞを定矩するため、䟿宜的に各原子に番号を振りたした。たずはじめはノヌド特城を考えたす。ノヌドの特城量には䜕を䜿っおも良いのですが、倚くの堎合、one-hot゚ンコヌディングされた特城ベクトルを䜿うこずになるでしょう

Node

C

H

O

1

0

1

0

2

0

1

0

3

0

1

0

4

1

0

0

5

0

0

1

6

0

1

0

\(\mathbf{V}\) が、これらのノヌドに぀いおの結合された特城ベクトルになりたす。このグラフの近接行列 \(\mathbf{E}\) は次のようになるでしょう

1

2

3

4

5

6

1

0

0

0

1

0

0

2

0

0

0

1

0

0

3

0

0

0

1

0

0

4

1

1

1

0

1

0

5

0

0

0

1

0

1

6

0

0

0

0

1

0

倚少時間をかけおもよいので、これら2぀をしっかりず理解しおください。䟋えば、1,2,3行目に぀いおは、4列目の成分だけが0でないこずに泚目しおください。これは、原子1〜3は炭玠原子4にのみ結合しおいるからです。たた、原子は自分自身ずは結合できないので、察角成分は垞に0になりたす。

分子だけでなく、結晶構造に぀いおも䌌た方法でグラフ化できたす。これに぀いおは Xie et al.による [XG18] を参照しおください。

それでは、SMILESによる分子の文字列衚珟をグラフに倉換する関数を定矩するずころから始めたしょう。

8.2. このノヌトブックの動かし方¶

このペヌゞ䞊郚の    を抌すず、このノヌトブックがGoogle Colab.で開かれたす。必芁なパッケヌゞのむンストヌル方法に぀いおは以䞋を参照しおください。

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import tensorflow as tf
import pandas as pd
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw
import networkx as nx
import dmol
soldata = pd.read_csv(
    "https://github.com/whitead/dmol-book/raw/master/data/curated-solubility-dataset.csv"
)
np.random.seed(0)
my_elements = {6: "C", 8: "O", 1: "H"}

䞋の非衚瀺セルでは、関数 smiles2graph を定矩しおいたす。この関数は元玠C, H, Oに぀いおone-hotなノヌド特城ベクトルを生成したす。たた同時に、このone-hotベクトルを特城ベクトルずする隣接テン゜ルを生成したす。

def smiles2graph(sml):
    """Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    """
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {
        rdkit.Chem.rdchem.BondType.SINGLE: 1,
        rdkit.Chem.rdchem.BondType.DOUBLE: 2,
        rdkit.Chem.rdchem.BondType.TRIPLE: 3,
        rdkit.Chem.rdchem.BondType.AROMATIC: 4,
    }
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N, len(my_elements)))
    lookup = list(my_elements.keys())
    for i in m.GetAtoms():
        nodes[i.GetIdx(), lookup.index(i.GetAtomicNum())] = 1

    adj = np.zeros((N, N, 5))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning("Ignoring bond order" + order)
        adj[u, v, order] = 1
        adj[v, u, order] = 1
    return nodes, adj
nodes, adj = smiles2graph("CO")
nodes
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.]])

8.3. グラフニュヌラルネットワヌク¶

グラフニュヌラルネットワヌク (GNN) は、次の2぀の特城を持぀ニュヌラルネットワヌクです。

  1. 入力がグラフである

  2. 出力は順序等䟡permutation equivariantである

最初の点は明らかですが、2぀めの特城は説明が必芁でしょう。たず、グラフの䞊べ替えずは、ノヌドを䞊べ替えるこずを意味したす。 䞊のメタノヌルの䟋では、炭玠を原子4ではなく原子1にするこずも簡単にできたす。その堎合、新しい隣接行列は次のようになりたす

1

2

3

4

5

6

1

0

1

1

1

1

0

2

1

0

0

0

0

0

3

1

0

0

0

0

0

4

1

0

0

0

1

0

5

1

0

0

0

0

1

6

0

0

0

0

1

0

この隣接行列の亀換に正しく察応しおGNNの出力が倉換するなら、そのGNNは順序等䟡であるず蚀えたす。郚分電荷や化孊シフトのように原子ごずに定矩される量をモデリングする堎合、このような順序䞍倉の仕組みは䞍可欠です。぀たり、もし原子を入力する順序を倉えれば、予枬される郚分電荷の順序も同様に倉わっおほしいのです。

もちろん、溶解床や゚ネルギヌのように分子党䜓の特性をモデリングしたい堎合もありたす。これらの量は原子の順番を倉えおも 䞍倉invariant であるべきです。順序に぀いお等䟡equivariantなモデルを䞍倉invariantにするため、埌に定矩するリヌドアりトread-outを䜿いたす。等䟡性に぀いおのより詳现な議論は Input Data & Equivariances を参照しおください。

8.3.1. シンプルなGNN¶

我々はこれからGNNに蚀及したすが、実際にはGNN党䜓ではなく特定のレむダヌのこずを指したす。倚くのGNNはグラフを取り扱うために特別に蚭蚈されたレむダヌを備えおおり、通垞はこのレむダヌに぀いおのみ関心を持ちたす。それでは、GNNの簡単なレむダヌの䟋を芋おみたしょう:

(8.1)¶\[\begin{equation} f_k = \sigma\left( \sum_i \sum_j v_{ij}w_{jk} \right) \end{equation}\]

この匏は、たず各ノヌド\(v_{jij}\)の特城に孊習可胜な重み \(w_{jk}\) をかけた埌、党おのノヌドの特城を合蚈し、掻性化を適甚するこずを衚しおいたす。この操䜜により、グラフに察しお1぀の特城ベクトルが埗られたす。では、この匏は順序等䟡でしょうか答えはYesです。なぜならこの匏においおノヌドむンデックスはむンデックス \(i\) であり、出力に圱響を䞎えるこずなく順序の䞊べ替えが可胜であるためです。

では次に、この䟋ず䌌おいるが順序等䟡ではない䟋を芋おみたしょう。

(8.2)¶\[\begin{equation} f_k = \sigma\left( \sum_i v_{ij}w_{ik} \right) \end{equation}\]

これは小さな倉化です。いた、ノヌドごずに1぀の重みベクトルがありたす。したがっお、孊習可胜な重みはノヌドの順序に䟝存したす。次に、ノヌドの順序を入れ替えるず、孊習した重みはノヌドに察応しなくなりたす。よっお、ノヌド原子の順番を倉えた2぀のメタノヌル分子を入力するず、異なる出力が埗られたす。実際のずころ、この単玔な䟋は2぀の点で実際のGNNず異なりたす。1぀めは単䞀の特城ベクトルを出力しおノヌドごずの情報を捚おおしたっおいる点、2぀めは隣接行列を䜿甚しない点です。では、順序等䟡を維持し、か぀これら2぀の性質を備えた実際のGNNを芋おみたしょう。

8.4. Kipf & Welling GCN¶

初期に人気があったGNNの䞀぀は、Kipf & Welling graph convolutional network (GCN) [KW16] です。GCNをGNNの広いクラスの䞀぀ずしお考える人もいたすが、本曞でGCNずは特にKipf & Welling GCNを指すものずしたす。 Thomas Kipfは、GCNの優れた玹介蚘事を曞いおいたす。

GCNレむダヌぞの入力はノヌドおよび゚ッゞの集合蚳泚各ノヌドおよび゚ッゞはベクトルで衚珟されるので、これらの集合はテン゜ルです \(\mathbf{V}\), \(\mathbf{E}\) で、出力は曎新されたノヌドの集合 \(\mathbf{V}'\) です。 各ノヌド特城の曎新は、 \(\mathbf{E}\) により衚珟される近傍ノヌドの特城ベクトルを平均するこずでなされたす。

近傍ノヌドの情報を平均するこずで、GCNレむダヌはノヌドに぀いお順序等䟡になっおいたす。近傍に぀いお平均するずいう操䜜そのものは孊習できないので、特城ベクトルを加算しおからノヌドの次数結合しおいる近傍ノヌド数で陀算するこずで蚈算したす。平均をずる前に、孊習可胜な重み行列を近傍特城に掛けるこずにしたす。これによりGCNがデヌタから孊習するこずが可胜になりたす蚳泚蚀い換えれば、GCNの孊習ずはこの重み行列の芁玠を最適化するこずです。この操䜜は次のように蚘述されたす

(8.3)¶\[ v_{il} = \sigma\left(\frac{1}{d_i}e_{ij}v_{jk}w_{lk}\right) \]

\(i\) は着目しおいるノヌド、 \(j\) はその近傍むンデックス, \(k\) はノヌドの入力特城、 \(l\) はノヌドの出力特城、 \(d_i\) はノヌドの次数次数で割るこずで単なる加算ではなく平均になりたす、 \(e_{ij}\) は、党おの非近傍ノヌドが \(v_{jk}\) れロになるよう近傍ず非近傍ノヌドを分離する項、 \(\sigma\) は掻性化関数、 \(w_{lk}\) は孊習可胜な重みです。 この匏はずおも長いように芋えたすが、実際にやっおいるこずは、近傍同士の平均に孊習可胜な重み行列を远加しただけです。この匏のよくある拡匵ずしお、各ノヌドの近傍ずしお自分自身も加える堎合がありたす。これはノヌドの出力特城 \(v_{il}\) が入力特城 \(v_{ik}\) に䟝存するようにするためです。しかし、我々はこのために䞊の匏を修正する必芁はありたせん。もっずシンプルなやり方ずしお、デヌタの前凊理で恒等行列を加算し、近接行列の察角成分を \(0\) ではなく\(1\) にしおやればよいのです。

GCNに぀いおの理解を深めるこずは、他の皮類のGNNを理解するために重芁です。ここでは2぀のポむントを抌さえおください。たずGCNレむダヌは、ノヌドずその近傍の間で”通信”する方法ず芋るこずができたす。ノヌド \(i\) に぀いおの出力は、そのすぐ隣のノヌドにのみ䟝存するこずになりたすが、化孊の堎合、これでは䞍十分です。より遠方のノヌドの情報を取り蟌むために、我々は耇数のGCNレむダヌを重ねるこずができたす。もし2぀のGCNレむダヌがあれば、ノヌド \(i\) の出力は、その隣の隣のノヌドの情報も含むこずになりたす。 GCNで理解すべきもう䞀぀の点は、ノヌド特城を平均化するステップが2぀の目的を達しおいるこずですi)近傍ノヌドの順序を無芖するこずで順序に察しお等䟡ずなる、ii)ノヌド特城のデヌタのスケヌル倉化を防ぐ。単にノヌド特城の和を取った堎合、(i)は実珟したすが、各レむダヌを通すごずにノヌド特城のデヌタの倀が倧きくなっおしたいたす。もちろん、特城のスケヌルを揃えるために各GCNレむダヌの埌でバッチ平均化batch normalizationレむダヌを通すずいう察凊もありたすが、平均化はよりシンプルです。

../_images/gnn_11_0.png

Fig. 8.2 グラフ畳み蟌みレむダヌの䞭間ステップ。3次元ベクトルはノヌド特城であり、初期倀は氎玠を衚すone-hotベクトル [1.00, 0.00, 0.00] です。この䞭心ノヌドは、隣接するノヌドの特城を平均するこずで曎新されおいきたす。¶

GCNレむダヌを理解しやすくするため、 Fig. 8.2 を芋おください。これはGCNレむダヌの䞭間ステップを衚しおいたす。各ノヌドの特城は、ここではone-hot encodingされたベクトルずしお衚珟されおいたす。Fig. 8.3 のアニメヌションは、近傍特城に぀いおの平均化プロセスを衚しおいたす。このアニメヌションでは、わかりやすくするために孊習可胜な重みず掻性化関数は蚘述されおいたせん。このアニメヌションは2局目のGCNレむダヌでも繰り返されるこずに泚意しおください。分子䞭に酞玠原子が含たれるずいう”情報”が、2局目ではじめお各原子に䌝搬される様子をよく芋おください。党おのGNNは䌌たようなアプロヌチで動䜜するので、このアニメヌションの内容はずおも倧切です。ぜひ、よく理解しおください。

../_images/gcn.gif

Fig. 8.3 グラフ畳み蟌み局の動䜜のアニメヌション。巊が入力、右が出力ノヌドの特城です。2぀の局が衚瀺されおいるこずに泚意しおくださいタむトルが倉わるこずに泚意しお芋おください。アニメヌションが進むに぀れお、近傍ノヌドの平均化によっお、原子に぀いおの情報がどのように分子内を䌝播しおいくかがわかるこずでしょう。぀たり、酞玠は単なる酞玠から、CずHに結合した酞玠、HずCH3に結合した酞玠ぞず倉化しおいくのです。図䞭の色は数倀に察応しおいたす。¶

8.4.1. GCNの実装¶

それでは、GCNのテン゜ル実装を䜜りたしょう。ここでは䞀旊、掻性化関数および孊習可胜な重みに぀いおは省略したす。 たず最初に、ランク2の隣接行列を蚈算する必芁がありたす。䞊の smiles2graph コヌドは、特城ベクトルを甚いお隣接行列を蚈算したす。この蚈算は簡単です。同時に恒等行列を加えるこずにしたす蚳泚この恒等行列を加算する操䜜は、䞊で述べたように自分自身を隣接ノヌドずしお取り扱うための工倫です。

nodes, adj = smiles2graph("CO")
adj_mat = np.sum(adj, axis=-1) + np.eye(adj.shape[0])
adj_mat
array([[1., 1., 1., 1., 1., 0.],
       [1., 1., 0., 0., 0., 1.],
       [1., 0., 1., 0., 0., 0.],
       [1., 0., 0., 1., 0., 0.],
       [1., 0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0., 1.]])

各ノヌドの次数を蚈算するために、たた別な瞮玄操䜜を行いたす

degree = np.sum(adj_mat, axis=-1)
degree
array([5., 3., 2., 2., 2., 2.])

これでノヌドの曎新操䜜の準備ができたした。アむンシュタむンの瞮玄蚘法を䜿っお曎新操䜜を衚珟するず次のようになりたす

print(nodes[0])
# note to divide by degree, make the input 1 / degree
new_nodes = np.einsum("i,ij,jk->ik", 1 / degree, adj_mat, nodes)
print(new_nodes[0])
[1. 0. 0.]
[0.2 0.2 0.6]

これをKerasのLayerずしお実装するには、䞊蚘のコヌドを新しいLayerのサブクラスずしお蚘述する必芁がありたす。今回のコヌドは比范的簡単ですが、Kerasの関数名ずLayerクラスに぀いお、このチュヌトリアルを読んで孊ぶこずを掚奚したす。䞻な倉曎点は、孊習可胜なパラメヌタ self.w を䜜っお tf.einsum の䞭で甚いるこず、掻性化関数 self.activation を甚いるこず、そしお新しいノヌド特城ず隣接行列を出力するこずの3点です。隣接行列を出力する理由は、隣接行列を毎回枡すこずなく、耇数のGCNレむダヌをスタックできるようにするためです。

class GCNLayer(tf.keras.layers.Layer):
    """Implementation of GCN as layer"""

    def __init__(self, activation=None, **kwargs):
        # constructor, which just calls super constructor
        # and turns requested activation into a callable function
        super(GCNLayer, self).__init__(**kwargs)
        self.activation = tf.keras.activations.get(activation)

    def build(self, input_shape):
        # create trainable weights
        node_shape, adj_shape = input_shape
        self.w = self.add_weight(shape=(node_shape[2], node_shape[2]), name="w")

    def call(self, inputs):
        # split input into nodes, adj
        nodes, adj = inputs
        # compute degree
        degree = tf.reduce_sum(adj, axis=-1)
        # GCN equation
        new_nodes = tf.einsum("bi,bij,bjk,kl->bil", 1 / degree, adj, nodes, self.w)
        out = self.activation(new_nodes)
        return out, adj

䞊蚘のコヌドの倧半はKeras/TFに固有のもので、倉数を適切な堎所に配眮しおいたす。ここで重芁なのは2行だけです。1぀めは、隣接行列の列に぀いお合蚈するこずでグラフの次数を蚈算する操䜜です

degree = tf.reduce_sum(adj, axis=-1)

2぀めの重芁な行は、GCN方皋匏 (8.3) を蚈算する郚分ですここでは掻性化は省略しおいたす

new_nodes = tf.einsum("bi,bij,bjk,kl->bil", 1 / degree, adj, nodes, self.w)

これで、いた実装したGCNレむダヌを詊すこずができるようになりたした

gcnlayer = GCNLayer("relu")
# we insert a batch axis here
gcnlayer((nodes[np.newaxis, ...], adj_mat[np.newaxis, ...]))
(<tf.Tensor: shape=(1, 6, 3), dtype=float32, numpy=
 array([[[0.        , 0.46567526, 0.07535715],
         [0.        , 0.12714943, 0.05325063],
         [0.01475453, 0.295794  , 0.39316285],
         [0.01475453, 0.295794  , 0.39316285],
         [0.01475453, 0.295794  , 0.39316285],
         [0.        , 0.38166213, 0.        ]]], dtype=float32)>,
 <tf.Tensor: shape=(1, 6, 6), dtype=float32, numpy=
 array([[[1., 1., 1., 1., 1., 0.],
         [1., 1., 0., 0., 0., 1.],
         [1., 0., 1., 0., 0., 0.],
         [1., 0., 0., 1., 0., 0.],
         [1., 0., 0., 0., 1., 0.],
         [0., 1., 0., 0., 0., 1.]]], dtype=float32)>)

これにより (1) 新しいノヌド特城、(2) 隣接行列が出力されたす。このレむダヌを積み重ねお、GCNを耇数回適甚できるこずを確認したしょう。

x = (nodes[np.newaxis, ...], adj_mat[np.newaxis, ...])
for i in range(2):
    x = gcnlayer(x)
print(x)
(<tf.Tensor: shape=(1, 6, 3), dtype=float32, numpy=
array([[[0.        , 0.18908624, 0.        ],
        [0.        , 0.        , 0.        ],
        [0.        , 0.145219  , 0.        ],
        [0.        , 0.145219  , 0.        ],
        [0.        , 0.145219  , 0.        ],
        [0.        , 0.        , 0.        ]]], dtype=float32)>, <tf.Tensor: shape=(1, 6, 6), dtype=float32, numpy=
array([[[1., 1., 1., 1., 1., 0.],
        [1., 1., 0., 0., 0., 1.],
        [1., 0., 1., 0., 0., 0.],
        [1., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 1.]]], dtype=float32)>)

うたくいきたしたしかし、なぜれロの倀があるのでしょうかこれはおそらく、出力に負の倀が含たれおおり、それがReLU掻性化を通した際に0になったためでしょう。これはモデルの蚓緎が䞍十分なために起きおいるず考えられ、蚓緎を重ねるこずで解決するでしょう

8.5. 䟋溶解床予枬¶

次に、GCNによる溶解床の予枬に぀いお説明したす。以前に、分子デヌタセットに含たれおいる特城量を䜿っお予枬モデルを組んだこずを思い出しおください。いた我々はGCNを䜿えるようになったので、特城量に頭を悩たすこずなく、分子構造を盎接ニュヌラルネットワヌクに入力できるようになりたした。GCNレむダヌは各ノヌドに぀いおの特城を出力したすが、溶解床を予枬するためには、グラフ党䜓に぀いおの特城を埗る必芁がありたす。このプロセスをさらに掗緎する方法に぀いおは埌で説明したすが、ここでは、GCNレむダヌ埌のすべおのノヌド特城の平均を䜿うこずにしたす。これにより単玔か぀順序䞍倉に、ノヌド特城をグラフ特城に倉換するこずができたす。この実装は次のずおりです

class GRLayer(tf.keras.layers.Layer):
    """A GNN layer that computes average over all node features"""

    def __init__(self, name="GRLayer", **kwargs):
        super(GRLayer, self).__init__(name=name, **kwargs)

    def call(self, inputs):
        nodes, adj = inputs
        reduction = tf.reduce_mean(nodes, axis=1)
        return reduction

䞊のコヌドで重芁な点は、ノヌドに぀いお平均をずっおいる(axis=1)次の郚分だけです

reduction = tf.reduce_mean(nodes, axis=1)

この溶解床予枬噚を完成させるため、いく぀かの党結合局を远加しお、回垰が行えるこずを確認したしょう。回垰の堎合は最終局の出力がそのたた予枬結果ずなるため、最終局には掻性化を適甚しないこずに泚意しおください。このモデルは Keras functional API を䜿っお実装されおいたす。

ninput = tf.keras.Input(
    (
        None,
        100,
    )
)
ainput = tf.keras.Input(
    (
        None,
        None,
    )
)
# GCN block
x = GCNLayer("relu")([ninput, ainput])
x = GCNLayer("relu")(x)
x = GCNLayer("relu")(x)
x = GCNLayer("relu")(x)
# reduce to graph features
x = GRLayer()(x)
# standard layers (the readout)
x = tf.keras.layers.Dense(16, "tanh")(x)
x = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs=(ninput, ainput), outputs=x)

この100はどこから来たのでしょうかその答えはデヌタセットに含たれる元玠の数にありたす。このデヌタセットは倚数の元玠を含むため、以前䜿ったサむズ3のone-hot encodingでは、党おの元玠を衚珟できたせん。前回はC, H, Oの元玠さえ衚珟できれば十分でしたが、今回はより倚数の元玠を扱う必芁がありたす。そのため、one-hot encodingのサむズも100に増やすこずにしたした。これで最倧100皮類の元玠を衚珟できたす。この拡匵のために、モデルだけでなくsmiles2graph関数も曎新するこずにしたしょう。

def gen_smiles2graph(sml):
    """Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    """
    m = rdkit.Chem.MolFromSmiles(sml)
    m = rdkit.Chem.AddHs(m)
    order_string = {
        rdkit.Chem.rdchem.BondType.SINGLE: 1,
        rdkit.Chem.rdchem.BondType.DOUBLE: 2,
        rdkit.Chem.rdchem.BondType.TRIPLE: 3,
        rdkit.Chem.rdchem.BondType.AROMATIC: 4,
    }
    N = len(list(m.GetAtoms()))
    nodes = np.zeros((N, 100))
    for i in m.GetAtoms():
        nodes[i.GetIdx(), i.GetAtomicNum()] = 1

    adj = np.zeros((N, N))
    for j in m.GetBonds():
        u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())
        order = j.GetBondType()
        if order in order_string:
            order = order_string[order]
        else:
            raise Warning("Ignoring bond order" + order)
        adj[u, v] = 1
        adj[v, u] = 1
    adj += np.eye(N)
    return nodes, adj
nodes, adj = gen_smiles2graph("CO")
model((nodes[np.newaxis], adj_mat[np.newaxis]))
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.0107595]], dtype=float32)>

このモデルは1぀の数倀スカラヌを出力したす。 さお、孊習可胜なデヌタセットを埗るために、いく぀か䜜業が必芁です。このデヌタセットは少し耇雑で、特城はテン゜ル(\(mathbf{V}, \mathbf{E}\))のタプルなので、デヌタセットは次のようなタプルのタプルになりたす \(\left((\mathbf{V}, \mathbf{E}), y\right)\) generatorはPythonの関数で、倀を繰り返し返すこずができたす。ここでは、孊習デヌタを1぀ず぀取り出すためにgeneratorを䜿いたす。続いお、これを from_generator tf.data.Dataset コンストラクタに枡したす。このコンストラクタでは、入力デヌタのshapeを明瀺的に指定する必芁がありたす。

def example():
    for i in range(len(soldata)):
        graph = gen_smiles2graph(soldata.SMILES[i])
        sol = soldata.Solubility[i]
        yield graph, sol


data = tf.data.Dataset.from_generator(
    example,
    output_types=((tf.float32, tf.float32), tf.float32),
    output_shapes=(
        (tf.TensorShape([None, 100]), tf.TensorShape([None, None])),
        tf.TensorShape([]),
    ),
)

ここたで来たらもう少しです。これで、い぀ものようにデヌタセットをtrain/val/testに分割できたす。

test_data = data.take(200)
val_data = data.skip(200).take(200)
train_data = data.skip(400)

そしお、いよいよモデルの蚓緎です

model.compile("adam", loss="mean_squared_error")
result = model.fit(train_data.batch(1), validation_data=val_data.batch(1), epochs=10)
plt.plot(result.history["loss"], label="training")
plt.plot(result.history["val_loss"], label="validation")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()
../_images/gnn_39_0.png

このモデルは明らかにアンダヌフィットです。考えられる理由の䞀぀はバッチサむズが1であるこずです。このモデルでは、原子個数を可倉にした副䜜甚ずしおバッチサむズを1より倧きくできない制玄がありたす。もう少し詳しく説明するず、任意のバッチサむズを入力できるようにするず、原子個数ずバッチサむズの2぀が䞍定ずなりたす。Keras/tensorflowは入力デヌタのshapeに未知の次元が2぀以䞊ある堎合デヌタを凊理できないため、バッチサむズを1に固定するこずで察凊しおいたす。ここでは扱いたせんが、この問題を回避しおバッチサむズを1より倧きくするための暙準的トリックは、耇数の分子を分子間の結合がない1぀のグラフにたずめおしたうこずです。これによりデヌタの次元はそのたたに、分子をバッチ凊理するこずができたす。

それではパリティプロットで予枬粟床を確認しおみたしょう。

yhat = model.predict(test_data.batch(1), verbose=0)[:, 0]
test_y = [y for x, y in test_data]
plt.figure()
plt.plot(test_y, test_y, "-")
plt.plot(test_y, yhat, ".")
plt.text(
    min(test_y) + 1,
    max(test_y) - 2,
    f"correlation = {np.corrcoef(test_y, yhat)[0,1]:.3f}",
)
plt.text(
    min(test_y) + 1,
    max(test_y) - 3,
    f"loss = {np.sqrt(np.mean((test_y - yhat)**2)):.3f}",
)
plt.title("Testing Data")
plt.show()
../_images/gnn_41_0.png

8.6. Message PassingずGCN¶

より広い意味でGCNレむダヌを捉えるず、GCNレむダヌは”message-passing”レむダヌの䞀぀ず蚀えたす。GCNでは、たず近傍ノヌドからやっおくるメッセヌゞを凊理したす

(8.4)¶\[\begin{equation} \vec{e}_{{s_i}j} = \vec{v}_{{s_i}j} \mathbf{W} \end{equation}\]

ここで \(v_{{s_i}j}\) は ノヌド \(i\) の \(j\) 番目の近傍です。 \(s_i\) は \(i\) に察するセンダヌ送信者です。 これはGCNがどのようにメッセヌゞを蚈算するか瀺したものですが、やっおいるこずは単玔で、各近傍ノヌドの特城に重み行列をかけおいるだけです。ノヌド \(i\) に向かうメッセヌゞ \(\vec{e}_{{s_i}j}\) を埗た埌、これらのメッセヌゞをノヌドの順番に察しお䞍倉な関数を甚いお集玄したす

(8.5)¶\[\begin{equation} \vec{e}_{i} = \frac{1}{|\vec{e}_{{s_i}j}|}\sum \vec{e}_{{s_i}j} \end{equation}\]

䞊で扱ったように、GCNではこの集玄は単なる平均ですが、任意の䟋えば孊習可胜な順序䞍倉の関数を䜿うこずもできたす

(8.6)¶\[\begin{equation} \vec{v}^{'}_{i} = \sigma(\vec{e}_i) \end{equation}\]

\(v^{'}\) は新しいノヌド特城を瀺しおいたす。これは単玔に、集玄された埌で掻性化関数を適甚したメッセヌゞです。このように曞き出すこずで、これらの手順にいくらでも小さな倉曎が加えられるこずに気づいたのではないでしょうか。Gilmerらによる重芁な論文 [GSR+17] では、いく぀かの遞択肢を怜蚎し、このメッセヌゞパッシングレむダヌの基本的なアむディアが、量子力孊に基づいお分子゚ネルギヌを予枬するタスクでうたくいくこずが述べられおいたす。GCN匏に倉曎を加えた䟋ずしおは、近傍メッセヌゞの蚈算においお゚ッゞ特城を含めたり、単に \(\sigma\) で和をずる代わりに党結合局を䜿うずいった詊みが挙げられたす。 これらから、GCNは、メッセヌゞパッシンググラフニュヌラルネットワヌクMPNNず略されるこずもありたすの䞀皮ず考えるこずができたす。

8.7. Gated Graph Neural Network¶

メッセヌゞパッシングレむダヌの最も有名な亜皮の䞀぀は、 gated graph neural network (GGN) [LTBZ15] です。これは最埌の匏のノヌド曎新を次で眮き換えたものです

(8.7)¶\[\begin{equation} \vec{v}^{'}_{i} = \textrm{GRU}(\vec{v}_i, \vec{e}_i) \end{equation}\]

\(\textrm{GRU}(\cdot, \cdot)\) はゲヌト再垰ナニットgated recurrent unit[CGCB14] です。 GRU はバむナリ入力匕数を2぀持぀ニュヌラルネットワヌクで、兞型的には系列デヌタのモデリングに䜿甚されたす。GCNず比范しお、GGNの興味深い特城は、GRUからのノヌド曎新においお孊習可胜なパラメヌタを持぀こずで、より柔軟性を備えたモデルずなっおいるこずです。GGNでは、GRUのパラメヌタは各局で共有されたすGRUを䜿っお系列デヌタをモデリングする方法ず同じです。パラメヌタ共有によるメリットは、孊習すべきパラメヌタを増やすこずなくGGNの局を無限に積み重ねられるこずです各局で \(\mathbf{W}\) を揃えるこずが前提です。このため、GGNは倧きなタンパク質や倧きなナニットセルをも぀結晶構造のような、倧きなグラフに適しおいたす蚳泚GGNレむダヌを䜕局もスタックするこずで、より遠くのノヌドからの情報を取り蟌むこずができるためです。

8.8. Pooling¶

メッセヌゞパッシングの芳点、および䞀般にGNNSでは、近傍からのメッセヌゞを結合する方法が重芁なステップずなりたす。このステップは畳み蟌みニュヌラルネットワヌクで䜿われるプヌリング局に䌌おいるため、プヌリングず呌ばれるこずもありたす。畳み蟌みニュヌラルネットワヌクのプヌリングず同じように、このために䜿甚できる瞮玄操䜜は耇数ありたす。䞀般的に、GNNのプヌリングにはメッセヌゞの合蚈たたは平均が䜿われたすが、グラフ同型ネットワヌクGraph Isomorphism Networks [XHLJ18] のように非垞に掗緎された操䜜を甚いるこずもできたす。この他にも、䟋えば泚意attentionの章では自己泚意self-attentionを䜿う䟋を扱いたすが、これらの操䜜もプヌリングに䜿えたす。プヌリングのステップを色々ず工倫したくなるこずもありたすが、プヌリング操䜜の遞択はモデルの性胜にずっおそれほど重芁ではないこずが経隓的に分かっおいたす [LDLio19, MSK20]。プヌリングの重芁な特性は順序䞍倉性で、集玄操䜜はノヌドプヌリングの堎合ぱッゞの順序に䟝存しない性質を備えるこずが望たれたす。Grattarolaら [GZBA21]が、プヌリング手法に関する最近のレビュヌを出版しおいたす。

Daigavaneらの論文 [DRA21] では、様々なプヌリング戊略の比范ず抂芁が芖芚的に玹介されおいたす。

8.9. Readout Function¶

GNNの出力はグラフですそう蚭蚈されおいるので圓たり前ですが。しかし、予枬したいラベルもグラフであるこずは皀で、䞀般的にはラベルは各ノヌドたたはグラフ党䜓に察しお付䞎されおいたす。ノヌドラベルの䟋は原子の郚分電荷、グラフラベルの䟋は分子の゚ネルギヌです。GNNから出力されるグラフを、予枬タヌゲットであるノヌドラベルやグラフラベルに倉換するプロセスを 読み出し readout ず呌びたす。ノヌドラベルを予枬する堎合であれば、単玔に゚ッゞ特城を捚おお、GNNから出力されるノヌド特城ベクトルを予枬結果ずしお扱うこずができたす。この堎合、出力局の前にいく぀かの党結合局を挟むこずが倚いでしょう。

分子の゚ネルギヌや実効電荷のようなグラフレベルのラベルを予枬する堎合、ノヌド/゚ッゞ特城をグラフラベルに倉換するプロセスに泚意が必芁です。所望のshapeのグラフラベルを埗るために、単玔にノヌド特城を党結合局に入力した堎合、順序等䟡性が倱われおしたいたす出力はノヌドラベルではなくグラフラベルなので、厳密には、順序等䟡性ではなく順序䞍倉性です。溶解床の䟋で甚いた読み出しは、ノヌド特城量に察しお瞮玄操䜜を行うこずでした。その埌でグラフ特城を党結合局に入力しお予枬結果を埗たした。実は、これがグラフ特城の読み出しを行う唯䞀の方法であるこずが瀺されおいたす [ZKR+17] 。すなわち、グラフ特城を埗るためにノヌド特城の瞮玄操䜜を行い、このグラフ特城を党結合局に入力するこずで予枬結果であるグラフラベルを埗たす。各ノヌドの特城量に察しおそれぞれ党結合局を通す操䜜もできたすが、ノヌド特城ぞの党結合局の適甚はGNN内郚で既に行われおいるので、あたりお勧めしたせん。このグラフ特城の読み出しはDeepSetsず呌ばれるこずもありたす。これは、特城が集合蚳泚順序を持たず、個数が䞍定ずしお䞎えられる堎合のために蚭蚈された、順序䞍倉なアヌキテクチャであるDeepSets [ZKR+17] ず同じ圢であるためです。

プヌリングも読み出しも順序䞍倉の関数が䜿われおいるこずにお気づきでしょうか。したがっお、DeepSetsはプヌリングに、attentionは読み出しに䜿甚するこずもできたす。

8.9.1. Intensive vs Extensive¶

回垰タスクでの読み出しにおいお考慮すべき重芁な点の1぀は、ラベルが intensive か extensive かです。Intensiveラベルは、ノヌド原子の数に䟝存しない倀を持぀ラベルです。䟋えば、屈折率や溶解床などはIntensiveラベルです。Intensiveラベルの読み出しは、䞀般にノヌドの数に䟝存しないこずが芁請されたす。したがっお、この堎合の読み出しにおける瞮玄操䜜ずしお、平均や最倧をずる操䜜は適甚可胜ですが、ノヌド数により倀が倉わるため合蚈は適したせん。察照的に、Extensiveラベルでは、䞀般的には読み出しの瞮玄操䜜にはノヌド数を反映できるため合蚈が適したす。extensiveな分子特性の䟋には、生成゚ンタルピヌが挙げられたす。

8.10. Battaglia General Equations¶

ここたでで孊んだように、GNNレむダヌはメッセヌゞパッシングレむダヌずしお䞀般化するこずができたした。Battagliaら [BHB+18] はさらに進んで、ほがすべおのGNNを蚘述できる䞀般的な方皋匏の集合を考案したした。圌らはGNNレむダヌの方皋匏を、メッセヌゞパッシングレむダヌの方皋匏におけるノヌド曎新匏のような3぀の曎新方皋匏ず、3぀の集玄方皋匏ずいう合蚈6぀の匏に分解したした。これらの匏では、グラフ特城ベクトルずいう新しい抂念が導入されおいたす。このアむディアでは、ネットワヌクに2぀の郚分GNNず読み出しを持たせる代わりに、グラフレベルの特城を各GNNレむダヌで曎新するアプロヌチをずりたす。グラフ特城ベクトルは、グラフ党䜓を衚す特城の集合です。䟋えば溶解床を蚈算する堎合は、読み出し関数を持぀代わりに分子党䜓の特城ベクトルを構築し、これを曎新しお最終的に溶解床を予枬する方法が有効だった可胜性もありたす。このように、分子党䜓に぀いお定矩されるあらゆる皮類の量䟋溶解床、゚ネルギヌは、グラフレベルの特城ベクトルを甚いお予枬できるでしょう。

これらの匏の最初のステップぱッゞの特城ベクトルの曎新であり、新たに導入する倉数である \(\vec{e}_k\) に぀いおの匏ずしお蚘述されたす

(8.8)¶\[ \vec{e}^{'}_k = \phi^e\left( \vec{e}_k, \vec{v}_{rk}, \vec{v}_{sk}, \vec{u}\right) \]

\(\vec{e}_k\) ぱッゞ \(k\) の特城ベクトル、 \(\vec{v}_{rk}\) ぱッゞ \(k\) に぀いお受信されたノヌド特城ベクトル、 \(\vec{v}_{sk}\) ぱッゞ \(k\) に぀いおノヌド特城ベクトルを送信したノヌド、 \(\vec{u}\) はグラフ特城ベクトル、 \(\phi^e\) はGNNレむダヌの定矩に䜿われる3皮類の曎新関数のうちの䞀぀です。ただし、ここで蚀う3皮類の曎新関数ずは䞀般化した衚珟であり、必ずしも3皮類を定矩する必芁はありたせん。ここではそのうちの䞀぀である \(\phi^e\) を甚いおGNNレむダヌを定矩したす。

ここで扱う分子グラフは無向グラフなので、どのノヌドが \(\vec{v}_{rk}\) を受信し、どのノヌドが \(\vec{v}_{sk}\) を送信するかをどのように決めれば良いでしょうか それぞれの \(\vec{e}^{'}_k\) は次のステップでノヌド \(v_{rk}\) ぞの入力ずしお集玄されたす。分子グラフでは、党おの結合は原子からの「入力」ず「出力」の䞡方を兌ねるため、党おの結合を2぀の有向゚ッゞずしお取り扱うこずにしたす他に良い方法がないのですC-H結合はCからHぞの蟺ずHからCぞの゚ッゞで構成されるこずになりたす。最初の疑問に戻りたすが、 \(\vec{v}_{rk} \)ず \(\vec{v}_{sk}\) ずは䜕でしょうか隣接行列のすべおの芁玠\(k\)を考え、\(k = \{ij\}\) すなわち芁玠 \(A_{ij}\) に぀いおは、受信ノヌドが \(j\) 、送信ノヌドが \(i\) であるこずを衚したす。逆向きの゚ッゞにおける隣接行列の芁玠 \(A_{ji}\) を考えるず、受信ノヌドが \(i\) 、送信ノヌドが \(j\) ずなりたす。

\(\vec{e}^{'}_k\) はGCNからのメッセヌゞのようなものですが、より䞀般的で、受信ノヌドずグラフ特城ベクトル \(\vec{u}\) の情報を反映するこずができたす。日垞的な意味での「メッセヌゞ」は䞀床送信されれば誰あるいは䜕に受信されるかによっお内容が倉わるわけではないので、 \(\vec{e}^{'}_k\) をメッセヌゞの比喩で説明しようずするずおかしなこずになりたす。ずもかく、新しい゚ッゞの曎新は、最初の集玄関数で集玄されたす

(8.9)¶\[ \bar{e}^{'}_i = \rho^{e\rightarrow v}\left( E_i^{'}\right) \]

\(\rho^{e\rightarrow v}\) は我々が定矩した関数、 \(E_i^{'}\) はノヌド \(i\) に 向かう 党おの゚ッゞからの \(\vec{e}^{'}_k\) をスタックしたものです。集玄された゚ッゞを䜿っお、ノヌドの曎新を蚈算できたす

(8.10)¶\[ \vec{v}^{'}_i = \phi^v\left( \bar{e}^{'}_i, \vec{v}_i, \vec{u}\right) \]

以䞊で新しいノヌドず゚ッゞが埗られたので、GNNレむダヌの通垞のステップは完了です。もしグラフ特城 (\(\vec{u}\)) を曎新する堎合、以䞋のステップが远加で定矩されるこずがありたす

(8.11)¶\[ \bar{e}^{'} = \rho^{e\rightarrow u}\left( E^{'}\right) \]

この匏は、グラフ党䜓に぀いお党おのメッセヌゞ/集玄された゚ッゞを集玄したす。これにより、新しいノヌドをグラフ党䜓に぀いお集玄できたす

(8.12)¶\[ \bar{v}^{'} = \rho^{v\rightarrow u}\left( V^{'}\right) \]

そしお最埌に、次のようにしおグラフ特城ベクトルを曎新できたす

(8.13)¶\[ \vec{u}^{'} = \phi^u\left( \bar{e}^{'},\bar{v}^{'}, \vec{u}\right) \]

8.10.1. Battaglia equationsによるGCNの再定匏化¶

Battagliaの匏によりGCNがどのように蚘述されるか芋おみたしょう。たず (8.8) を䜿っお、隣接する可胜性のあるすべおの隣接ノヌドに察しおメッセヌゞを蚈算したす。GCNでは、メッセヌゞは送信者にのみ䟝存し、受信者には䟝らないこずに泚意しおください。

\[ \vec{e}^{'}_k = \phi^e\left( \vec{e}_k, \vec{v}_{rk}, \vec{v}_{sk}, \vec{u}\right) = \vec{v}_{sk} \mathbf{W} \]

(8.9) においおノヌド \(i\) にやっおくるメッセヌゞを集玄するために、これらのメッセヌゞの平均をずりたす

\[ \bar{e}^{'}_i = \rho^{e\rightarrow v}\left( E_i^{'}\right) = \frac{1}{|E_i^{'}|}\sum E_i^{'} \]

続いお、ノヌドの曎新を行いたすが、これは単にメッセヌゞに぀いお掻性化関数を適甚するだけです (8.10)

\[ \vec{v}^{'}_i = \phi^v\left( \bar{e}^{'}_i, \vec{v}_i, \vec{u}\right) = \sigma(\bar{e}^{'}_i) \]

䞊匏においお \(\sigma(\bar{e}^{'}_i + \vec{v}_i)\) ず倉曎を加えるこずで、グラフに自己ルヌプを持たせるこずも可胜です。GCNでは他の関数は必芁ないので、これら3぀の匏だけでGCNを定矩するこずができたす。

8.11. The SchNet Architecture¶

最も叀く、か぀よく甚いられるGNNの1぀に、SchNetネットワヌク [SchuttSK+18] がありたす。発衚圓時はあたりGNNずしおは認識されおいたせんでしたが、珟圚ではその䞀぀ずしお認識され、ベヌスラむンモデルずしおよく䜿われおいたす。ベヌスラむンモデルずは、新手法ずの比范に甚いられるモデルのこずで、広く受け入れられ、か぀様々な実隓を通じお良い性胜を瀺すこずが確認されおいるモデルが䜿われたす。

これたでに扱った党おの䟋では、分子をグラフずしおモデルぞ入力しおいたした。䞀方SchNetでは、分子グラフではなく、原子をxyz座暙点ずしお衚珟し入力し、xyz座暙をグラフに倉換するこずでGNNNを適甚したす。SchNetは、結合情報なしで原子の配眮のみから゚ネルギヌや力を予枬するために開発されたした。したがっお、SchNetを理解するために、たず原子ずその䜍眮のセットがどのようにグラフに倉換されるかを確認したしょう。各原子をノヌド化する手順は簡単で、䞊蚘ず同様の凊理を行った埌、原子番号をembeddingレむダヌに枡したす。これは、各原子番号に孊習可胜なベクトルを割り圓おるこず蚳泚孊習により、各原子の抜象的特城を捉えた高次元ベクトルを埗るこずを意味したすembeddingに぀いおの埩習は Standard Layers を参照しおください。

隣接行列の蚈算も簡単で、党おの原子が党おの原子に接続されるようにするだけです。単に党原子が盞互に接続するのだずしたら、GNNを䜿う意味がよくわからないず思われるかもしれたせん。このような操䜜をする理由は、GNNは順序等䟡であるからです。もし原子をxyz座暙ずしお孊習しようずするず原子の䞊び方によっお重みが倉わっおしたう䞊に、構造ごずの原子数の違いをうたく取り扱えないこずでしょう。

SchNetを理解するためにもう䞀぀抌さえるべき点は、各原子のxyz座暙の情報はどう扱われるのか、ずいうこずです。SchNetでは、xyz座暙から゚ッゞ特城を構築するこずにより、モデルに座暙の情報を取り蟌んでいたす。原子 \(i\) ず \(j\) の間の゚ッゞ \(\vec{e}\) は、シンプルにこれらの原子間距離 \(r\) から蚈算されたす。

(8.14)¶\[ e_k = \exp\left(-\gamma \left(r - \mu_k\right)^2\right) \]

\(\gamma\) はハむパヌパラメヌタ䟋 10Å \(\mu_k\) は [0, 5, 10, 15 , 20] のようなスカラヌの等間隔グリッドです。 (8.14) の操䜜は、原子番号や共有結合の皮類のようなカテゎリ特城をone-hotベクトルに倉換するこずに䌌おいたす。しかし、カテゎリカルな量ず異なり、距離は連続倀で無限にあるので、one-hotベクトルずしお衚珟するこずはできたせん。そこで、䞀皮の「スムヌゞング」によっお、距離を擬䌌的にone-hot衚珟しおいるのです。この感芚を぀かむために、䟋を芋おみたしょう

gamma = 1
mu = np.linspace(0, 10, 5)


def rbf(r):
    return np.exp(-gamma * (r - mu) ** 2)


print("input", 2)
print("output", np.round(rbf(2), 2))
input 2
output [0.02 0.78 0.   0.   0.  ]

距離 \(r=2\) は、 \(k = 1\) の䜍眮蚳泚kはこのlistのindexのこずですが匷く掻性化したベクトルを䞎えるこずがわかりたす。これは \(\mu_1 = 2\) であるこずに察応したす。

ここたででノヌドず゚ッゞおよび、GNNの曎新方皋匏を定矩したした。さらにもう少しだけ蚘号を定矩しおおく必芁がありたす。ここでは、MLPMultilayer perceptronを衚すために \(h(\vec{x})\) を䜿甚したす。MLPは、基本的に1局あるいは2局の党結合レむダヌからなるニュヌラルネットワヌクです。これらのMLPにおける党結合レむダヌの正確な数や、い぀・どこで掻性化を行うかずいった詳现は、重芁な点を理解する䞊で䞍芁なため説明は省略したすこれらの詳现は以䞋の実装䟋を参照しおください。ではここで、党結合レむダヌの定矩を思い出したしょう

\[ h(\vec{x}) = \sigma\left(Wx + b\right) \]

たた、SchNetでは “shifted softplus” ず呌ばれる新たな掻性化関数 \(\sigma\) を甚いたす \(\sigma = \ln\left(0.5e^{x} + 0.5\right)\) 。 Fig. 8.4 においお、 \(\sigma(x)\) ず通垞のReLU掻性化を比范した分析が報告されおいたす。shifted softplusを䜿う理由は、入力に察しお滑らかであるためです。このため、粒子同士の距離pairwise distanceに察しお滑らかな埮分が必芁ずされる分子動力孊シミュレヌションのようなアプリケヌションにおいお、フォヌス蚳泚粒子同士に働く力を蚈算するためにSchNetを䜿うこずができたす。

../_images/gnn_49_0.png

Fig. 8.4 䞀般的なReLU掻性化関数ず、SchNetで䜿甚されおいるshifted softplusの比范¶

さお、前眮きが続きたしたが、ようやくGNN方皋匏に話を移したす。゚ッゞの曎新方皋匏 (8.8) は2぀の郚分から成りたす。たず、やっおくる゚ッゞ結合特城ず、ノヌド原子の特城をMLPに通したす。続いお、それらの結果を次のMLPに通したす

\[ \vec{e}^{'}_k = \phi^e\left( \vec{e}_k, \vec{v}_{rk}, \vec{v}_{sk}, \vec{u}\right) =h_1\left(\vec{v}_{sk}\right) \cdot h_2\left(\vec{e}_k\right) \]

次に゚ッゞの集玄関数 (8.9) を考えたしょう。SchNetでは、゚ッゞ集玄は近傍の原子特城量に察する和です。

\[ \bar{e}^{'}_i = \sum E_i^{'} \]

最埌に、SchNetのノヌド曎新関数は以䞋のようになりたす

\[ \vec{v}^{'}_i = \phi^v\left( \bar{e}^{'}_i, \vec{v}_i, \vec{u}\right) = \vec{v}_i + h_3\left(\bar{e}^{'}_i\right) \]

通垞、GNNの曎新は3〜6回適甚されたす。䞊蚘のSchNetの説明においお、゚ッゞの曎新匏を定矩したしたが、GCN同様に実際にぱッゞ特城を䞊曞きせず、各局で同じ゚ッゞ特城が保たれたす。元々のSchNetぱネルギヌや力を予枬するためのものなので、読み出しはsum-poolingや䞊蚘のような戊略で行うこずが可胜です。

これらの匏の詳现は倉曎されるこずもありたすが、オリゞナルのSchNetの論文では、 \(h_1\) は掻性化なしのdenseな1局、\(h_2\) は掻性化ありの2局、\(h_3\) は1局に掻性化あり・2局目は掻性化無しのdenseな2局の構成が甚いられたした。

SchNetずは?

SchNetベヌスのGNNの䞻な特城は、1゚ッゞの曎新メッセヌゞの組み立おに゚ッゞずノヌドの特城を甚いるこずです

\[ \vec{e}^{'}_k = h_1(\vec{v}_{sk}) \cdot h_2(\vec{e}_k) \]

ここで \(h_i()\) は䜕らかの孊習可胜な関数です。特城2は、ノヌド曎新に残差を利甚するこずです

\[ \vec{v}^{'}_i = \vec{v}_i + h_3\left(\bar{e}^{'}_i\right) \]

その他、゚ッゞ特城の䜜り方、\(h_i\) の局数、掻性化関数の遞択、読み出しの方法、ポむントクラりドをグラフに倉換する方法など、詳现は党お [SchuttSK+18] で提案されたSchNetモデルの定矩に準拠したす。

8.12. SchNet Example: Predicting Space Groups¶

Our next example will be a SchNet model that predict space groups of points. Identifying the space group of atoms is an important part of crystal structure identification, and when doing simulations of crystallization. Our SchNet model will take as input points and output the predicted space group. This is a classification problem; specifically it is multi-class becase a set of points should only be in one space group. To simplify our plots and analysis, we will work in 2D where there are 17 possible space groups.

Our data for this is a set of points from various point groups. The features are xyz coordinates and the label is the space group. We will not have multiple atom types for this problem. The hidden cell below loads the data and reshapes it for the example.

import gzip
import pickle
import urllib

urllib.request.urlretrieve(
    "https://github.com/whitead/dmol-book/raw/master/data/sym_trajs.pb.gz",
    "sym_trajs.pb.gz",
)
with gzip.open("sym_trajs.pb.gz", "rb") as f:
    trajs = pickle.load(f)

label_str = list(set([k.split("-")[0] for k in trajs]))

# now build dataset
def generator():
    for k, v in trajs.items():
        ls = k.split("-")[0]
        label = label_str.index(ls)
        traj = v
        for i in range(traj.shape[0]):
            yield traj[i], label


data = tf.data.Dataset.from_generator(
    generator,
    output_signature=(
        tf.TensorSpec(shape=(None, 2), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32),
    ),
).shuffle(1000)

# The shuffling above is really important because this dataset is in order of labels!

val_data = data.take(100)
test_data = data.skip(100).take(100)
train_data = data.skip(200)

Let’s take a look at a few examples from the dataset

fig, axs = plt.subplots(4, 5, figsize=(12, 8))
axs = axs.flatten()

# get a few example and plot them
for i, (x, y) in enumerate(data):
    if i == 20:
        break
    axs[i].plot(x[:, 0], x[:, 1], ".")
    axs[i].set_title(label_str[y.numpy()])
    axs[i].axis("off")
../_images/gnn_55_0.png

You can see that there is a variable number of points and a few examples for each space group. The goal is to infer those titles on the plot from the points alone.

8.12.1. Building the graphs¶

We now need to build the graphs for the points. The nodes are all identical - so they can just be 1s (we’ll reserve 0 in case we want to mask or pad at some point in the future). As described in the SchNet section above, the edges should be distance to every other atom. In most implementations of SchNet, we practically add a cut-off on either distance or maximum degree (edges per node). We’ll do maximum degree for this work of 16.

I have a function below that is a bit sophisticated. It takes a matrix of point positions in arbitrary dimension and returns the distances and indices to the nearest k neighbors - exactly what we need. It uses some tricks from Tensors and Shapes. However, it is not so important for you to understand this function. Just know it takes in points and gives us the edge features and edge nodes.

# this decorator speeds up the function by "compiling" it (tracing it)
# to run efficienty
@tf.function(
    reduce_retracing=True,
)
def get_edges(positions, NN, sorted=True):
    M = tf.shape(input=positions)[0]
    # adjust NN
    NN = tf.minimum(NN, M)
    qexpand = tf.expand_dims(positions, 1)  # one column
    qTexpand = tf.expand_dims(positions, 0)  # one row
    # repeat it to make matrix of all positions
    qtile = tf.tile(qexpand, [1, M, 1])
    qTtile = tf.tile(qTexpand, [M, 1, 1])
    # subtract them to get distance matrix
    dist_mat = qTtile - qtile
    # mask distance matrix to remove zros (self-interactions)
    dist = tf.norm(tensor=dist_mat, axis=2)
    mask = dist >= 5e-4
    mask_cast = tf.cast(mask, dtype=dist.dtype)
    # make masked things be really far
    dist_mat_r = dist * mask_cast + (1 - mask_cast) * 1000
    topk = tf.math.top_k(-dist_mat_r, k=NN, sorted=sorted)
    return -topk.values, topk.indices

Let’s see how this function works by showing the connections between points in one of our examples. I’ve hidden the code below. It shows some point’s neighbors and connects them so you can get a sense of how a set of points is converted into a graph. The complete graph will have all points’ neighborhoods.

from matplotlib import collections

fig, axs = plt.subplots(2, 3, figsize=(12, 8))
axs = axs.flatten()
for i, (x, y) in enumerate(data):
    if i == 6:
        break
    e_f, e_i = get_edges(x, 8)

    # make things easier for plotting
    e_i = e_i.numpy()
    x = x.numpy()
    y = y.numpy()

    # make lines from origin to its neigbhors
    lines = []
    colors = []
    for j in range(0, x.shape[0], 23):
        # lines are [(xstart, ystart), (xend, yend)]
        lines.extend([[(x[j, 0], x[j, 1]), (x[k, 0], x[k, 1])] for k in e_i[j]])
        colors.extend([f"C{j}"] * len(e_i[j]))
    lc = collections.LineCollection(lines, linewidths=2, colors=colors)
    axs[i].add_collection(lc)
    axs[i].plot(x[:, 0], x[:, 1], ".")
    axs[i].axis("off")
    axs[i].set_title(label_str[y])
plt.show()
../_images/gnn_59_0.png

We will now add this function and the edge featurization of SchNet (8.14) to get the graphs for the GNN steps.

MAX_DEGREE = 16
EDGE_FEATURES = 8
MAX_R = 20

gamma = 1
mu = np.linspace(0, MAX_R, EDGE_FEATURES)


def rbf(r):
    return tf.exp(-gamma * (r[..., tf.newaxis] - mu) ** 2)


def make_graph(x, y):
    edge_r, edge_i = get_edges(x, MAX_DEGREE)
    edge_features = rbf(edge_r)
    return (tf.ones(tf.shape(x)[0], dtype=tf.int32), edge_features, edge_i), y[None]


graph_train_data = train_data.map(make_graph)
graph_val_data = val_data.map(make_graph)
graph_test_data = test_data.map(make_graph)

Let’s examine one graph to see what it looks like. We’ll slice out only the first nodes.

for (n, e, nn), y in graph_train_data:
    print("first node:", n[1].numpy())
    print("first node, first edge features:", e[1, 1].numpy())
    print("first node, all neighbors", nn[1].numpy())
    print("label", y.numpy())
    break
first node: 1
first node, first edge features: [2.8479335e-01 4.9036104e-02 6.8545725e-10 7.7790052e-25 0.0000000e+00
 0.0000000e+00 0.0000000e+00 0.0000000e+00]
first node, all neighbors [  7  11  10 206   4  12 197   9  13  15   3 192 200   2 130 195]
label [13]

8.12.2. Implementing the MLPs¶

Now we can implement the SchNet model! Let’s start with the \(h_1,h_2,h_3\) MLPs that are used in the GNN update equations. In the SchNet paper these each had different numbers of layers and different decisions about which layers had activation. Let’s create them now.

def ssp(x):
    # shifted softplus activation
    return tf.math.log(0.5 * tf.math.exp(x) + 0.5)


def make_h1(units):
    return tf.keras.Sequential([tf.keras.layers.Dense(units)])


def make_h2(units):
    return tf.keras.Sequential(
        [
            tf.keras.layers.Dense(units, activation=ssp),
            tf.keras.layers.Dense(units, activation=ssp),
        ]
    )


def make_h3(units):
    return tf.keras.Sequential(
        [tf.keras.layers.Dense(units, activation=ssp), tf.keras.layers.Dense(units)]
    )

One detail that can be missed is that the weights in each MLP should change in each layer of SchNet. Thus, we’ve written the functions above to always return a new MLP. This means that a new set of trainable weights is generated on each call, meaning there is no way we could erroneously have the same weights in multiple layers.

8.12.3. Implementing the GNN¶

Now we have all the pieces to make the GNN. This code will be very similar to the GCN example above, except we now have edge features. One more detail is that our readout will be an MLP as well, following the SchNet paper. The only change we’ll make is that we want our output property to be (1) multi-class classification and (2) intensive (independent of number of atoms). So we’ll end with an average (intensive) and end with an output vector of logits the size of our labels.

class SchNetModel(tf.keras.Model):
    """Implementation of SchNet Model"""

    def __init__(self, gnn_blocks, channels, label_dim, **kwargs):
        super(SchNetModel, self).__init__(**kwargs)
        self.gnn_blocks = gnn_blocks

        # build our layers
        self.embedding = tf.keras.layers.Embedding(2, channels)
        self.h1s = [make_h1(channels) for _ in range(self.gnn_blocks)]
        self.h2s = [make_h2(channels) for _ in range(self.gnn_blocks)]
        self.h3s = [make_h3(channels) for _ in range(self.gnn_blocks)]
        self.readout_l1 = tf.keras.layers.Dense(channels // 2, activation=ssp)
        self.readout_l2 = tf.keras.layers.Dense(label_dim)

    def call(self, inputs):
        nodes, edge_features, edge_i = inputs
        # turn node types as index to features
        nodes = self.embedding(nodes)
        for i in range(self.gnn_blocks):
            # get the node features per edge
            v_sk = tf.gather(nodes, edge_i)
            e_k = self.h1s[i](v_sk) * self.h2s[i](edge_features)
            e_i = tf.reduce_sum(e_k, axis=1)
            nodes += self.h3s[i](e_i)
        # readout now
        nodes = self.readout_l1(nodes)
        nodes = self.readout_l2(nodes)
        return tf.reduce_mean(nodes, axis=0)

Remember that the key attributes of a SchNet GNN are the way that we use edge and node features. We can see the mixing of these two in the key line for computing the edge update (computing message values):

e_k = self.h1s[i](v_sk) * self.h2s[i](edge_features)

followed by aggregation of the edges updates (pooling messages):

e_i = tf.reduce_sum(e_k, axis=1)

and the node update

nodes += self.h3s[i](e_i)

Also of note is how we go from node features to multi-classs. We use dense layers that get the shape per-node into the number of classes

self.readout_l1 = tf.keras.layers.Dense(channels // 2, activation=ssp)
self.readout_l2 = tf.keras.layers.Dense(label_dim)

and then we take the average over all nodes

return tf.reduce_mean(nodes, axis=0)

Let’s give now use the model on some data.

small_schnet = SchNetModel(3, 32, len(label_str))
for x, y in graph_train_data:
    yhat = small_schnet(x)
    break
print(yhat.numpy())
[ 0.01385795  0.00906286  0.00222304 -0.00563365  0.00490831  0.00621619
  0.02778318  0.0169939  -0.00951115 -0.00167371 -0.0171227  -0.00270679
 -0.00358437  0.00626842 -0.00611755  0.01474886  0.01494901]

The output is the correct shape and remember it is logits. To get a class prediction that sums to probability 1, we need to use a softmax:

print("predicted class", tf.nn.softmax(yhat).numpy())
predicted class [0.05939336 0.05910925 0.05870633 0.0582469  0.05886419 0.05894123
 0.06022622 0.05957991 0.05802149 0.05847801 0.05758153 0.05841763
 0.05836639 0.0589443  0.05821873 0.0594463  0.0594582 ]

8.12.4. Training¶

Great! It is untrained though. Now we can set-up training. Our loss will be cross-entropy from logits, but we need to be careful on the form. Our labels are integers - which is called “sparse” labels because they are not full one-hots. Mult-class classification is also known as categorical classification. Thus, the loss we want is sparse categorical cross entropy from logits.

small_schnet.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics="sparse_categorical_accuracy",
)
result = small_schnet.fit(graph_train_data, validation_data=graph_val_data, epochs=20)
plt.plot(result.history["sparse_categorical_accuracy"], label="training accuracy")
plt.plot(result.history["val_sparse_categorical_accuracy"], label="validation accuracy")
plt.axhline(y=1 / 17, linestyle="--", label="random")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()
../_images/gnn_75_0.png

The accuracy is not great, but it looks like we could keep training. We have a very small SchNet here. Standard SchNet described in [SchuttSK+18] uses 6 layers and 64 channels and 300 edge features. We have 3 layers and 32 channels. Nevertheless, we’re able to get some learning. Let’s visually see what’s going on with the trained model on some test data

fig, axs = plt.subplots(4, 5, figsize=(12, 8))
axs = axs.flatten()

for i, ((x, y), (gx, _)) in enumerate(zip(test_data, graph_test_data)):
    if i == 20:
        break
    axs[i].plot(x[:, 0], x[:, 1], ".")
    yhat = small_schnet(gx)
    yhat_i = tf.math.argmax(tf.nn.softmax(yhat)).numpy()
    axs[i].set_title(f"True: {label_str[y.numpy()]}\nPredicted: {label_str[yhat_i]}")
    axs[i].axis("off")
plt.tight_layout()
plt.show()
../_images/gnn_77_0.png

We’ll revisit this example later! One unique fact about this dataset is that it is synthetic, meaning there is no label noise. As discussed in Regression & Model Assessment, that removes the possibility of overfitting and leads us to favor high variance models. The goal of teaching a model to predict space groups is to apply it on real simulations or microscopy data, which will certainly have noise. We could have mimicked this by adding noise to the labels in the data and/or by randomly removing atoms to simulate defects. This would better help our model work in a real setting.

8.13. Current Research Directions¶

8.13.1. Common Architecture Motifs and Comparisons¶

We’ve now seen message passing layer GNNs, GCNs, GGNs, and the generalized Battaglia equations. You’ll find common motifs in the architectures, like gating, アテンションレむダヌ, and pooling strategies. For example, Gated GNNS (GGNs) can be combined with attention pooling to create Gated Attention GNNs (GAANs)[ZSX+18]. GraphSAGE is a similar to a GCN but it samples when pooling, making the neighbor-updates of fixed dimension[HYL17]. So you’ll see the suffix “sage” when you sample over neighbors while pooling. These can all be represented in the Battaglia equations, but you should be aware of these names.

The enormous variety of architectures has led to work on identifying the “best” or most general GNN architecture [DJL+20, EPBM19, SMBGunnemann18]. Unfortunately, the question of which GNN architecture is best is as difficult as “what benchmark problems are best?” Thus there are no agreed-upon conclusions on the best architecture. However, those papers are great resources on training, hyperparameters, and reasonable starting guesses and I highly recommend reading them before designing your own GNN. There has been some theoretical work to show that simple architectures, like GCNs, cannot distinguish between certain simple graphs [XHLJ18]. How much this practically matters depends on your data. Ultimately, there is so much variety in hyperparameters, data equivariances, and training decisions that you should think carefully about how much the GNN architecture matters before exploring it with too much depth.

8.13.2. Nodes, Edges, and Features¶

You’ll find that most GNNs use the node-update equation in the Battaglia equations but do not update edges. For example, the GCN will update nodes at each layer but the edges are constant. Some recent work has shown that updating edges can be important for learning when the edges have geometric information, like if the input graph is a molecule and the edges are distance between the atoms [KGrossGunnemann20]. As we’ll see in the chapter on equivariances (Input Data & Equivariances), one of the key properties of neural networks with point clouds (i.e., Cartesian xyz coordinates) is to have rotation equivariance. [KGrossGunnemann20] showed that you can achieve this if you do edge updates and encode the edge vectors using a rotation equivariant basis set with spherical harmonics and Bessel functions. These kind of edge updating GNNs can be used to predict protein structure [JES+20].

Another common variation on node features is to pack more into node features than just element identity. In many examples, you will see people inserting valence, elemental mass, electronegativity, a bit indicating if the atom is in a ring, a bit indicating if the atom is aromatic, etc. Typically these are unnecessary, since a model should be able to learn any of these features which are computed from the graph and node elements. However, we and others have empirically found that some can help, specifically indicating if an atom is in a ring [LWC+20]. Choosing extra features to include though should be at the bottom of your list of things to explore when designing and using GNNs.

8.13.3. Beyond Message Passing¶

One of the common themes of GNN research is moving “beyond message passing,” where message passing is the message construction, aggregation, and node update with messages. Some view this as impossible – claiming that all GNNs can be recast as message passing [Velivckovic22]. Another direction is on disconnecting the underlying graph being input to the GNN and the graph used to compute updates. We sort of saw this above with SchNet, where we restricted the maximum degree for the message passing. More useful are ideas like “lifting” the graphs into more structured objects like simplicial complexes [BFO+21]. Finally, you can also choose where to send the messages beyond just neighbors [TZK21]. For example, all nodes on a path could communicate messages or all nodes in a clique.

8.13.4. Do we need graphs?¶

It is possible to convert a graph into a string if you’re working with an adjacency matrix without continuous values. Molecules specifically can be converted into a string. This means you can use layers for sequences/strings (e.g., recurrent neural networks or 1D convolutions) and avoid the complexities of a graph neural network. SMILES is one way to convert molecular graphs into strings. With SMILES, you cannot predict a per-atom quantity and thus a graph neural network is required for atom/bond labels. However, the choice is less clear for per-molecule properties like toxicity or solubility. There is no consensus about if a graph or string/SMILES representation is better. SMILES can exceed certain graph neural networks in accuracy on some tasks. SMILES is typically better on generative tasks. Graphs obviously beat SMILES in label representations, because they have granularity of bonds/edges. We’ll see how to model SMILES in Deep Learning on Sequences, but it is an open question of which is better.

8.13.5. Stereochemistry/Chiral Molecules¶

Stereochemistry is fundamentally a 3D property of molecules and thus not present in the covalent bonding. It is measured experimentally by seeing if molecules rotate polarized light and a molecule is called chiral or “optically active” if it is experimentally known to have this property. Stereochemistry is the categorization of how molecules can preferentially rotate polarized light through asymmetries with respect to their mirror images. In organic chemistry, the majority of stereochemistry is of enantiomers. Enantiomers are “handedness” around specific atoms called chiral centers which have 4 or more different bonded atoms. These may be treated in a graph by indicating which nodes are chiral centers (nodes) and what their state or mixture of states (racemic) are. This can be treated as an extra processing step. Amino acids and thus all proteins are entaniomers with only one form present. This chirality of proteins means many drug molecules can be more or less potent depending on their stereochemistry.

../_images/helicene.mp4

Fig. 8.5 This is a molecule with axial stereochemistry. Its small helix could be either left or right-handed.¶

Adding node labels is not enough generally. Molecules can interconvert between stereoisomers at chiral centers through a process called tautomerization. There are also types of stereochemistry that are not at a specific atom, like rotamers that are around a bond. Then there is stereochemistry that involves multiple atoms like axial helecene. As shown in Fig. 8.5, the molecule has no chiral centers but is “optically active” (experimentally measured to be chiral) because of its helix which can be left- or right-handed.

8.14. Relevant Videos¶

8.14.1. Intro to GNNs¶

8.14.2. Overview of GNN with Molecule, Compiler Examples¶

8.15. Chapter Summary¶

  • Molecules can be represented by graphs by using one-hot encoded feature vectors that show the elemental identity of each node (atom) and an adjacency matrix that show immediate neighbors (bonded atoms).

  • Graph neural networks are a category of deep neural networks that have graphs as inputs.

  • One of the early GNNs is the Kipf & Welling GCN. The input to the GCN is the node feature vector and the adjacency matrix, and returns the updated node feature vector. The GCN is permutation invariant because it averages over the neighbors.

  • A GCN can be viewed as a message-passing layer, in which we have senders and receivers. Messages are computed from neighboring nodes, which when aggregated update that node.

  • A gated graph neural network is a variant of the message passing layer, for which the nodes are updated according to a gated recurrent unit function.

  • The aggregation of messages is sometimes called pooling, for which there are multiple reduction operations.

  • GNNs output a graph. To get a per-atom or per-molecule property, use a readout function. The readout depends on if your property is intensive vs extensive

  • The Battaglia equations encompasses almost all GNNs into a set of 6 update and aggregation equations.

  • You can convert xyz coordinates into a graph and use a GNN like SchNet

8.16. Cited References¶

DJL+20(1,2)

Vijay Prakash Dwivedi, Chaitanya K Joshi, Thomas Laurent, Yoshua Bengio, and Xavier Bresson. Benchmarking graph neural networks. arXiv preprint arXiv:2003.00982, 2020.

BBL+17

Michael M Bronstein, Joan Bruna, Yann LeCun, Arthur Szlam, and Pierre Vandergheynst. Geometric deep learning: going beyond euclidean data. IEEE Signal Processing Magazine, 34(4):18–42, 2017.

WPC+20

Zonghan Wu, Shirui Pan, Fengwen Chen, Guodong Long, Chengqi Zhang, and S Yu Philip. A comprehensive survey on graph neural networks. IEEE Transactions on Neural Networks and Learning Systems, 2020.

LWC+20(1,2)

Zhiheng Li, Geemi P Wellawatte, Maghesree Chakraborty, Heta A Gandhi, Chenliang Xu, and Andrew D White. Graph neural network based coarse-grained mapping prediction. Chemical Science, 11(35):9524–9531, 2020.

YCW20

Ziyue Yang, Maghesree Chakraborty, and Andrew D White. Predicting chemical shifts with graph neural networks. bioRxiv, 2020.

XFLW+19

Tian Xie, Arthur France-Lanord, Yanming Wang, Yang Shao-Horn, and Jeffrey C Grossman. Graph dynamical networks for unsupervised learning of atomic scale dynamics in materials. Nature communications, 10(1):1–9, 2019.

SLRPW21

Benjamin Sanchez-Lengeling, Emily Reif, Adam Pearce, and Alex Wiltschko. A gentle introduction to graph neural networks. Distill, 2021. https://distill.pub/2021/gnn-intro. doi:10.23915/distill.00033.

XG18

Tian Xie and Jeffrey C. Grossman. Crystal graph convolutional neural networks for an accurate and interpretable prediction of material properties. Phys. Rev. Lett., 120:145301, Apr 2018. URL: https://link.aps.org/doi/10.1103/PhysRevLett.120.145301, doi:10.1103/PhysRevLett.120.145301.

KW16

Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907, 2016.

GSR+17

Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. Neural message passing for quantum chemistry. arXiv preprint arXiv:1704.01212, 2017.

LTBZ15

Yujia Li, Daniel Tarlow, Marc Brockschmidt, and Richard Zemel. Gated graph sequence neural networks. arXiv preprint arXiv:1511.05493, 2015.

CGCB14

Junyoung Chung, Caglar Gulcehre, KyungHyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555, 2014.

XHLJ18(1,2)

Keyulu Xu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. How powerful are graph neural networks? In International Conference on Learning Representations. 2018.

LDLio19

Enxhell Luzhnica, Ben Day, and Pietro Liò. On graph classification networks, datasets and baselines. arXiv preprint arXiv:1905.04682, 2019.

MSK20

Diego Mesquita, Amauri Souza, and Samuel Kaski. Rethinking pooling in graph neural networks. Advances in Neural Information Processing Systems, 2020.

GZBA21

Daniele Grattarola, Daniele Zambon, Filippo Maria Bianchi, and Cesare Alippi. Understanding pooling in graph neural networks. arXiv preprint arXiv:2110.05292, 2021.

DRA21

Ameya Daigavane, Balaraman Ravindran, and Gaurav Aggarwal. Understanding convolutions on graphs. Distill, 2021. https://distill.pub/2021/understanding-gnns. doi:10.23915/distill.00032.

ZKR+17(1,2)

Manzil Zaheer, Satwik Kottur, Siamak Ravanbakhsh, Barnabas Poczos, Russ R Salakhutdinov, and Alexander J Smola. Deep sets. In Advances in neural information processing systems, 3391–3401. 2017.

BHB+18

Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, and others. Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261, 2018.

SchuttSK+18(1,2,3)

Kristof T SchÃŒtt, Huziel E Sauceda, P-J Kindermans, Alexandre Tkatchenko, and K-R MÃŒller. Schnet–a deep learning architecture for molecules and materials. The Journal of Chemical Physics, 148(24):241722, 2018.

CW22

Sam Cox and Andrew D White. Symmetric molecular dynamics. arXiv preprint arXiv:2204.01114, 2022.

ZSX+18

Jiani Zhang, Xingjian Shi, Junyuan Xie, Hao Ma, Irwin King, and Dit-Yan Yeung. Gaan: gated attention networks for learning on large and spatiotemporal graphs. arXiv preprint arXiv:1803.07294, 2018.

HYL17

Will Hamilton, Zhitao Ying, and Jure Leskovec. Inductive representation learning on large graphs. In Advances in neural information processing systems, 1024–1034. 2017.

EPBM19

Federico Errica, Marco Podda, Davide Bacciu, and Alessio Micheli. A fair comparison of graph neural networks for graph classification. In International Conference on Learning Representations. 2019.

SMBGunnemann18

Oleksandr Shchur, Maximilian Mumme, Aleksandar Bojchevski, and Stephan GÃŒnnemann. Pitfalls of graph neural network evaluation. arXiv preprint arXiv:1811.05868, 2018.

KGrossGunnemann20(1,2)

Johannes Klicpera, Janek Groß, and Stephan GÃŒnnemann. Directional message passing for molecular graphs. In International Conference on Learning Representations. 2020.

JES+20

Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael JL Townshend, and Ron Dror. Learning from protein structure with geometric vector perceptrons. arXiv preprint arXiv:2009.01411, 2020.

Velivckovic22

Petar Veličković. Message passing all the way up. arXiv preprint arXiv:2202.11097, 2022.

BFO+21

Cristian Bodnar, Fabrizio Frasca, Nina Otter, Yuguang Wang, Pietro Lio, Guido F Montufar, and Michael Bronstein. Weisfeiler and lehman go cellular: cw networks. Advances in Neural Information Processing Systems, 34:2625–2640, 2021.

TZK21

Erik Thiede, Wenda Zhou, and Risi Kondor. Autobahn: automorphism-based graph neural nets. Advances in Neural Information Processing Systems, 2021.