18. GNNによるDFTエネルギーの予測

QM9は、9つ以下の原子(C, H, O, N, F)で構成された分子の、134,000件におよぶデータセットです[RDRVL14]。 特徴量には、分子を構成する各原子のxyz座標 (X) および元素 (e)が用いられ、各分子の構造は、B3LYP/6-31G(2df,p) レベルのDFT計算により構造緩和されています。 QM9データセットの各データには複数のラベルが付与されていますが(下表を参照)、ここでは生成エネルギー(298.15Kにおけるエンタルピー)に着目します。 この章の目的は、分子の座標が与えられたときに、グラフニューラルネットワークを回帰して生成エネルギーを予測することです。また本章では、これまでに学んだ以下の章の内容を基に進めていきます。

  1. Regression & Model Assessment

  2. Graph Neural Networks

  3. Input Data & Equivariances

QM9は、2014年に登場して以来、有機化合物の分子構造を扱う機械学習タスクにおける最も標準的なベンチマークデータセットの一つです。このデータセットが登場した当時、この生成エネルギーの回帰問題における予測誤差は10kcal/mol程度でしたが、現在は~1kcal/mol以下の精度の性能が得られるまでに改善されました。 このデータセットを扱う機械学習モデルは、分子の構造に対して並進、回転、順列不変である必要があります。

18.1. ラベルの説明

Index

Name

Units

Description

0

index

-

Consecutive, 1-based integer identifier of molecule

1

A

GHz

Rotational constant A

2

B

GHz

Rotational constant B

3

C

GHz

Rotational constant C

4

mu

Debye

Dipole moment

5

alpha

Bohr^3

Isotropic polarizability

6

homo

Hartree

Energy of Highest occupied molecular orbital (HOMO)

7

lumo

Hartree

Energy of Lowest unoccupied molecular orbital (LUMO)

8

gap

Hartree

Gap, difference between LUMO and HOMO

9

r2

Bohr^2

Electronic spatial extent

10

zpve

Hartree

Zero point vibrational energy

11

U0

Hartree

Internal energy at 0 K

12

U

Hartree

Internal energy at 298.15 K

13

H

Hartree

Enthalpy at 298.15 K

14

G

Hartree

Free energy at 298.15 K

15

Cv

cal/(mol K)

Heat capacity at 298.15 K

18.2. データの準備

便利なヘルパーコードを fetch_data.py に書きました。このコードはデータをダウンロードし、Pythonで使いやすい形式に変換します。またこの関数は、QM9データを特徴量 Xe に変換します。ここで X は原子の位置と原子の部分電荷の N×4 行列、 vece は分子内の各原子の原子番号のベクトルです。先述の通りQM9の各データは複数のラベル(=ラベルベクトル)を持つことから、このラベルベクトルから必要なラベルをスライスする必要があることに注意してください。

18.3. このノートブックの動かし方

このページ上部の    を押すと、このノートブックがGoogle Colab.で開かれます。必要なパッケージのインストール方法については以下を参照してください。

import tensorflow as tf
import numpy as np
from fetch_data import qm9_parse, qm9_fetch
import dmol

早速データを読み込みましょう。データをダウンロードおよび処理するため、この手順には数分かかります。

qm9_records = qm9_fetch()
data = qm9_parse(qm9_records)

data は133k件の分子データを含むイテラブル(訳注:for文でループできるオブジェクト)です。最初の分子を見てみましょう。

for d in data:
    print(d)
    break
((<tf.Tensor: shape=(5,), dtype=int64, numpy=array([6, 1, 1, 1, 1])>, <tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[-1.2698136e-02,  1.0858041e+00,  8.0009960e-03, -5.3568900e-01],
       [ 2.1504159e-03, -6.0313176e-03,  1.9761203e-03,  1.3392100e-01],
       [ 1.0117308e+00,  1.4637512e+00,  2.7657481e-04,  1.3392200e-01],
       [-5.4081506e-01,  1.4475266e+00, -8.7664372e-01,  1.3392299e-01],
       [-5.2381361e-01,  1.4379326e+00,  9.0639728e-01,  1.3392299e-01]],
      dtype=float32)>), <tf.Tensor: shape=(16,), dtype=float32, numpy=
array([ 1.0000000e+00,  1.5771181e+02,  1.5770998e+02,  1.5770699e+02,
        0.0000000e+00,  1.3210000e+01, -3.8769999e-01,  1.1710000e-01,
        5.0480002e-01,  3.5364101e+01,  4.4748999e-02, -4.0478931e+01,
       -4.0476063e+01, -4.0475117e+01, -4.0498596e+01,  6.4689999e+00],
      dtype=float32)>)

