GNNによるDFTエネルギーの予測
Contents
18. GNNによるDFTエネルギーの予測¶
QM9は、9つ以下の原子(C, H, O, N, F)で構成された分子の、134,000件におよぶデータセットです[RDRVL14]。
特徴量には、分子を構成する各原子のxyz座標 (
QM9は、2014年に登場して以来、有機化合物の分子構造を扱う機械学習タスクにおける最も標準的なベンチマークデータセットの一つです。このデータセットが登場した当時、この生成エネルギーの回帰問題における予測誤差は10kcal/mol程度でしたが、現在は~1kcal/mol以下の精度の性能が得られるまでに改善されました。 このデータセットを扱う機械学習モデルは、分子の構造に対して並進、回転、順列不変である必要があります。
18.1. ラベルの説明¶
Index |
Name |
Units |
Description |
---|---|---|---|
0 |
index |
- |
Consecutive, 1-based integer identifier of molecule |
1 |
A |
GHz |
Rotational constant A |
2 |
B |
GHz |
Rotational constant B |
3 |
C |
GHz |
Rotational constant C |
4 |
mu |
Debye |
Dipole moment |
5 |
alpha |
Bohr^3 |
Isotropic polarizability |
6 |
homo |
Hartree |
Energy of Highest occupied molecular orbital (HOMO) |
7 |
lumo |
Hartree |
Energy of Lowest unoccupied molecular orbital (LUMO) |
8 |
gap |
Hartree |
Gap, difference between LUMO and HOMO |
9 |
r2 |
Bohr^2 |
Electronic spatial extent |
10 |
zpve |
Hartree |
Zero point vibrational energy |
11 |
U0 |
Hartree |
Internal energy at 0 K |
12 |
U |
Hartree |
Internal energy at 298.15 K |
13 |
H |
Hartree |
Enthalpy at 298.15 K |
14 |
G |
Hartree |
Free energy at 298.15 K |
15 |
Cv |
cal/(mol K) |
Heat capacity at 298.15 K |
18.2. データの準備¶
便利なヘルパーコードを fetch_data.py
に書きました。このコードはデータをダウンロードし、Pythonで使いやすい形式に変換します。またこの関数は、QM9データを特徴量
18.3. このノートブックの動かし方¶
このページ上部の を押すと、このノートブックがGoogle Colab.で開かれます。必要なパッケージのインストール方法については以下を参照してください。
Tip
必要なパッケージをインストールするには、新規セルを作成して次のコードを実行してください。
!pip install dmol-book
もしインストールがうまくいかない場合、パッケージのバージョン不一致が原因である可能性があります。動作確認がとれた最新バージョンの一覧はここから参照できます
import tensorflow as tf
import numpy as np
from fetch_data import qm9_parse, qm9_fetch
import dmol
早速データを読み込みましょう。データをダウンロードおよび処理するため、この手順には数分かかります。
qm9_records = qm9_fetch()
data = qm9_parse(qm9_records)
data
は133k件の分子データを含むイテラブル(訳注:for文でループできるオブジェクト)です。最初の分子を見てみましょう。
for d in data:
print(d)
break
((<tf.Tensor: shape=(5,), dtype=int64, numpy=array([6, 1, 1, 1, 1])>, <tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[-1.2698136e-02, 1.0858041e+00, 8.0009960e-03, -5.3568900e-01],
[ 2.1504159e-03, -6.0313176e-03, 1.9761203e-03, 1.3392100e-01],
[ 1.0117308e+00, 1.4637512e+00, 2.7657481e-04, 1.3392200e-01],
[-5.4081506e-01, 1.4475266e+00, -8.7664372e-01, 1.3392299e-01],
[-5.2381361e-01, 1.4379326e+00, 9.0639728e-01, 1.3392299e-01]],
dtype=float32)>), <tf.Tensor: shape=(16,), dtype=float32, numpy=
array([ 1.0000000e+00, 1.5771181e+02, 1.5770998e+02, 1.5770699e+02,
0.0000000e+00, 1.3210000e+01, -3.8769999e-01, 1.1710000e-01,
5.0480002e-01, 3.5364101e+01, 4.4748999e-02, -4.0478931e+01,
-4.0476063e+01, -4.0475117e+01, -4.0498596e+01, 6.4689999e+00],
dtype=float32)>)
これらはTensorflowのTensor型のデータです。これらのデータは、 x.numpy()
によってNumPyのarrayに変換することができます。最初のアイテムは元素ベクトル 6,1,1,1,1
です。このベクトルがどの元素を表すかわかるでしょうか?そう、C, H, H, H, Hです。次のアイテムは位置です。原子の部分電荷を表す行が含まれますが、今回これは特徴量としては使わないことに注意してください。そして最後はラベルベクトルです。
ここで、これらのデータを加工してより扱いやすい形式にします。まずNumPyのarrayに変換した後、部分電荷を削除し、原子番号をone-hotベクトルに変換しましょう。
def convert_record(d):
# break up record
(e, x), y = d
#
e = e.numpy()
x = x.numpy()
r = x[:, :3]
# make ohc size larger
# so use same node feature
# shape later
ohc = np.zeros((len(e), 16))
ohc[np.arange(len(e)), e - 1] = 1
return (ohc, r), y.numpy()[13]
for d in data:
(e, x), y = convert_record(d)
print("Element one hots\n", e)
print("Coordinates\n", x)
print("Label:", y)
break
Element one hots
[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Coordinates
[[-1.2698136e-02 1.0858041e+00 8.0009960e-03]
[ 2.1504159e-03 -6.0313176e-03 1.9761203e-03]
[ 1.0117308e+00 1.4637512e+00 2.7657481e-04]
[-5.4081506e-01 1.4475266e+00 -8.7664372e-01]
[-5.2381361e-01 1.4379326e+00 9.0639728e-01]]
Label: -40.475117
18.4. ベースラインモデル¶
モデリングに深入りする前に、まずはシンプルなモデルでどこまで精度を上げられるか試してみましょう。これは後でより洗練された手法を作った際に、その手法と比較するためのベースラインモデルを準備しておくという意味でも役に立ちます。シンプルなモデルには多くの選択肢がありますが、ここでは含まれる元素の総数を使った線形回帰モデルを組んでみます。
import jax.numpy as jnp
import jax.experimental.optimizers as optimizers
import jax
import warnings
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
@jax.jit
def baseline_model(nodes, w, b):
# get sum of each element type
atom_count = jnp.sum(nodes, axis=0)
yhat = atom_count @ w + b
return yhat
def baseline_loss(nodes, y, w, b):
return (baseline_model(nodes, w, b) - y) ** 2
baseline_loss_grad = jax.grad(baseline_loss, (2, 3))
w = np.ones(16)
b = 0.0
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In [5], line 2
1 import jax.numpy as jnp
----> 2 import jax.experimental.optimizers as optimizers
3 import jax
4 import warnings
ModuleNotFoundError: No module named 'jax.experimental.optimizers'
これでシンプルな回帰モデルができました。少し難しい点は、各分子が異なる数の原子からなる、つまり分子ごとにtensorのshapeが異なるために、通常のように分子をバッチ処理できない点です。
# we'll just train on 5,000 and use 1,000 for test
# shuffle
shuffled_data = data.shuffle(7000)
test_set = shuffled_data.take(1000)
valid_set = shuffled_data.skip(1000).take(1000)
train_set = shuffled_data.skip(2000).take(5000)
このデータのラベルはかなり大きな値をもつため、そのままモデルに入力するとlossも非常に大きくなり、学習が不安定になる可能性があります。学習がスムーズに進むように、ラベルのスケールを学習率やその他のパラメータに合わせて変換することにします。
ys = [convert_record(d)[1] for d in train_set]
train_ym = np.mean(ys)
train_ys = np.std(ys)
print("Mean = ", train_ym, "Std =", train_ys)
あとは学習時に次の変換
def transform_label(y):
return (y - train_ym) / train_ys
def transform_prediction(y):
return y * train_ys + train_ym
epochs = 16
eta = 1e-3
baseline_val_loss = [0.0 for _ in range(epochs)]
for epoch in range(epochs):
for d in train_set:
(e, x), y_raw = convert_record(d)
y = transform_label(y_raw)
grad_est = baseline_loss_grad(e, y, w, b)
# update regression weights
w -= eta * grad_est[0]
b -= eta * grad_est[1]
# compute validation loss
for v in valid_set:
(e, x), y_raw = convert_record(v)
y = transform_label(y_raw)
# convert SE to RMSE
baseline_val_loss[epoch] += baseline_loss(e, y, w, b)
baseline_val_loss[epoch] = jnp.sqrt(baseline_val_loss[epoch] / 1000)
eta *= 0.9
plt.plot(baseline_val_loss)
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.show()
この性能はかなり低いですが、他のより洗練されたモデルが超えるべきベースライン性能として良い基準となるでしょう。この訓練で用いた変わった点としては、学習率を徐々に下げていったことです。これは、我々の特徴量及びラベルの値がすべて異なる大きさであるためです。このような場合、モデルの重みは、はじめは正しいオーダーにするために大きく動かし、その後で少し微調整する必要があります。そのため、大きな学習率でスタートして徐々に下げていくのです。
ys = []
yhats = []
for v in valid_set:
(e, x), y = convert_record(v)
ys.append(y)
yhat_raw = baseline_model(e, w, b)
yhat = transform_prediction(yhat_raw)
yhats.append(yhat)
plt.plot(ys, ys, "-")
plt.plot(ys, yhats, ".")
plt.xlabel("Energy")
plt.ylabel("Predicted Energy")
plt.show()
かなりバラつきはありますが、このシンプルなモデルでも、おおよその傾向を捉えた予測はできているようです。それではモデルを改良していきましょう。
18.5. GNNモデルの例¶
さて、我々はQM9のデータを使ってエネルギーの予測モデルを組めるようになりました。はじめに述べたように、この問題には分子の構造における回転・並進・順序に対して不変性を備えたモデルが必要です。 まず順序に対して不変とするため、Graph Neural Network(GNN)を用いることにします。そして、座標/元素ベクトルに基づいて、分子を表すグラフを作成します。すなわち、各原子が隣接原子と結合しているとみなし、これを原子をノード、結合をエッジとするグラフで表現します。この時、各原子同士の距離(pairwise距離)の逆数をエッジの重みとして用います。pairwise距離は原子同士の相対的な位置関係にのみ依存するため、これを用いることで、平行移動と回転に対する不変性を得ることができます。また距離の逆数による重み付けによって、近くにある原子はエッジの重みが大きく、逆に遠くにある原子では小さくなります。
それでは、Battaglia方程式 [BHB+18] を使ってモデルを定義します。(授業で機械学習を学んだことがある人にとって)授業で習ったであろうほとんどの例とは対象的に、エッジやノードではなく、ここではグラフ全体の特徴ベクトル
ここで、入力エッジ
全てのノードのグローバルな集約でも同様に和を取ります。よって、グラフ特徴ベクトルのアップデートは次のようになります:
そして、final energyの計算のため、次の回帰関数を使います:
この実装におけるポイントの最後は、各GNNレイヤーではエッジに
18.5.1. JAXによるモデルの実装¶
def x2e(x):
"""convert xyz coordinates to inverse pairwise distance"""
r2 = jnp.sum((x - x[:, jnp.newaxis, :]) ** 2, axis=-1)
e = jnp.where(r2 != 0, 1 / r2, 0.0)
return e
def gnn_layer(nodes, edges, features, we, web, wv, wu):
"""Implementation of the GNN"""
# make nodes be N x N so we can just multiply directly
# ek is now shaped N x N x features
ek = jax.nn.leaky_relu(
web
+ jnp.repeat(nodes[jnp.newaxis, ...], nodes.shape[0], axis=0)
@ we
* edges[..., jnp.newaxis]
)
# sum over neighbors to get N x features
ebar = jnp.sum(ek, axis=1)
# dense layer for new nodes to get N x features
new_nodes = jax.nn.leaky_relu(ebar @ wv) + nodes
# sum over nodes to get shape features
global_node_features = jnp.sum(new_nodes, axis=0)
# dense layer for new features
new_features = jax.nn.leaky_relu(global_node_features @ wu) + features
# just return features for ease of use
return new_nodes, edges, new_features
ここまでの説明をもとにして、座標をpairwise距離の逆数に変換するコードと、上記のGNN方程式を実装しました。実装がどうなっているのか見てみましょう:
graph_feature_len = 8
node_feature_len = 16
msg_feature_len = 16
# make our weights
def init_weights(g, n, m):
we = np.random.normal(size=(n, m), scale=1e-1)
wb = np.random.normal(size=(m), scale=1e-1)
wv = np.random.normal(size=(m, n), scale=1e-1)
wu = np.random.normal(size=(n, g), scale=1e-1)
return [we, wb, wv, wu]
# make a graph
nodes = e
edges = x2e(x)
features = jnp.zeros(graph_feature_len)
# eval
out = gnn_layer(
nodes,
edges,
features,
*init_weights(graph_feature_len, node_feature_len, msg_feature_len),
)
print("input feautres", features)
print("output features", out[2])
バッチリです!グラフ特徴を処理できていますね。次に、この方程式を使って、2層のGNNレイヤーからなるGNNモデルを定義しましょう。合わせてlossも定義します。
# get weights for both layers
w1 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)
w2 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)
w3 = np.random.normal(size=(graph_feature_len))
b = 0.0
@jax.jit
def model(nodes, coords, w1, w2, w3, b):
f0 = jnp.zeros(graph_feature_len)
e0 = x2e(coords)
n0 = nodes
n1, e1, f1 = gnn_layer(n0, e0, f0, *w1)
n2, e2, f2 = gnn_layer(n1, e1, f1, *w2)
yhat = f2 @ w3 + b
return yhat
def loss(nodes, coords, y, w1, w2, w3, b):
return (model(nodes, coords, w1, w2, w3, b) - y) ** 2
loss_grad = jax.grad(loss, (3, 4, 5, 6))
以下のコードでは、GNNの学習率を回帰関数の学習率の1/10にスケーリングする小さな変更を取り入れました。このハックはモデルにとって本質的に重要な点ではありませんが、原著者が試行錯誤で見つけた学習をうまくいくようにするための小さなテクニックです(訳注:このように、DNNの学習ではごく小さな変更がモデルの性能に大きな影響を与える場合がしばしばあります)。
eta = 1e-3
val_loss = [0.0 for _ in range(epochs)]
for epoch in range(epochs):
for d in train_set:
(e, x), y_raw = convert_record(d)
y = transform_label(y_raw)
grad = loss_grad(e, x, y, w1, w2, w3, b)
# update regression weights
w3 -= eta * grad[2]
b -= eta * grad[3]
# update GNN weights
for i, w in [(0, w1), (1, w2)]:
for j in range(len(w)):
w[j] -= eta * grad[i][j] / 10
# compute validation loss
for v in valid_set:
(e, x), y_raw = convert_record(v)
y = transform_label(y_raw)
# convert SE to RMSE
val_loss[epoch] += loss(e, x, y, w1, w2, w3, b)
val_loss[epoch] = jnp.sqrt(val_loss[epoch] / 1000)
eta *= 0.9
plt.plot(baseline_val_loss, label="baseline")
plt.plot(val_loss, label="GNN")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.show()
これは大きなデータセットで、学習には時間がかかるかもしれません。まずは、GNNを使ったモデル構築と学習の原理が伝われば幸いです。最後に、モデルの予測性能を調べましょう。
ys = []
yhats = []
for v in valid_set:
(e, x), y = convert_record(v)
ys.append(y)
yhat_raw = model(e, x, w1, w2, w3, b)
yhats.append(transform_prediction(yhat_raw))
plt.plot(ys, ys, "-")
plt.plot(ys, yhats, ".")
plt.xlabel("Energy")
plt.ylabel("Predicted Energy")
plt.show()
いくつか点が密集したクラスターがありますが、これらは分子の種類や大きさに対応しています。このプロットからは、ここまでに構築したモデルが正しく学習され、クラスター内部で正しい傾向を学習しつつあることはわかりますが、対角線から大きく外れたクラスターを修正し、さらに精度を上げるためには、まだまだモデルの工夫と追加学習が必要そうです。
18.6. QM9データの学習についての関連資料¶
18.7. 参考文献¶
- BHB+18
Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, and others. Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261, 2018.
- RDRVL14
Raghunathan Ramakrishnan, Pavlo O Dral, Matthias Rupp, and O Anatole Von Lilienfeld. Quantum chemistry structures and properties of 134 kilo molecules. Scientific data, 1(1):1–7, 2014.