これらはTensorflowのTensor型のデータです。これらのデータは、 x.numpy() によってNumPyのarrayに変換することができます。最初のアイテムは元素ベクトル 6,1,1,1,1 です。このベクトルがどの元素を表すかわかるでしょうか?そう、C, H, H, H, Hです。次のアイテムは位置です。原子の部分電荷を表す行が含まれますが、今回これは特徴量としては使わないことに注意してください。そして最後はラベルベクトルです。

ここで、これらのデータを加工してより扱いやすい形式にします。まずNumPyのarrayに変換した後、部分電荷を削除し、原子番号をone-hotベクトルに変換しましょう。

def convert_record(d):
    # break up record
    (e, x), y = d
    #
    e = e.numpy()
    x = x.numpy()
    r = x[:, :3]
    # make ohc size larger
    # so use same node feature
    # shape later
    ohc = np.zeros((len(e), 16))
    ohc[np.arange(len(e)), e - 1] = 1
    return (ohc, r), y.numpy()[13]


for d in data:
    (e, x), y = convert_record(d)
    print("Element one hots\n", e)
    print("Coordinates\n", x)
    print("Label:", y)
    break
Element one hots
 [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Coordinates
 [[-1.2698136e-02  1.0858041e+00  8.0009960e-03]
 [ 2.1504159e-03 -6.0313176e-03  1.9761203e-03]
 [ 1.0117308e+00  1.4637512e+00  2.7657481e-04]
 [-5.4081506e-01  1.4475266e+00 -8.7664372e-01]
 [-5.2381361e-01  1.4379326e+00  9.0639728e-01]]
Label: -40.475117

18.4. ベースラインモデル

モデリングに深入りする前に、まずはシンプルなモデルでどこまで精度を上げられるか試してみましょう。これは後でより洗練された手法を作った際に、その手法と比較するためのベースラインモデルを準備しておくという意味でも役に立ちます。シンプルなモデルには多くの選択肢がありますが、ここでは含まれる元素の総数を使った線形回帰モデルを組んでみます。

import jax.numpy as jnp
import jax.experimental.optimizers as optimizers
import jax
import warnings
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")


@jax.jit
def baseline_model(nodes, w, b):
    # get sum of each element type
    atom_count = jnp.sum(nodes, axis=0)
    yhat = atom_count @ w + b
    return yhat


def baseline_loss(nodes, y, w, b):
    return (baseline_model(nodes, w, b) - y) ** 2


baseline_loss_grad = jax.grad(baseline_loss, (2, 3))
w = np.ones(16)
b = 0.0
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In [5], line 2
      1 import jax.numpy as jnp
----> 2 import jax.experimental.optimizers as optimizers
      3 import jax
      4 import warnings

ModuleNotFoundError: No module named 'jax.experimental.optimizers'

これでシンプルな回帰モデルができました。少し難しい点は、各分子が異なる数の原子からなる、つまり分子ごとにtensorのshapeが異なるために、通常のように分子をバッチ処理できない点です。

# we'll just train on 5,000 and use 1,000 for test
# shuffle
shuffled_data = data.shuffle(7000)
test_set = shuffled_data.take(1000)
valid_set = shuffled_data.skip(1000).take(1000)
train_set = shuffled_data.skip(2000).take(5000)

このデータのラベルはかなり大きな値をもつため、そのままモデルに入力するとlossも非常に大きくなり、学習が不安定になる可能性があります。学習がスムーズに進むように、ラベルのスケールを学習率やその他のパラメータに合わせて変換することにします。

ys = [convert_record(d)[1] for d in train_set]
train_ym = np.mean(ys)
train_ys = np.std(ys)
print("Mean = ", train_ym, "Std =", train_ys)

あとは学習時に次の変換 ys=yμyσy を加え、出力に対しては逆に y^=f^(e,x)σy+μy とすることで、出力のレンジを一定の範囲に収めることができます(正規化と呼びます)。

def transform_label(y):
    return (y - train_ym) / train_ys


def transform_prediction(y):
    return y * train_ys + train_ym
epochs = 16
eta = 1e-3
baseline_val_loss = [0.0 for _ in range(epochs)]
for epoch in range(epochs):
    for d in train_set:
        (e, x), y_raw = convert_record(d)
        y = transform_label(y_raw)
        grad_est = baseline_loss_grad(e, y, w, b)
        # update regression weights
        w -= eta * grad_est[0]
        b -= eta * grad_est[1]
    # compute validation loss
    for v in valid_set:
        (e, x), y_raw = convert_record(v)
        y = transform_label(y_raw)
        # convert SE to RMSE
        baseline_val_loss[epoch] += baseline_loss(e, y, w, b)
    baseline_val_loss[epoch] = jnp.sqrt(baseline_val_loss[epoch] / 1000)
    eta *= 0.9
plt.plot(baseline_val_loss)
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.show()

この性能はかなり低いですが、他のより洗練されたモデルが超えるべきベースライン性能として良い基準となるでしょう。この訓練で用いた変わった点としては、学習率を徐々に下げていったことです。これは、我々の特徴量及びラベルの値がすべて異なる大きさであるためです。このような場合、モデルの重みは、はじめは正しいオーダーにするために大きく動かし、その後で少し微調整する必要があります。そのため、大きな学習率でスタートして徐々に下げていくのです。

ys = []
yhats = []
for v in valid_set:
    (e, x), y = convert_record(v)
    ys.append(y)
    yhat_raw = baseline_model(e, w, b)
    yhat = transform_prediction(yhat_raw)
    yhats.append(yhat)


plt.plot(ys, ys, "-")
plt.plot(ys, yhats, ".")
plt.xlabel("Energy")
plt.ylabel("Predicted Energy")
plt.show()

かなりバラつきはありますが、このシンプルなモデルでも、おおよその傾向を捉えた予測はできているようです。それではモデルを改良していきましょう。

18.5. GNNモデルの例

さて、我々はQM9のデータを使ってエネルギーの予測モデルを組めるようになりました。はじめに述べたように、この問題には分子の構造における回転・並進・順序に対して不変性を備えたモデルが必要です。 まず順序に対して不変とするため、Graph Neural Network(GNN)を用いることにします。そして、座標/元素ベクトルに基づいて、分子を表すグラフを作成します。すなわち、各原子が隣接原子と結合しているとみなし、これを原子をノード、結合をエッジとするグラフで表現します。この時、各原子同士の距離(pairwise距離)の逆数をエッジの重みとして用います。pairwise距離は原子同士の相対的な位置関係にのみ依存するため、これを用いることで、平行移動と回転に対する不変性を得ることができます。また距離の逆数による重み付けによって、近くにある原子はエッジの重みが大きく、逆に遠くにある原子では小さくなります。

それでは、Battaglia方程式 [BHB+18] を使ってモデルを定義します。(授業で機械学習を学んだことがある人にとって)授業で習ったであろうほとんどの例とは対象的に、エッジやノードではなく、ここではグラフ全体の特徴ベクトル u に着目します。 学習が進むにつれて、u は我々が予測したいターゲットであるエネルギーと結び付けられていきます。エッジの更新では、sender(エッジと結合したノード)および、学習可能な重みを持つエッジのみを考慮します。

(18.1)ek=ϕe(ek,vrk,vsk,u)=σ(vskweek+be)

ここで、入力エッジ ek は単一の数値(pairwise距離の逆数)であり、 be は学習可能なバイアスベクトルです。ここでは示しませんが、エッジの更新では和をとって特徴を集約します。 σ はleaky ReLU関数です(訳注:活性化関数の一つであるReLUの亜種のこと)。原著者は、ReLUではなくleakly ReLUを用いることで、勾配消失を起こさず性能が上がることを発見しました。よって、ノードの更新は次のようになります:

(18.2)vi=ϕv(e¯i,vi,u)=σ(Wve¯i)+vi

全てのノードのグローバルな集約でも同様に和を取ります。よって、グラフ特徴ベクトルのアップデートは次のようになります:

(18.3)u=ϕu(e¯,v¯,u)=σ(Wuv¯)+u

そして、final energyの計算のため、次の回帰関数を使います:

(18.4)E^=wu+b

この実装におけるポイントの最後は、各GNNレイヤーではエッジに u およびノードベクトルを渡すものの、エッジそのものの重みは全てのGNNレイヤーで同じとすることである。この実装はexampleモデルであり、上で述べた様々な事柄については変更の余地があることに注意してください。また、分子グラフを扱うためには、ここで用いているGNNのカーネル(ノード・エッジ・グラフの更新方法)ではない他の方法のほうがうまくいくことが知られている。とはいえ、まずはこのモデルを実装して、うまくいくかどうか見てみることにしましょう。

18.5.1. JAXによるモデルの実装

def x2e(x):
    """convert xyz coordinates to inverse pairwise distance"""
    r2 = jnp.sum((x - x[:, jnp.newaxis, :]) ** 2, axis=-1)
    e = jnp.where(r2 != 0, 1 / r2, 0.0)
    return e


def gnn_layer(nodes, edges, features, we, web, wv, wu):
    """Implementation of the GNN"""
    # make nodes be N x N so we can just multiply directly
    # ek is now shaped N x N x features
    ek = jax.nn.leaky_relu(
        web
        + jnp.repeat(nodes[jnp.newaxis, ...], nodes.shape[0], axis=0)
        @ we
        * edges[..., jnp.newaxis]
    )
    # sum over neighbors to get N x features
    ebar = jnp.sum(ek, axis=1)
    # dense layer for new nodes to get N x features
    new_nodes = jax.nn.leaky_relu(ebar @ wv) + nodes
    # sum over nodes to get shape features
    global_node_features = jnp.sum(new_nodes, axis=0)
    # dense layer for new features
    new_features = jax.nn.leaky_relu(global_node_features @ wu) + features
    # just return features for ease of use
    return new_nodes, edges, new_features

ここまでの説明をもとにして、座標をpairwise距離の逆数に変換するコードと、上記のGNN方程式を実装しました。実装がどうなっているのか見てみましょう:

graph_feature_len = 8
node_feature_len = 16
msg_feature_len = 16

# make our weights
def init_weights(g, n, m):
    we = np.random.normal(size=(n, m), scale=1e-1)
    wb = np.random.normal(size=(m), scale=1e-1)
    wv = np.random.normal(size=(m, n), scale=1e-1)
    wu = np.random.normal(size=(n, g), scale=1e-1)
    return [we, wb, wv, wu]


# make a graph
nodes = e
edges = x2e(x)
features = jnp.zeros(graph_feature_len)

# eval
out = gnn_layer(
    nodes,
    edges,
    features,
    *init_weights(graph_feature_len, node_feature_len, msg_feature_len),
)
print("input feautres", features)
print("output features", out[2])

バッチリです!グラフ特徴を処理できていますね。次に、この方程式を使って、2層のGNNレイヤーからなるGNNモデルを定義しましょう。合わせてlossも定義します。

# get weights for both layers
w1 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)
w2 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)
w3 = np.random.normal(size=(graph_feature_len))
b = 0.0


@jax.jit
def model(nodes, coords, w1, w2, w3, b):
    f0 = jnp.zeros(graph_feature_len)
    e0 = x2e(coords)
    n0 = nodes
    n1, e1, f1 = gnn_layer(n0, e0, f0, *w1)
    n2, e2, f2 = gnn_layer(n1, e1, f1, *w2)
    yhat = f2 @ w3 + b
    return yhat


def loss(nodes, coords, y, w1, w2, w3, b):
    return (model(nodes, coords, w1, w2, w3, b) - y) ** 2


loss_grad = jax.grad(loss, (3, 4, 5, 6))

以下のコードでは、GNNの学習率を回帰関数の学習率の1/10にスケーリングする小さな変更を取り入れました。このハックはモデルにとって本質的に重要な点ではありませんが、原著者が試行錯誤で見つけた学習をうまくいくようにするための小さなテクニックです(訳注:このように、DNNの学習ではごく小さな変更がモデルの性能に大きな影響を与える場合がしばしばあります)。

eta = 1e-3
val_loss = [0.0 for _ in range(epochs)]
for epoch in range(epochs):
    for d in train_set:
        (e, x), y_raw = convert_record(d)
        y = transform_label(y_raw)
        grad = loss_grad(e, x, y, w1, w2, w3, b)
        # update regression weights
        w3 -= eta * grad[2]
        b -= eta * grad[3]
        # update GNN weights
        for i, w in [(0, w1), (1, w2)]:
            for j in range(len(w)):
                w[j] -= eta * grad[i][j] / 10
    # compute validation loss
    for v in valid_set:
        (e, x), y_raw = convert_record(v)
        y = transform_label(y_raw)
        # convert SE to RMSE
        val_loss[epoch] += loss(e, x, y, w1, w2, w3, b)
    val_loss[epoch] = jnp.sqrt(val_loss[epoch] / 1000)
    eta *= 0.9
plt.plot(baseline_val_loss, label="baseline")
plt.plot(val_loss, label="GNN")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.show()

これは大きなデータセットで、学習には時間がかかるかもしれません。まずは、GNNを使ったモデル構築と学習の原理が伝われば幸いです。最後に、モデルの予測性能を調べましょう。

ys = []
yhats = []
for v in valid_set:
    (e, x), y = convert_record(v)
    ys.append(y)
    yhat_raw = model(e, x, w1, w2, w3, b)
    yhats.append(transform_prediction(yhat_raw))


plt.plot(ys, ys, "-")
plt.plot(ys, yhats, ".")
plt.xlabel("Energy")
plt.ylabel("Predicted Energy")
plt.show()

いくつか点が密集したクラスターがありますが、これらは分子の種類や大きさに対応しています。このプロットからは、ここまでに構築したモデルが正しく学習され、クラスター内部で正しい傾向を学習しつつあることはわかりますが、対角線から大きく外れたクラスターを修正し、さらに精度を上げるためには、まだまだモデルの工夫と追加学習が必要そうです。

18.6. QM9データの学習についての関連資料

18.7. 参考文献

BHB+18

Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, and others. Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261, 2018.

RDRVL14

Raghunathan Ramakrishnan, Pavlo O Dral, Matthias Rupp, and O Anatole Von Lilienfeld. Quantum chemistry structures and properties of 134 kilo molecules. Scientific data, 1(1):1–7, 2014.