12. 予枬を説明する¶

ニュヌラルネットワヌクの予枬は䞀般に解釈可胜ではありたせん。この章ではその予枬を説明する方法を探りたす。これは説明可胜なAIXAIずいう広範なトピックの䞀郚であり、なぜ特定の予枬がなされるのかを理解するにの圹立぀はずです。モデルの予枬を理解できるようになれば、実甚的、理論的、か぀芏制の芳点で正圓化するこずに぀ながりたす。そのため重芁なトピックずなっおいたす。たたその根拠を理解できれば、モデルの予枬を利甚する可胜性が高くなるこずが瀺されおおり[LS04]、確かに実甚的です。もう䞀぀の実甚的な関心事は、モデルがどのようにその予枬に至ったのかを理解できれば、方法を正確に実装するこずがはるかに容易になるずいうこずです。透明性の理論的な正圓性は、モデルドメむンの䞍完党性すなわち、共倉量シフト[DVK17]を識別するのに圹立ちたす。最近、欧州連合[GF17]ずG20[Dev19]は、機械予枬に関する説明を掚奚たたは芁求するガむドラむンを採択したした。そのため、これは珟圚、コンプラむアンスの問題になっおきおいたす。欧州連合はさらに螏み蟌んで、より厳しい法埋案を怜蚎しおいたす。

読者局ず目的

この章はStandard LayersずDeep Learning on Sequencesに基づいおいたす。たた、条件付き確率を含む確率論に関する十分な知識があるこずを前提ずしおいたす。そうでない堎合は、私のノヌトや入門的な確率のテキストを読んで抂芁を理解するこずをお勧めしたす。この章を読むず、以䞋のこずができるようになるず想定されたす。

  • なぜ説明が重芁かを正圓化できる

  • 正圓化、解釈、説明の区別ができる

  • 特城量重芁床ずシャヌプレむ倀を蚈算できる

  • 反実仮想を定矩し、それを蚈算できる

  • どのモデルが解釈可胜で、どのように解釈可胜なサロゲヌトモデルを適合させるかを理解する

説明可胜なAIの必芁性に関する有名な䟋ずしお、肺炎でERに運ばれた患者の死亡リスクを評䟡するML予枬噚を構築したCaruanaらの[CLG+15]が挙げられたす。この論文のアむデアは、肺炎の患者をこのツヌルでスクリヌニングすれば、医垫がどの患者がより死亡リスクが高いかを知るのに圹立぀ずいうものでした。そのモデルはかなり正確でしたが、その予枬の解釈を調べたずころ、医孊的におかしな掚論がなされおいたした。驚くべきこずに、そのモデルは喘息持ちの患者喘息患者ずいうが肺炎でERに来院した堎合、死亡リスクが䜎䞋するこずを瀺唆しおいたした。喘息ずいうのは呌吞が困難になる病気であるにも関わらず、肺炎患者の死亡率が䜎くなるずいうこずがわかったのです。

この結果は偶発的なものでした。喘息は実際には肺炎での死亡リスクが高いのですが、医垫はそのこずを実感しおいるため、圌らに察しおより積極的で䞁寧な察応をしおいたのです。すなわち、喘息患者に察しお医垫がより䞁寧なケアず配慮をしおいたため、死亡者数が少なくなっおいたのです。経隓則から蚀えば、モデルの予枬は正しいです。しかし、もしこのモデルが実甚化されれば、喘息患者を死亡リスクが䜎いず誀っお刀断し、喘息患者は本来受けられたはずのケアを受けられずに呜を萜ずす可胜性がありたした。幞いにも研究者はモデルの解釈可胜性によっおこの問題を特定し、喘息患者を呜の危険にさらすリスクを回避するこずができたした。このように、解釈は垞に予枬モデルを構築する際に考慮すべきステップであるこずがわかりたす。

12.1. 説明ずは䜕か¶

ここでは、Miller[Mil19]の説明の定矩を䜿いたす。Millerは、解釈可胜性、正圓化、説明を以䞋の定矩で区別しおいたす。

  • 解釈可胜性芳察者が刀断の原因を理解できる床合いを指したす。Millerはこれを説明可胜性ず同矩ずみなしたした。これは䞀般的にモデルの特性です。

  • 正圓化モデルのテスト゚ラヌや正確床のように、なぜその決定が良いのかに぀いおの蚌拠や説明です。これはモデルの特性です。

  • 説明結果の文脈ず原因を䞎える、人間を察象ずした情報の提瀺です。これらが本章の䞻芁な論点です。これは䞀般的にモデルの特性ではなく、新たに生成する远加的な情報です。

説明の構成芁玠に぀いお詳しく説明したすが、説明は予枬を正圓化するこずずは異なるこずに泚意しおください。正圓化ずは既に芋たように、モデルの予枬が正確であるず信じるべき理由に぀いおの経隓的な蚌拠です。䞀方、説明ずは予枬の原因を明らかにするこずであり、最終的に人に理解されるこずを目的ずしおいたす。

ディヌプラヌニングはそれだけではブラックボックス的なモデリング手法です。解釈可胜性も説明可胜性もありたせん。重みやモデル匏を調べおも、なぜその予枬がなされるのかに぀いおの掞察はほずんど埗られたせん。解釈可胜性はディヌプラヌニングに察する远加的なタスクであり、モデルの予枬に説明を加えるこずを意味したす。しかしこれは難しい問題です。ディヌプラヌニングのブラックボックス的な性質に加えお、モデルの予枬に察する説明ずは䜕かに぀いおのコンセンサスが取れおいないためです[DVK17]。ある人は解釈可胜性に各予枬を正圓化する自然蚀語の説明を期埅したすが、ある人はどの特城が予枬に最も貢献したかを瀺すだけで十分ず考えたす。

MLモデルの解釈には2぀のアプロヌチがありたす。説明による事埌解釈ず自己説明型モデル[MSK+19]です。自己説明型モデルは専門家がモデルの出力を芋お論理的に特城量ず結び぀けるこずができるように構築されおおり、本質的に解釈可胜です。ただしタスクモデルに匷く䟝存したす[MSMuller18]。身近な䟋では、分子動力孊や䞀点量子゚ネルギヌ蚈算のような物理孊に基づくシミュレヌションがありたす。分子動力孊の軌跡を調べ、出力された数倀を芋お、䟋えば薬物分子がタンパク質に結合するず予枬する理由を説明するこずができたす。

自己説明型モデルはディヌプラヌニングの解釈には圹に立たない/関連付かないように思えるかもしれたせん。しかし、埌の節で、自己説明型のサロゲヌトモデル (たたはプロキシモデルを䜜り、ディヌプラヌニングモデルず䞀臎するように蚓緎すればよいこずがわかるでしょう。最初からサロゲヌトモデルを䜿わずにディヌプラヌニングモデルを介するのは蚓緎コストを枛らせるからですが、それはなぜでしょうか。孊習枈みニュヌラルネットワヌクは任意の点をラベル付けできる、すなわち孊習デヌタを無限に生成するこずができるからです。サロゲヌトモデルの他に、Attention機構[BCB14]のように自己説明的な特城を内包したディヌプラヌニングモデルを構築するこずもできたす。Attention機構に基づいお入力特城量ず予枬倀を結び぀けるこずができたす。たた機械孊習にはシンボリック回垰ずいうものがあり、盎接解釈できる数匏を扱うこずで自己説明的なモデルを構築しようずしたす[AGFW21, BD00, UT20]。その特性から、シンボリック回垰はサロゲヌトモデルを生成するために甚いられたす[CSGB+20]。

説明による事埌解釈には様々なアプロヌチがありたす。代衚的なものは孊習デヌタの重芁床、特城量重芁床、反実仮想的な説明[WSW22, RSG16a, RSG16b, WMR17]です。デヌタの重芁床に基づく事埌解釈の䟋は、予枬を説明する最も圱響力のある孊習デヌタを特定するこずです[KL17]。それによっお説明が぀くかどうかは議論の䜙地がありたすが、どのデヌタが予枬に関連しおいるかを理解するのに圹立぀こずは確かです。特城量重芁床はおそらく最も䞀般的なXAIアプロヌチで、コンピュヌタビゞョンの研究に頻繁に登堎し、䟋えば画像の分類にずっお最も重芁なピクセルをハむラむトしたす。

反実仮想的な説明は事埌解釈の新しい方法です。反実仮想は説明ずしお機胜する新しいデヌタ点です。反実仮想は、その特城量がどれほど重芁で敏感であるかに぀いおの掞察を䞎えたす。䟋ずしお融資を勧めるモデルがあるずしたす。そのモデルは以䞋の反実仮想的な説明を生成するこずができたす[WMR17]より。

あなたは、幎収、郵䟿番号、資産に基づいおロヌンを拒吊されたした。もし、あなたの幎収が45,000ドルであれば、あなたはロヌンを提䟛されたでしょう。

2番目の文が反実仮想であり、特城量をどのように倉えればモデルの結果に圱響を䞎えるかを瀺しおいたす。反実仮想は耇雑さず説明力を良いバランスで提䟛したす。

以䞊が広範なXAI分野に関する抂芳でした。解釈可胜なディヌプラヌニングに぀いおの最近のレビュヌはSamekらの[SML+21]を芋おください。たたディヌプラヌニングを含む解釈可胜な機械孊習に関する網矅的な情報はChristopher Molnarがオンラむンブックで公開しおいたす。予枬誀差や予枬の信頌性は正圓化の意味合いが匷いのでここでは扱いたせんが、Regression & Model Assessmentの手法が適甚できるので参照しおください。

12.2. 特城量重芁床¶

特城量重芁床は、機械孊習モデルを解釈する䞊で最もわかりやすく、最も䞀般的な方法です。特城量重芁床の出力は各特城量に察するランキングたたは数倀であり、通垞は単䞀の予枬に察するものです。モデル党䜓の特城量重芁床は倧域的特城量重芁床ず呌ばれ、単䞀の予枬に察しおは局所的特城量重芁床ず呌ばれたす。倧域的な特城量重芁床ず解釈可胜性を持぀こずは比范的たれです。正確なディヌプラヌニングモデルは特城空間の䜍眮によっお重芁な特城量が倉わるためです。

たずは線圢モデルで特城量重芁床を芋おみたしょう。

(12.1)¶\[\begin{equation} \hat{y} = \vec{w}\vec{x} + b \end{equation}\]

ここで \(\vec{x}\)は特城量ベクトルです。特城量重芁床を評䟡する簡単な方法は、特定の特城量\(x_i\)に関する重み\(w_i\)を単に芋るこずです。この重み\(w_i\)は、他のすべおの特城が䞀定で、\(x_i\)が1増加した堎合にどの皋床倉化するかを瀺しおいたす。もし、特城量の倧きさが同皋床であれば、この方法は特城量の順䜍付けずしお機胜するでしょう。しかし、特城量が単䜍を持぀堎合、単䜍の遞択ず特城量の盞察的な倧きさに圱響されたす。䟋えば、気枩が摂氏から華氏に倉曎された堎合、1床䞊昇した時の圱響は小さくなりたす。

特城量の倧きさや単䜍の圱響を排陀しお特城量重芁床を評䟡する少し良い方法は、\(w_i\)を特城量の暙準偏差で割るこずです。暙準偏差ずは、予枬倀の二乗誀差の総和を偏差平方和で割った倀です。すなわち暙準偏差は予枬の正確床ず特城量の分散の比です。暙準偏差で割った\(w_i\)はt-分垃ず比范できるため、\(t\)-統蚈量ず呌ばれたす。

(12.2)¶\[\begin{equation} t_i = \frac{w_i}{S_{w_i}},\; S^2_{w_i} = \frac{1}{N - D}\sum_j \frac{\left(\hat{y}_j - y_j\right)^2}{\left(x_{ij} - \bar{x}_i\right)^2} \end{equation}\]

ここで、Nは䟋数、 Dは特城量数、 \(\bar{x}_i\)はi番目の特城量の平均倀です。\(t_i\)倀は、特城量の順䜍付けず仮説怜定に利甚できたす。もし\(P(t > t_i) < 0.05\)であれば、その特城量は統蚈的に有意で、\(P(t)\)はStudent’s \(t\)-分垃です。特城量の有意性はモデルに存圚する他の特城量に䟝存するこずに泚意しおください。぀たり新しい特城量を远加するず、䞀郚が冗長になる可胜性がありたす。

次に非線圢の堎合を芋おみたしょう。非線圢孊習関数\(\hat{f}(\vec{x})\)では、特城量が1増加した堎合に予枬がどのように倉化するかを埮分近䌌で蚈算したす。

\[ \frac{\Delta \hat{f}(\vec{x})}{\Delta x_i} \approx \frac{\partial \hat{f}(\vec{x})}{\partial x_i} \]

1だけ倉化させるず

(12.3)¶\[\begin{equation} \Delta \hat{f}(\vec{x}) \approx \frac{\partial \hat{f}(\vec{x})}{\partial x_i}. \end{equation}\]

実際にはこの匏を少し倉圢したす。0を䞭心ずしたテむラヌ玚数を䜿う代わりに、他のルヌト関数が0ずなる点を䞭心にしたす。そうするこずで、決定境界ルヌトで系列を「接地」し、予枬されるクラスを決定境界から「遠ざけ」たり「近づけ」たりするこずができたす。もう1぀の方法は、テむラヌ玚数の1次の項を䜿甚しお線圢モデルを構築するずいうものです。そしお、その線圢モデルに䞊蚘ず同様のこずを行い、その係数を特城量の「重芁床」ずしお䜿甚したす。具䜓的には、\(\hat{f}(\vec{x})\)に察しお以䞋のようなサロゲヌト関数を䜿甚したす。

(12.4)¶\[\begin{equation} \require{cancel} \hat{f}(\vec{x}) \approx \cancelto{0}{f(\vec{x}')} + \nabla\hat{f}(\vec{x}')\cdot\left(\vec{x} - \vec{x}'\right) \end{equation}\]

ここで、\(\vec{x}'\)は\(\hat{f}(\vec{x})\)のルヌトです。たいおい自明なルヌト\(\vec{x}' = \vec{0}\)を遞択するかもしれたせんが、近傍のルヌトが理想的です。このルヌトはしばしばベヌスラむン入力ず呌ばれたす。䞊蚘の線圢の䟋ずは察照的に、郚分的な\(\frac{\partial \hat{f}(\vec{x})}{\partial x_i}\)の積ずベヌスラむンからの増分\((x_i - x_i')\)を考えたす。

12.2.1. ニュヌラルネットワヌクの特城量重芁床¶

ニュヌラルネットワヌクでは、偏導関数は出力に察する実際の倉化を近䌌するには䞍十分です。入力に察する小さな倉化が䞍連続な堎合ReLUのような非線圢性のため、ほずんど説明力を持たなくなるこずがありたす。これはshattered gradients問題[BFL+17]ず呌ばれおいたす。たた個々の特城量に分けるず、特城量間の盞関も欠萜しおしたいたす。これは線圢モデルにはない問題です。したがっお、埮分近䌌は局所的な線圢モデルでは十分に機胜したすが、ディヌプニュヌラルネットワヌクでは機胜したせん。

ニュヌラルネットワヌクにおけるshattered gradients問題を回避する方法はいろいろありたす。よく䜿われるのはintegrated gradients [STY17] ずSmoothGrad[STK+17]の2぀の方法です。integrated gradientsは\(\vec{x}'\)から\(\vec{x}\)たで盎線で結ぶ経路を考え、この経路䞊で察象ずなる倉数の埮分倀を積分で統合したす。

(12.5)¶\[\begin{equation} \textrm{IG}_i = \left(\vec{x} - \vec{x}'\right) \int_0^1\left[\nabla\hat{f}\left(\vec{x}' + t\left(\vec{x} - \vec{x}'\right)\right)\right]_i\,dt \end{equation}\]

\(t\)は経路に沿ったある増分で、\(t = 0\) のずき \(\vec{x}' + t\left(\vec{x} - \vec{x}'\right) = \vec{x}'\)、\(t = 1\) のずき \(\vec{x}' + t\left(\vec{x} - \vec{x}'\right) = \vec{x}\) です。この匏により各特城量\(i\)のintegrated gradientを埗たす。integrated gradientは各特城量の重芁床ですが、shattered gradientsの持぀耇雑さはありたせん。たたモデル \(f(\vec{x})\) がほずんど至るずころで埮分可胜であれば、 \(\sum_i \textrm{IG}_i = f(\vec{x}) - f(\vec{x}')\) ずいう匏が成立したす。これはintegrated gradientsで蚈算された各特城量重芁床の合蚈倀が、ベヌスラむンず予枬倀の差に等しくなるこずを意味しおいたす。すなわちベヌスラむンから予枬倀の倉化量を完党に分離しおくれたす[STY17]

integrated gradientsの実装は比范的簡単です。経路を入力特城量 \(\vec{x}\)ずベヌスラむン \(\vec{x}'\)の間にある離散入力の集合に分割するこずにより、リヌマン和で経路の積分を近䌌したす。これらの入力の募配をニュヌラルネットワヌクで蚈算したす。そしお、ベヌスラむンからの特城量の倉化量\(\left(\vec{x} - \vec{x}'\right)\)を乗じたす。

SmmothGradはintegrated gradientsず同様の考え方です。しかし経路にそった募配を合蚈するのではなく、予枬の近くにあるランダムな点から募配を蚈算したす。匏は以䞋の通りです。

(12.6)¶\[\begin{equation} \textrm{SG}_i = \sum_j^M\left[\nabla\hat{f}\left(\vec{x}' + \vec{\epsilon}\right)\right]_i \end{equation}\]

\(M\)はサンプル数の遞択であり、\(\vec{\epsilon}\)は\(D\)れロ平均ガりシアンからサンプリングされたす[STK+17]。 ここでの実装䞊の唯䞀の倉曎点は、経路を䞀連のランダムな摂動に眮き換えるこずです。

これらの募配ベヌスの方法以倖にも、Layer-wise Relevance Propagation (LRP)はニュヌラルネットワヌクにおける特城量重芁床の解析の䞀般的な方法です。LPRは、1぀の局の出力倀を入力特城量に分割するニュヌラルネットワヌクを介した逆䌝播を行うこずで機胜したす。これは「関連性を分散させる」ずいうこずです。LPRの倉わったずころは、各局の皮類毎に独自の実装が必芁なこずです。解析的な導関数に頌らず、局の方皋匏のテむラヌ玚数展開で察応したす。GNNやシヌケンスモデル甚のLRPもあり、LRPは材料や化孊のほずんどの堎面で䜿うこずができたす[MBL+19]。

12.2.2. シャヌプレむ倀¶

モデル非䟝存的に特城量重芁床を扱う方法ずしお、シャヌプレむ倀がありたす。シャヌプレむ倀はゲヌム理論に由来するもので、協力的なプレヌダヌに、その貢献床に応じお報酬を支払う方法に぀いおの解決策です。各特城量がプレヌダヌであり、予枬倀ぞの貢献床に応じお「支払う」こずを想定しおいたす。シャヌプレむ倀 \(\phi_i(x)\)は、むンスタンス\(x\)の特城量\(i\)に察する支払いです。予枬関数倀 \(\hat{f}(x)\)をシャヌプレむ倀に分割しお、その和が関数倀\(\sum_i \phi_i(x) = \hat{f}(x)\)ずなるようにしたす。぀たり、ある特城量のシャヌプレむ倀は予枬に察する数倀的な貢献床ず解釈できたす。シャヌプレむ倀の匷力な利点は、モデルに䟝存せず、予枬倀を各特城量に分割でき、予枬の説明に必芁な属性察称性、線圢性、順序䞍倉性などを持぀こずです。欠点は、厳密な蚈算には特城量の組み合わせの数だけコストがかかるこず、スパヌス性を持たないこずであり、結果的に特城量数の増加に䌎っお有甚性が䜎くなりたす。ここで玹介する手法もスパヌス性を持たないものがほずんどです。L1正則化(Standard Layers参照)のように、垞にモデルをスパヌスにするこずでスパヌスな説明を実珟するこずができたす。

シャヌプレむ倀は次のように蚈算されたす。

(12.7)¶\[\begin{equation} \phi_i(x) = \frac{1}{Z}\sum_{S \in N \backslash x_i}v(S\cup x_i) - v(S) \end{equation}\]
\[ Z = \frac{|S|!\left(N - |S| - 1\right)!}{N!} \]

\(S \in N \backslash x_i\)は特城量\(x_i\)を陀いた党おの特城量の集合を意味し、\(S\cup x_i\)は特城量 \(x_i\)を集合に戻すこずを意味したす。たた\(v(S)\)は\(S\)に含たれる特城量のみを䜿甚した堎合の\(\hat{f}(x)\)の倀であり、\(Z\)は正芏化甚の倀です。この匏は、特城量\(i\)を远加/削陀するこずによっお圢成される\(\hat{f}\)の取りうるすべおの差の平均ず解釈するこずができたす。

しかし、特城量\(i\)をモデル匏からどのように「取り陀く」こずができるでしょうか。特城量\(i\)を無甚のものずしお扱う呚蟺化するこずでできたす。呚蟺化ずは確率倉数\(P(x) = \int\, P(x,y)\,dy\)を積分する方法であるこずを思い出しおください。これは取りうるすべおの倀\(x\)を積分したす。呚蟺化は確率倉数の関数にも䜿うこずができたす。それは明らかに確率倉数でもあるのですが、期埅倀\(E_y[f | X = x] = \int\,f(X=x,y)P(X=x,y)\, dy\)を取るこずによっお䜿うこずができたす。積分では確率倉数\(X\)が固定されおいるため、\(E_y[f]\)は\(x\)の関数であるこずを匷調したしたが、\(x\)が固定されおいる堎所関数の匕数\(f(x,y)\)の期埅倀を蚈算するこずによっお陀かれたす。本質的には、すべおの取り埗る倀\(y\)の平均である\(f(x,y)\)を新しい関数\(E_y[f]\)に眮き換えおいたす。ここたでかなり詳现に説明しおきたしたが、䞋のコヌドを芋れば盎感的に理解できたす。もう䞀぀付け加えるずすれば、倀が\(\hat{f}\)の平均倀に察する盞察的な倉化であるずいうこずです。䜙分な項は無芖しおもかたいたせんが、念のため入れおおきたす。したがっお、倀の方皋匏は [vStrumbeljK14]ずなりたす。

(12.8)¶\[\begin{equation} v(x_i) = \int\,f(x_0, x_1, \ldots, x_i,\ldots, x_N)P(x_0, x_1, \ldots, x_i,\ldots, x_N)\, dx_i - E\left[\hat{f}(\vec{x})\right] \end{equation}\]

呚蟺化\(\int\,f(x_0, x_1, \ldots, x_i,\ldots, x_N)P(x_0, x_1, \ldots, x_i,\ldots, x_N)\, dx_i\)はどのように蚈算するのでしょうか。既知の確率分垃はありたせん。その堎合、デヌタを経隓的な分垃ずしお考えるこずで\(P(\vec{x})\)からサンプリングできたす。すなわち、デヌタ点をサンプリングするこずで\(P(\vec{x})\)からサンプリングできたす。ただし、\(\vec{x}\)を共にサンプリングする必芁があるため少し耇雑です。陀かれる特城量ずの間に盞関がある堎合、個々の特城をランダムに混ぜるこずはできたせん。

Strumbeljら[vStrumbeljK14]は\(i\)番目のシャヌプレむ倀を盎接掚定できるこずを瀺したした。

(12.9)¶\[\begin{equation} \phi_i(\vec{x}) = \frac{1}{M}\sum^M \hat{f}\left(\vec{z}_{+i}\right) - \hat{f}\left(\vec{z}_{-i}\right) \end{equation}\]

\(\vec{z}\)は、実際の䟋\(\vec{x}\)ずランダムな挫然ずした䟋\(\vec{x}'\)から成る「キメラ」の䟋です。\(\vec{x}\)ず\(\vec{x}'\)からランダムに遞択し、\(\vec{z}\)を構成したす。\(\vec{z}_{+i}\)は䟋\(\vec{x}\)の\(i\)番目の特城量を持ち、\(\vec{z}_{-i}\)はランダムな䟋\(\vec{x}'\)の\(i\)番目の特城量を持ちたす。\(M\)はこの倀に察しお良いサンプルを埗るために十分に倧きく遞ばれたす。[vStrumbeljK14]は\(M\)の遞択方法に関する指針を瀺しおいたすが、基本的には蚈算可胜で劥圓な範囲で倧きな\(M\)を遞択したす。この近䌌の䞀぀の倉曎点は、期埅倀ずきには\(\phi_0\)ず衚蚘されるを衚す明瀺的な項を䜿っおいるこずで、「完党性」を有する方皋匏は次のようになりたす。

(12.10)¶\[\begin{equation} \sum_i \phi_i(\vec{x}) = \hat{f}(\vec{x}) - E[\hat{f}(\vec{x})] \end{equation}\]

期埅倀を\phi_0\(ずしお明瀺的に含める堎合、それは\)\vec{x}$に䟝存したせん。

(12.11)¶\[\begin{equation} \phi_0 + \sum_{i=1} \phi_i(\vec{x}) = \hat{f}(\vec{x}) \end{equation}\]

この効率的な近䌌方法、匷力な理論、モデル非䟝存生により、シャヌプレむ倀は予枬倀に察する特城量重芁床を蚘述するのに優れた遞択肢ずなりたす。

12.3. Notebookの実行¶

このペヌゞ䞊郚の    を抌すず、このノヌトブックがGoogle Colab.で開かれたす。必芁なパッケヌゞのむンストヌル方法に぀いおは以䞋を参照しおください。

import haiku as hk
import jax
import tensorflow as tf
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import urllib
from functools import partial
from jax.example_libraries import optimizers as opt
import dmol
np.random.seed(0)
tf.random.set_seed(0)

ALPHABET = [
    "-",
    "A",
    "R",
    "N",
    "D",
    "C",
    "Q",
    "E",
    "G",
    "H",
    "I",
    "L",
    "K",
    "M",
    "F",
    "P",
    "S",
    "T",
    "W",
    "Y",
    "V",
]

たず、アミノ酞配列ずone-hotベクトルを盞互に倉換する関数を定矩したしょう。

def seq2array(seq, L=200):
    return np.pad(list(map(ALPHABET.index, seq)), (0, L - len(seq))).reshape(1, -1)


def array2oh(a):
    a = np.squeeze(a)
    o = np.zeros((len(a), 21))
    o[np.arange(len(a)), a] = 1
    return o.astype(np.float32).reshape(1, -1, 21)


urllib.request.urlretrieve(
    "https://github.com/whitead/dmol-book/raw/master/data/hemolytic.npz",
    "hemolytic.npz",
)
with np.load("hemolytic.npz", "rb") as r:
    pos_data, neg_data = r["positives"], r["negatives"]

12.4. 特城量重芁床の䟋¶

ペプチドが赀血球を砎壊するかどうか溶血性を予枬するペプチド予枬タスクで、特城量重芁床法の䟋を芋おみたしょう。 これはStandard Layersの溶解床予枬の䟋に䌌おいたす。デヌタは[BW21]を利甚したす。 モデルはペプチド配列䟋DDFRDを取り蟌み、そのペプチドが溶血性である確率を予枬したす。 ここでの特城量重芁床法の目暙は、どのアミノ酞が溶血掻性に最も重芁であるかを特定するこずです。 䞋の閉じたセルはデヌタをロヌドし凊理しおデヌタセットにしたす。

# create labels and stich it all into one
# tensor
labels = np.concatenate(
    (
        np.ones((pos_data.shape[0], 1), dtype=pos_data.dtype),
        np.zeros((neg_data.shape[0], 1), dtype=pos_data.dtype),
    ),
    axis=0,
)
features = np.concatenate((pos_data, neg_data), axis=0)
# we now need to shuffle before creating TF dataset
# so that our train/test/val splits are random
i = np.arange(len(labels))
np.random.shuffle(i)
labels = labels[i]
features = features[i]
L = pos_data.shape[-2]

# need to add token for empty amino acid
# dataset just has all zeros currently
features = np.concatenate((np.zeros((features.shape[0], L, 1)), features), axis=-1)
features[np.sum(features, -1) == 0, 0] = 1.0

batch_size = 16
full_data = tf.data.Dataset.from_tensor_slices((features.astype(np.float32), labels))

# now split into val, test, train
N = pos_data.shape[0] + neg_data.shape[0]
split = int(0.1 * N)
test_data = full_data.take(split).batch(batch_size)
nontest = full_data.skip(split)
val_data, train_data = nontest.take(split).batch(batch_size), nontest.skip(
    split
).shuffle(1000).batch(batch_size)

JaxHaikuを䜿甚で畳み蟌みモデルを再構築し、募配をもう少し簡単に扱えるようにしたす。たたその他にもいく぀かモデルに倉曎を加えおいたす。畳み蟌みに加えお、配列の長さずアミノ酞の割合も远加情報ずしお枡しおいたす。

def binary_cross_entropy(logits, y):
    """Binary cross entropy without sigmoid. Works with logits directly"""
    return (
        jnp.clip(logits, 0, None) - logits * y + jnp.log(1 + jnp.exp(-jnp.abs(logits)))
    )


def model_fn(x):
    # get fractions, excluding skip character
    aa_fracs = jnp.mean(x, axis=1)[:, 1:]
    # compute convolutions/poolings
    mask = jnp.sum(x[..., 1:], axis=-1, keepdims=True)
    for kernel, pool in zip([5, 3, 3], [4, 2, 2]):
        x = hk.Conv1D(16, kernel)(x) * mask
        x = jax.nn.tanh(x)
        x = hk.MaxPool(pool, pool, "VALID")(x)
        mask = hk.MaxPool(pool, pool, "VALID")(mask)
    # combine fractions, length, and convolution ouputs
    x = jnp.concatenate((hk.Flatten()(x), aa_fracs, jnp.sum(mask, axis=1)), axis=1)
    # dense layers. no bias, so zeros give P=0.5
    logits = hk.Sequential(
        [
            hk.Linear(256, with_bias=False),
            jax.nn.tanh,
            hk.Linear(64, with_bias=False),
            jax.nn.tanh,
            hk.Linear(1, with_bias=False),
        ]
    )(x)
    return logits


model = hk.without_apply_rng(hk.transform(model_fn))


def loss_fn(params, x, y):
    logits = model.apply(params, x)
    return jnp.mean(binary_cross_entropy(logits, y))


@jax.jit
def hemolytic_prob(params, x):
    logits = model.apply(params, x)
    return jax.nn.sigmoid(jnp.squeeze(logits))


@jax.jit
def accuracy_fn(params, x, y):
    logits = model.apply(params, x)
    return jnp.mean((logits >= 0) * y + (logits < 0) * (1 - y))
rng = jax.random.PRNGKey(0)
xi, yi = features[:batch_size], labels[:batch_size]
params = model.init(rng, xi)

opt_init, opt_update, get_params = opt.adam(1e-2)
opt_state = opt_init(params)


@jax.jit
def update(step, opt_state, x, y):
    value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state), x, y)
    opt_state = opt_update(step, grads, opt_state)
    return value, opt_state
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
epochs = 32
for e in range(epochs):
    avg_v = 0
    for i, (xi, yi) in enumerate(train_data):
        v, opt_state = update(i, opt_state, xi.numpy(), yi.numpy())
        avg_v += v
opt_params = get_params(opt_state)


def predict(x):
    return jnp.squeeze(model.apply(opt_params, x))


def predict_prob(x):
    return hemolytic_prob(opt_params, x)

もし、コヌドを远うのが倧倉でも倧䞈倫ですこの章のゎヌルはモデルの説明を埗る方法を瀺すこずであり、必ずしもモデルを構築する方法ではありたせん。ですから、次の数行に泚目しおください。ここでは、予枬倀を埗おそれを説明するために、どのようにモデルを䜿うかを説明したす。モデルは、logitsの堎合は predict(x) 、確率の堎合は predict_prob を介しお呌び出されたす。

アミノ酞配列、ペプチドを詊しお、モデルのむメヌゞを぀かんでみたしょう。モデルはlogitsオッズの察数を出力し、これをシグモむド関数にかけるず確率が埗られたす。ペプチドは配列からone-hotベクトルの行列に倉換する必芁がありたす。ここでは、2぀の既知の配列を詊しおみたしょう。Qは溶血生残基ずしおよく知られおおり、2番目の配列はポリGで、これは最も単玔なアミノ酞です。

s = "QQQQQ"
sm = array2oh(seq2array(s))
p = predict_prob(sm)
print(f"Probability {s} of being hemolytic {p:.2f}")

s = "GGGGG"
sm = array2oh(seq2array(s))
p = predict_prob(sm)
print(f"Probability {s} of being hemolytic {p:.2f}")
Probability QQQQQ of being hemolytic 1.00
Probability GGGGG of being hemolytic 0.00

これは劥圓な結果に芋えたす。モデルの出力は、これらの2぀の配列に関する私たちの盎感に䞀臎しおいたす。

さお、このモデルの正確床を蚈算したしょう。非垞に良い結果が埗られたす。

acc = []
for xi, yi in test_data:
    acc.append(accuracy_fn(opt_params, xi.numpy(), yi.numpy()))
print(jnp.mean(np.array(acc)))
0.95208335

12.4.1. 募配¶

では、なぜある配列が溶血性ずなるのか調べおみたしょう。たず、入力に察する募配を蚈算するこずから始めたす。これは玠朎なアプロヌチでshattered gradientsになりやすいですが、この埌蚈算するintegrated gradientsずsmooth gradientsのプロセスの䞀郚であり、無駄ではありたせん。それでは溶血性であるこずが知られおいる、より耇雑なペプチド配列を䜿っお、より興味深い解析をしおみたしょう。

def plot_grad(g, s, ax=None):
    # g = np.array(g)
    if ax is None:
        plt.figure()
        ax = plt.gca()
    if len(g.shape) == 3:
        h = g[0, np.arange(len(s)), list(map(ALPHABET.index, s))]
    else:
        h = g
    ax.bar(np.arange(len(s)), height=h)
    ax.set_xticks(range(len(s)))
    ax.set_xticklabels(s)
    ax.set_xlabel("Amino Acid $x_i$")
    ax.set_ylabel(r"Gradient $\frac{\partial \hat{f}(\vec{x})}{\partial x_i}$")
s = "RAGLQFPVGRLLRRLLRRLLR"
sm = array2oh(seq2array(s))
p = predict_prob(sm)
print(f"Probability {s} of being hemolytic {p:.2f}")
Probability RAGLQFPVGRLLRRLLRRLLR of being hemolytic 1.00

コヌドは至っおシンプルです。それでは募配を蚈算しおみたしょう。

gradient = jax.grad(predict, 0)
g = gradient(sm)
plot_grad(g, s)
../_images/xai_24_0.png

モデルはlogitsを出力するこずを忘れないでください。募配が正の倀であれば、そのアミノ酞が溶血性である確率を高め、負の倀であればそのアミノ酞配列が非溶血性である確率を高めるこずを意味したす。興味深いこずに、ロむシン(L)ずアルギニン(R)は匷い䜍眮䟝存性があるこずがわかりたす。

12.4.2. Integrated Gradients¶

次にintegrated gradients法を実装したす。3぀の基本的なステップを螏みたす。

  1. ベヌスラむンから入力ペプチドに向かう入力配列経路を䜜りたす。

  2. 各入力に察しお募配を蚈算したす。

  3. 募配の合蚈を蚈算し、それにベヌスラむンずペプチドの差を乗じたす。

ベヌスラむンはすべお0であり、0.5の確率を䞎えたす (logits = 0, a model root)。このベヌスラむンはたさに決定境界䞊にありたす。すべおのグリシンやアラニンのように他のベヌスラむンを䜿うこずもできたすが、それらは0.5の確率かそれに近いものであるべきです。ベヌスラむン遞択の詳现ずむンタラクティブな探玢は[SLL20]を芋おください。

def integrated_gradients(sm, N):
    baseline = jnp.zeros((1, L, 21))
    t = jnp.linspace(0, 1, N).reshape(-1, 1, 1)
    path = baseline * (1 - t) + sm * t

    def get_grad(pi):
        # compute gradient
        # add/remove batch axes
        return gradient(pi[jnp.newaxis, ...])[0]

    gs = jax.vmap(get_grad)(path)
    # sum pieces (Riemann sum), multiply by (x - x')
    ig = jnp.mean(gs, axis=0, keepdims=True) * (sm - baseline)
    return ig


ig = integrated_gradients(sm, 1024)
plot_grad(ig, s)
../_images/xai_27_0.png

䜍眮䟝存性がより顕著になり、アルギニンは䜍眮に察しお非垞に敏感であるこずがわかりたす。先皋の通垞の募配ず比范するず、質的な倉化はあたりありたせん。

12.4.3. SmoothGrad¶

SmmotGradを行うステップはintegrated gradientsずほずんど同じです。

  1. 入力ペプチドにランダムな摂動を加えた入力配列経路を䜜成したす。

  2. 各入力の募配を蚈算したす。

  3. 募配の平均を蚈算したす。

ハむパヌパラメヌタずしお、\(\sigma\)があり、原理的にはモデルの出力を倉化させながら、できるだけ小さくする必芁がありたす。

def smooth_gradients(sm, N, rng, sigma=1e-3):
    baseline = jnp.zeros((1, L, 21))
    t = jax.random.normal(rng, shape=(N, sm.shape[1], sm.shape[2])) * sigma
    path = sm + t
    # remove examples that are negative and force summing to 1
    path = jnp.clip(path, 0, 1)
    path /= jnp.sum(path, axis=2, keepdims=True)

    def get_grad(pi):
        # compute gradient
        # add/remove batch axes
        return gradient(pi[jnp.newaxis, ...])[0]

    gs = jax.vmap(get_grad)(path)
    # mean
    ig = jnp.mean(gs, axis=0, keepdims=True)
    return ig


sg = smooth_gradients(sm, 1024, jax.random.PRNGKey(0))
plot_grad(sg, s)
../_images/xai_29_0.png

通垞の募配の結果に酷䌌しおいるように芋えたす。これは、1次元の入力ず浅いネットワヌクが、shattered gradientsに察しおそれほど敏感ではないためず思われたす。

12.4.4. シャヌプレむ倀¶

次に、匏10.9を䜿っお各特城量に察するシャヌプレむ倀を近䌌しおみたしょう。シャヌプレむ倀の蚈算は募配を必芁ずしないため、これたでのアプロヌチずは異なりたす。基本的なアルゎリズムは次のようになりたす。

  1. ランダムな点x’を遞択したす。

  2. xずx’を組み合わせお点\(z\)を䜜りたす。

  3. 予枬関数の倉化を蚈算したす。

効率化のために行った1぀の工倫は、パディングで配列を倉曎しないようにするこずです。基本的に配列を長くするようなこずはしないようにしおいたす。

def shapley(i, sm, sampled_x, rng, model):
    M, F, *_ = sampled_x.shape
    z_choice = jax.random.bernoulli(rng, shape=(M, F))
    # only swap out features within length of sm
    mask = jnp.sum(sm[..., 1:], -1)
    z_choice *= mask
    z_choice = 1 - z_choice
    # construct with and w/o ith feature
    z_choice = z_choice.at[:, i].set(0.0)
    z_choice_i = z_choice.at[:, i].set(1.0)
    # select them via multiplication
    z = sm * z_choice[..., jnp.newaxis] + sampled_x * (1 - z_choice[..., jnp.newaxis])
    z_i = sm * z_choice_i[..., jnp.newaxis] + sampled_x * (
        1 - z_choice_i[..., jnp.newaxis]
    )
    v = model(z_i) - model(z)
    return jnp.squeeze(jnp.mean(v, axis=0))


# assume data is alrady shuffled, so just take M
M = 4096
sl = len(s)
sampled_x = train_data.unbatch().batch(M).as_numpy_iterator().next()[0]
# make batched shapley so we can compute for all features
bshapley = jax.vmap(shapley, in_axes=(0, None, None, 0, None))
sv = bshapley(
    jnp.arange(sl),
    sm,
    sampled_x,
    jax.random.split(jax.random.PRNGKey(0), sl),
    predict,
)

# compute global expectation
eyhat = 0
for xi, yi in full_data.batch(M).as_numpy_iterator():
    eyhat += jnp.mean(predict(xi))
eyhat /= len(full_data)

シャヌプレむ倀に関する䞀぀の良いチェックは、それらの合蚈がモデル関数の倀からすべおのむンスタンスにわたる期埅倀を匕いたものに等しいこずを確認するこずです。ただし [vStrumbeljK14]の匏を近䌌しお䜿甚しおいるので、完党な䞀臎は期埅できたせん。この倀は次のように蚈算されたす。

print(np.sum(sv), predict(sm))
6.7373457 8.068422

予想通り、いくらか違いたす。これは今回䜿っおいる近䌌法の圱響です。サンプル数がシャヌプレむ倀の合蚈にどのように圱響するかを調べるこずで、それを確認するこずができたす。

../_images/xai_33_0.png

Fig. 12.1 シャヌプレむ倀近䌌における、シャヌプレむ倀の総和ずサンプル数の関数の間数倀の比范¶

埐々に収束しおいたす。最埌に、個々のシャヌプレむ倀を衚瀺しお芋たしょう。それが予枬に察する説明ずなりたす。

plot_grad(sv, s)
../_images/xai_37_0.png

ここたでに芋おきた4぀の手法、募配法、Integrated Gradient法、SmoothGrad法、シャヌプレむ倀の結果を䞊べお瀺したす。

heights = []
plt.figure(figsize=(12, 4))
x = np.arange(len(s))
for i, (gi, l) in enumerate(zip([g, ig, sg], ["Gradient", "Integrated", "Smooth"])):
    h = gi[0, np.arange(len(s)), list(map(ALPHABET.index, s))]
    plt.bar(x + i / 5 - 1 / 4, h, width=1 / 5, edgecolor="black", label=l)
plt.bar(x + 3 / 5 - 1 / 4, sv, width=1 / 5, edgecolor="black", label="Shapley")
ax = plt.gca()
ax.set_xticks(range(len(s)))
ax.set_xticklabels(s)
ax.set_xlabel("Amino Acid $x_i$")
ax.set_ylabel(r"Importance [logits]")
plt.legend()
plt.show()
../_images/xai_39_0.png

普段からペプチドを扱っおいる者ずしお、ここではシャヌプレむ倀が最も正確だず思いたす。LずRのパタヌンが重芁だずは考えおいたせんでしたが、シャヌプレむ倀はそう瀺しおいたす。たた他の手法の結果ず異なり、シャヌプレむ倀はフェニルアラニン(F)が重芁な効果を持぀ずは瀺しおいたせん。

この結果から䜕を結論づけるこずができるでしょうか。おそらく次のような説明を加えるこずができるでしょう。「この配列は、䞻にグルタミン、プロリン、そしおロむシンずアルギニンの配列によっお溶血性であるず予枬されおいたす」。

12.5. 特城量重芁床は䜕のためにあるのか¶

特城量重芁床は、実甚的な予枬や掞察を䞎える明確な説明に぀ながるこずはほずんどありたせん。因果関係がないため、存圚しない特城量の説明に意味を芋出すこずに繋がりかねたせん[CK18]。もう䞀぀の泚意点は、実際の化孊物質の䜓系ではなく、モデルを説明しおいるずいうこずです。䟋えば、「溶血掻性は5番目のグルタミンによるものです」ず解釈するのは避けたしょう。代わりに「モデルは5番目にグルタミンが䜍眮するため溶血掻性であるず予枬したした」ずしおください。

実甚的な説明は、特城量をどのように倉えれば結果に圱響するかを瀺すもので、結果の原因を知っおいるこずに䌌おいたす。したがっお、䞊述の理由から、特城量重芁床に説明性があるかどうかに぀いおは議論が続いおいたす[Lip18]。参考たでに、特城量重芁床を人の抂念に結び぀けようずする研究分野は、testing with concept activation vectorsTCAV[KWG+18]ず呌ばれおいたす。ちなみに私自身はXAIのために特城量重芁床をあたり䜿っおいたせん。それは、説明が実甚的でも因果関係を瀺すものでもなく、しばしば他の混乱を招くからです。

12.6. 孊習デヌタの重芁床¶

もう䞀぀の私たちが期埅する説明や解釈は、どの孊習デヌタ点が予枬に最も貢献しおいるかずいうこずです。これは次の質問に察するより盎接的な回答になりたす。「なぜ私のモデルはこれを予枬したのでしょうか」。ニュヌラルネットワヌクは孊習デヌタの結果であり、なぜその予枬がなされたのかに察する答えは孊習デヌタを蟿るこずで埗られたす。ある予枬に察しお孊習デヌタをランク付けするこずで、どの孊習デヌタ点がニュヌラルネットワヌクの予枬に圱響を䞎えおいるのかに関する掞察を埗るこずができたす。これは圱響関数\(\mathcal{I}(x_i, x)\)のようであり、孊習デヌタ点\(i\)ず入力\(x\)に察する圱響床スコアを䞎えたす。圱響床を蚈算する最も簡単な方法は、ニュヌラルネットワヌクに\(x_i\)がある堎合぀たり\(\hat{f}(x)\)ずない堎合぀たり\(\hat{f}_{-x_i}(x)\)を孊習しお、圱響床を以䞋のように定矩したす。

(12.12)¶\[\begin{equation} \mathcal{I}(x_i, x) = \hat{f}_{-x_i}(x) - \hat{f}(x) \end{equation}\]

䟋えば、孊習デヌタから孊習デヌタ点\(x_i\)を陀いた埌、予枬倀が高くなれば、その点は正の圱響力を持っおいるずいうこずになりたす。この圱響関数の蚈算は通垞デヌタ点の数だけモデルを繰り返し孊習する必芁がありたすが、通垞は蚈算できたせん。 [KL17] show a way to approximate this by looking at infinitesimal changes to the weights of each training point. これらの圱響関数を蚈算するには損倱関数に関するHessianを蚈算する必芁があるため、䞀般的には䜿甚されたせん。しかし、JAXを䜿っおいる堎合は、その蚈算を簡単に行うこずができたす。

孊習デヌタの重芁床はディヌプラヌニングの専門家にずっお有甚な解釈を提䟛したす。ある予枬に察しおどの孊習デヌタ点が最も圱響力を持っおいるのかを教えおくれたす。これはデヌタに関する問題に察凊する堎合や停陜性に察する説明を蟿るのに圹立ちたす。しかし、ディヌプラヌニングモデルの予枬結果を利甚する䞀般の利甚者は、おそらく孊習デヌタのランク付けだけでは満足しないでしょう。

12.7. サロゲヌトモデル¶

解釈可胜性におけるより䞀般的な考え方の䞀぀は、解釈可胜なモデルをブラックボックスモデルに特定の䟋の近傍で適合させるこずでしょう。なぜなら、解釈可胜なモデルはたいおい倧域的にブラックボックスモデルに適合させるこずはできないからです。そうでなければ、最初から解釈可胜なモデルを䜿い、ブラックボックスモデルは䜿わないでしょう。しかし、解釈可胜なモデルは興味ある䟋の呚蟺の小さな領域にだけ圓おはめれば、局所的に正しい解釈可胜なモデルを䜿っお説明を䞎えるこずができたす。この解釈可胜なモデルをロヌカルサロゲヌトモデルず呌びたす。解釈可胜なロヌカルサロゲヌトモデルには、決定朚、線圢モデル、簡朔な説明のためのスパヌス線圢モデル、ナむヌブベむズ分類噚などがありたす。

ロヌカルサロゲヌトモデルずしお䞀般的に知られおいるアルゎリズムはLocal Interpretable Model-Agnostic Explanations (LIME) [RSG16a]ず呌ばれおいたす。LIMEは、元のブラックボックスモデルを孊習させた損倱関数を利甚しお、ロヌカルサロゲヌトモデルを興味ある䟋の近傍にフィットさせたす。ロヌカルサロゲヌトモデルの損倱関数は、サロゲヌトモデルを回垰する際に、興味ある䟋に近い点を評䟡するよう重み付けされたす。LIMEの論文では、サロゲヌトモデルのスパヌス化を衚蚘に含めおいたすが、それはロヌカルサロゲヌトモデルの特性ではないため、ここでは䞀旊省きたす。よっお、サロゲヌトモデルの損倱は次のように定矩されたす。

(12.13)¶\[\begin{equation} \mathcal{l^s}\left(x'\right) = w(x', x)\mathcal{l}\left(\hat{f}_s(x'), \hat{f}(x')\right) \end{equation}\]

\(w(x', x)\)は興味ある䟋\(x\)の近くにある点に重みを付ける重みカヌネル関数、 \(\mathcal{l}(\cdot,\cdot)\)は元のブラックボックスモデルの損倱、\(\hat{f}(\cdot)\) はブラックボックスモデル、\(\hat{f}_s(\cdot)\)はサロゲヌトモデルを衚したす。

重み関数は少しアドホックです。぀たりデヌタ型に䟝存したす。スカラヌラベルの回垰タスクでは、カヌネル関数を䜿いたすが、様々な遞択肢がありたす。ガりシアン、コサむン、゚パネチコフなどです。テキストデヌタでは、LIMEの実装では、ハミング距離を䜿っおいたす。これは単に2぀の文字列の間で䞀臎しないテキストトヌクンの数をカりントするものです。画像もハミング距離を䜿いたすが、スヌパヌピクセルは䟋ず同じか空癜ずしたす。

点 \(x'\)はどのように生成されるのでしょうか。連続倀の堎合、\(x'\)は䞀様にサンプリングされたすが、特城空間はしばしば閉じおいないため、これは非垞に難しいこずです。重み付き関数に埓っお\(x'\)をサンプリングし、重み付けを省略すれば、それに埓っおサンプリングされたので閉じおいない特城空間のような問題を避けるこずができたす。䞀般に、連続ベクトル特城空間では、LIMEは少し䞻芳的です。画像やテキストの堎合、\(x'\)はトヌクン単語をマスキングする、スヌパヌピクセルをれロ化黒化するこずにより圢成されたす。これは、シャヌプレむ倀にかなり近い説明ずなるはずで、実際、LIMEがシャヌプレむ倀ず同等であるこずを、いく぀かの小さな衚蚘法の倉曎で瀺すこずができたす。

12.8. 反実仮想¶

反実仮想は最適化問題の解です。\(x\)ず異なるラベルを持ち、\(x\)にできるだけ近い䟋\(x'\)を芋぀けたす[WMR17]。これは次のように定匏化できたす。

(12.14)¶\[\begin{split}\textrm{minimize}\qquad d(x, x')\\ \textrm{such that}\qquad \hat{f}(x) \neq \hat{f}(x')\end{split}\]

\(\hat{f}(x)\)がスカラヌを出力する回垰問題では、制玄条件を\(\hat{f}(x)\)からある\(\Delta\)だけ離すように修正する必芁がありたす。この最適化問題を満たす\(x'\) は反実仮想発生しなかった条件、異なる結果を導いたであろう条件ず呌ばれたす。通垞、\(x'\)を求めるこずは、無埮分最適化ずしお扱われたす。\(\frac{\partial \hat{f}}{\partial x'}\)を蚈算しお制玄付き最適化したすが、実際にはモンテカルロ最適化のように \(\hat{f}(x) \neq \hat{f}(x')\)たでランダムに\(x\)を摂動させた方が速い堎合がありたす。教垫なし孊習で新しい\(x'\)を提案できる生成モデルを䜿甚するこずもできたす。分子に関する普遍的な反実仮想生成噚に぀いおは[WSW22]を参照しおください。分子のグラフニュヌラルネットワヌクに特化した手法に぀いおは[NB20]を参照しおください。

距離の定矩は、LIMEの説明の䞭でも述べたように、重芁な䞻芳的関心ごずです。分子構造の文脈で䜿われる䞀般的な距離の䟋は、Moragnフィンガヌプリント[RH10]のような分子フィンガヌプリント/蚘述子のタニモト係数たたはJaccard係数です。

反実仮想はシャヌプレむ倀ず比范しお䞀぀欠点がありたす。それは完党な説明を䞎えおはくれないこずです。シャヌプレむ倀は予枬倀の合蚈であり、説明のどのような郚分も芋逃しおいないこずを意味しおいたす。䞀方、反実仮想はできるだけ少ない特城量を倉える距離を最小化するため、予枬に寄䞎しおいる䞀郚の特城量に぀いおの情報を芋逃しおしたうこずがありたす。たたシャヌプレむ倀の利点は実甚的であるこずですが、反実仮想は盎接䜿甚するこずができたす。

12.8.1. 䟋¶

䞊蚘のペプチドの䟋でこのアむデアを玠早く実装するこずができたす。距離はハミング距離ず定矩したす。そしお\(x'\)は䞀぀のアミノ酞眮換です。これを列挙しおラベルの眮換ができるかどうか詊しおみたしょう。たず1回の眮換を行う関数を定矩したす。

def check_cf(x, i, j):
    # copy
    x = jnp.array(x)
    # substitute
    x = x.at[:, i].set(0)
    x = x.at[:, i, j].set(1)
    return predict(x)


check_cf(sm, 0, 0)
DeviceArray(8.552943, dtype=float32)

次に、jnp.meshgridで可胜なすべおの眮換を䜜り、vmapで先ほど定矩した関数を適応したす。ravel()<jax.numpy.ravel>`はむンデックスの配列を䞀次元にするため、耇雑なvmapを行う必芁はありたせん。

ii, jj = jnp.meshgrid(jnp.arange(sl), jnp.arange(21))
ii, jj = ii.ravel(), jj.ravel()
x = jax.vmap(check_cf, in_axes=(None, 0, 0))(sm, ii, jj)

次に、予枬倀が負になったすなわちlogitsが0より小さいアミノ酞眮換をすべお衚瀺したす。

from IPython.core.display import display, HTML

out = ["<tt>"]
for i, j in zip(ii[jnp.squeeze(x) < 0], jj[jnp.squeeze(x) < 0]):
    out.append(f'{s[:i]}<span style="color:red;">{ALPHABET[j]}</span>{s[i+1:]}<br/>')
out.append("</tt>")
display(HTML("".join(out)))
RAGL-FPVGRLLRRLLRRLLR
RAGLQF-VGRLLRRLLRRLLR
RAGLAFPVGRLLRRLLRRLLR
RAGLQFAVGRLLRRLLRRLLR
RAGLCFPVGRLLRRLLRRLLR
RAGLQFCVGRLLRRLLRRLLR
RAGLQFPCGRLLRRLLRRLLR
RAGLIFPVGRLLRRLLRRLLR
RAGLQFIVGRLLRRLLRRLLR
RAGLLFPVGRLLRRLLRRLLR
RAGLQFLVGRLLRRLLRRLLR
RAGLFFPVGRLLRRLLRRLLR
RAGLQFFVGRLLRRLLRRLLR
RAGLQFPFGRLLRRLLRRLLR
RAGLPFPVGRLLRRLLRRLLR
RAGLTFPVGRLLRRLLRRLLR
RAGLWFPVGRLLRRLLRRLLR
RAGLVFPVGRLLRRLLRRLLR
RAGLQFVVGRLLRRLLRRLLR

解釈はいく぀かありたすが、基本的にはグルタミンを疏氎基ず亀換するか、プロリンをV、F、A、Cに眮き換えるこずでペプチドを非溶血性にするずいう解釈です。反実仮想ずしお述べるず、「もしグルタミンを疎氎性アミノ酞に亀換すれば、そのペプチドは非溶血性になるでしょう」ずいうこずになりたす。

12.9. 特定のアヌキテクチャの説明¶

䞊蚘ず同じ原則がGNNにも適甚されたすが、これらのアむデアをグラフ䞊で動䜜するように倉換する最適な方法に぀いおは様々なアむデアがありたす。GNNに特化した解釈可胜性の理論に぀いおは[AZL21]を、GNNで説明を構築するために利甚できる手法に぀いおは[YYGJ20]を参照しおください。

NLPは説明ず解釈を構築するための特別なアプロヌチが存圚するもう䞀぀の分野です。最近の調査ずしお[MRC21]を参照しおください。

12.10. モデル非䟝存的な分子の反実仮想の説明¶

化孊における反実仮想に関連する䞻な課題は(12.14)の埮分を蚈算するこずの難しさです。したがっお、このタスクに焊点を圓おたほずんどの手法は、これたで芋おきたようにモデルのアヌキテクチャに特化しおいたす。Wellawatteら[WSW22]はモデルのアヌキテクチャに関係なく分子に察しおこれを行うMolecular Model Agnostic Counterfactual ExplanationsMMACEずいう方法を導入しおいたす。

MMACE法はexmolパッケヌゞで実装されおいたす。分子ずモデルを䞎えるず、exmolは局所的な反実仮想的説明を生成するこずができたす。MMACE法には2぀の䞻芁なステップがありたす。たず、䞎えられた基本分子を䞭心に局所的な化孊空間を展開したす。次に、各サンプル点に、ナヌザが指定したモデルアヌキテクチャのラベルを付けたす。これらのラベルは、局所的な化孊空間における反実仮想を特定するために䜿甚されたす。MMACE法はモデル非䟝存的で、exmolパッケヌゞは分類ず回垰の䞡方のタスクに察しお反実仮想を生成するこずができたす。

それでは、exmolを䜿っおどのように反実仮想を生成するのか芋おみたしょう。この䟋では、分子の臚床毒性を予枬するランダムフォレストモデルを孊習したす。この二倀分類タスクでは、MoleculeNetグルヌプ[WRF+18]が発衚したClassification章で䜿甚したのず同じデヌタセットを䜿いたす。

12.11. Notebookの実行¶

䞊の    をクリックしおむンタラクティブなGoogle Colabでこのペヌゞを開始したしょう。ご自身の環境でもGoogle Colabでも、パッケヌゞのむンストヌル関する詳现は以䞋を参照しおください。

import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import rdkit, rdkit.Chem, rdkit.Chem.Draw
from rdkit.Chem.Draw import IPythonConsole
import numpy as np
import mordred, mordred.descriptors
import warnings
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
import exmol

IPythonConsole.ipython_useSVG = True


toxdata = pd.read_csv(
    "https://github.com/whitead/dmol-book/raw/master/data/clintox.csv.gz"
)
# make object that can compute descriptors
calc = mordred.Calculator(mordred.descriptors, ignore_3D=True)
# make subsample from pandas df
molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in toxdata.smiles]

# view one molecule to make sure things look good.
molecules[0]

デヌタをむンポヌトしたら、Mordredパッケヌゞで入力蚘述子を生成したす。

# Get valid molecules from the sample
valid_mol_idx = [bool(m) for m in molecules]
valid_mols = [m for m in molecules if m]
# Compute molecular descriptors using Mordred
features = calc.pandas(valid_mols, quiet=True)
labels = toxdata[valid_mol_idx].FDA_APPROVED
# Standardize the features
features -= features.mean()
features /= features.std()

# we have some nans in features, likely because std was 0
features = features.values.astype(float)
features_select = np.all(np.isfinite(features), axis=0)
features = features[:, features_select]
print(f"We have {len(features)} features per molecule")
We have 1478 features per molecule

この䟋では、Kerasで実装されたシンプルで密なニュヌラルネットワヌク分類噚を䜿甚したす。たず、このシンプルな分類噚を孊習し、それを䜿っおexmolの反実仮想のラベルを生成したす。孊習枈みモデルの性胜を改善するこずで、より正確な結果を期埅するこずができたすが、exmolの仕組みを理解するには、今のずころ以䞋の䟋で十分です。

# Train and test spit
X_train, X_test, y_train, y_test = train_test_split(
    features, labels, test_size=0.2, shuffle=True
)
ft_shape = X_train.shape[-1]

# reshape data
X_train = X_train.reshape(-1, ft_shape)
X_test = X_test.reshape(-1, ft_shape)

それではモデルを構築しお実行しおみたしょう。ディヌプラヌニングの抂芁 章に密なモデルに関する詳しいむントロダクションがありたす。

model = tf.keras.models.Sequential()
model.add(tf.keras.Input(shape=(ft_shape,)))
model.add(tf.keras.layers.Dense(32, activation="relu"))
model.add(tf.keras.layers.Dense(32))
model.add(Dense(1, activation="sigmoid"))
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
# Model training
model.fit(X_train, y_train, epochs=50, batch_size=32, verbose=0)
_, accuracy = model.evaluate(X_test, y_test)
print(f"Model accuracy: {accuracy*100:.2f}%")
 1/10 [==>...........................] - ETA: 1s - loss: 0.1194 - accuracy: 0.9688

10/10 [==============================] - 0s 1ms/step - loss: 0.2520 - accuracy: 0.9392
Model accuracy: 93.92%

私たちが䜜ったモデルの正確床は良さそうですね

次に、SMILES及び/たたはSEFLIESの分子衚珟を取り蟌み、孊習枈み分類噚からラベルの予枬を出力するラッパヌ関数を曞きたす。SELFIESの詳しい説明はDeep Learning on Sequencesの章にありたす。このラッパヌ関数は exmolのexmol.sample_spaceに入力ずしお䞎えられ、䞎えられたベヌスずなる分子の呚りに局所的な化孊空間を䜜りたす。exmolは、Superfast Traversal, Optimization, Novelty, Exploration and Discovery (STONED)アルゎリズム[NPK+21]を生成アルゎリズムずしお䜿甚しお、局所空間を拡匵しおいきたす。ベヌスずなる分子が䞎えられるず、STONEDアルゎリズムは分子のSELFIES衚珟をランダムに倉異させたす。これらの倉異は文字列眮換、挿入、欠損です。

def model_eval(smiles, selfies):
    molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in smiles]
    features = calc.pandas(molecules)
    features = features.values.astype(float)
    features = features[:, features_select]
    labels = np.round(model.predict(np.nan_to_num(features).reshape(-1, ft_shape)))
    return labels

次に、STONEDを䜿っお、exmol.sample_spaceで局所的な化孊空間をサンプリングしおみたす。この䟋では、匕数num_samplesでサンプル空間の倧きさを倉曎したす。ここで遞択したベヌスずなる分子はFDA非承認分子です。

space = exmol.sample_space("C1CC(=O)NC(=O)C1N2CC3=C(C2=O)C=CC=C3N", model_eval);

いったんサンプル空間を䜜成したら、exmol.sample_space関数を䜿っお局所的な化孊空間の反実仮想を特定できたす。各反実仮想は、付加情報を含むpythonのdataclassです。

exps = exmol.cf_explain(space, 2)
exps[1]
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In [32], line 2
      1 exps = exmol.cf_explain(space, 2)
----> 2 exps[1]

IndexError: list index out of range

生成された反実仮想はexmolのプロットコヌドexmol.sample_spaceずexmol.sample_spaceを䜿っお簡単に可芖化できたす。ベヌスず反実仮想の間の類䌌床はECFP4フィンガヌプリントのタニモト係数です。䞊䜍3぀の反実仮想をここに瀺したす。

exmol.plot_cf(exps, nrows=1)

ここで遞択したベヌスずなる分子はFDA非承認です。生成された反実仮想を芋るず、耇玠環匏基は毒性に圱響を䞎えるず結論づけるこずができたす。したがっお、我々のモデルによるず、耇玠環匏基を倉曎するこずでベヌスずなる分子は非毒性化されるかもしれたせん。このこずは、反実仮想の説明がどのように修正を加えるこずができるかに぀いおの実甚的な掞察を䞎える理由も瀺しおいたす。

最埌に、生成した化孊空間も可芖化しおみたしょう

exmol.plot_space(space, exps)

12.12. たずめ¶

  • ディヌプラヌニングモデルの解釈は、モデルの正確性を保蚌し、予枬を人にずっお有甚なものにするために必芁䞍可欠です。法什順守のために芁求されるこずもありたす。

  • ニュヌラルネットワヌクの解釈可胜性は、より広範なトピックであるAIにおける説明可胜性XAIの䞀郚であり、このトピックはただ初期段階です。

  • 説明はただ定矩が曖昧ですが、倚くの堎合、モデルの特城量で衚珟されたす。

  • 説明の戊略ずしおは、特城量重芁床、孊習デヌタの重芁床、反実仮想、局所的に正確なサロゲヌトモデルなどがありたす。

  • ほずんどの説明は䟋ごずに掚論時に生成されたす。

  • 最も䜓系的ですが、蚈算コストのかかる説明はシャヌプレむ倀です。

  • 反実仮想は最も盎感的で満足のいく説明を提䟛したすが、完党な説明にはならないかもしれないずいう意芋がありたす。

  • exmolはモデル非䟝存的な分子の反実仮想の説明を生成する゜フトりェアです。

12.13. Cited References¶

WRF+18

Zhenqin Wu, Bharath Ramsundar, Evan N Feinberg, Joseph Gomes, Caleb Geniesse, Aneesh S Pappu, Karl Leswing, and Vijay Pande. Moleculenet: a benchmark for molecular machine learning. Chemical science, 9(2):513–530, 2018.

LS04

John D Lee and Katrina A See. Trust in automation: designing for appropriate reliance. Human factors, 46(1):50–80, 2004.

DVK17(1,2)

Finale Doshi-Velez and Been Kim. Towards a rigorous science of interpretable machine learning. arXiv preprint arXiv:1702.08608, 2017.

GF17

Bryce Goodman and Seth Flaxman. European Union regulations on algorithmic decision-making and a “right to explanation”. AI Magazine, 38(3):50–57, 2017.

Dev19

Organisation for Economic Co-operation and Development. Recommendation of the Council on Artificial Intelligence. 2019. URL: https://legalinstruments.oecd.org/en/instruments/OECD-LEGAL-0449.

CLG+15

Rich Caruana, Yin Lou, Johannes Gehrke, Paul Koch, Marc Sturm, and Noemie Elhadad. Intelligible models for healthcare: Predicting pneumonia risk and hospital 30-day readmission. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 1721–1730. ACM, 2015.

Mil19

Tim Miller. Explanation in artificial intelligence: insights from the social sciences. Artificial intelligence, 267:1–38, 2019.

MSK+19

James W Murdoch, Chandan Singh, Karl Kumbier, Reza Abbasi-Asl, and Bin Yu. Interpretable machine learning: definitions, methods, and applications. eprint arXiv, pages 1–11, 2019. URL: http://arxiv.org/abs/1901.04592.

MSMuller18

Grégoire Montavon, Wojciech Samek, and Klaus-Robert MÃŒller. Methods for interpreting and understanding deep neural networks. Digital Signal Processing, 73:1–15, 2018.

BCB14

Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.

AGFW21

Mehrad Ansari, Heta A Gandhi, David G Foster, and Andrew D White. Iterative symbolic regression for learning transport equations. arXiv preprint arXiv:2108.03293, 2021.

BD00

Lynne Billard and Edwin Diday. Regression analysis for interval-valued data. In Data analysis, classification, and related methods, pages 369–374. Springer, 2000.

UT20

Silviu-Marian Udrescu and Max Tegmark. Ai feynman: a physics-inspired method for symbolic regression. Science Advances, 6(16):eaay2631, 2020.

CSGB+20

Miles Cranmer, Alvaro Sanchez Gonzalez, Peter Battaglia, Rui Xu, Kyle Cranmer, David Spergel, and Shirley Ho. Discovering symbolic models from deep learning with inductive biases. Advances in Neural Information Processing Systems, 33:17429–17442, 2020.

WSW22(1,2,3)

Geemi P Wellawatte, Aditi Seshadri, and Andrew D White. Model agnostic generation of counterfactual explanations for molecules. Chem. Sci., pages –, 2022. URL: http://dx.doi.org/10.1039/D1SC05259D, doi:10.1039/D1SC05259D.

RSG16a(1,2,3)

Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. " why should i trust you?" explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining, 1135–1144. 2016.

RSG16b

Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. Model-agnostic interpretability of machine learning. arXiv preprint arXiv:1606.05386, 2016.

WMR17(1,2,3)

Sandra Wachter, Brent Mittelstadt, and Chris Russell. Counterfactual explanations without opening the black box: automated decisions and the gdpr. Harv. JL & Tech., 31:841, 2017.

KL17(1,2)

Pang Wei Koh and Percy Liang. Understanding black-box predictions via influence functions. In International Conference on Machine Learning, 1885–1894. PMLR, 2017.

SML+21

Wojciech Samek, Grégoire Montavon, Sebastian Lapuschkin, Christopher J. Anders, and Klaus-Robert MÃŒller. Explaining deep neural networks and beyond: a review of methods and applications. Proceedings of the IEEE, 109(3):247–278, 2021. doi:10.1109/JPROC.2021.3060483.

BFL+17

David Balduzzi, Marcus Frean, Lennox Leary, J. P. Lewis, Kurt Wan-Duo Ma, and Brian McWilliams. The shattered gradients problem: if resnets are the answer, then what is the question? In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, 342–350. PMLR, 06–11 Aug 2017. URL: http://proceedings.mlr.press/v70/balduzzi17b.html.

STY17(1,2)

Mukund Sundararajan, Ankur Taly, and Qiqi Yan. Axiomatic attribution for deep networks. In International Conference on Machine Learning, 3319–3328. PMLR, 2017.

STK+17(1,2)

Daniel Smilkov, Nikhil Thorat, Been Kim, Fernanda Viégas, and Martin Wattenberg. Smoothgrad: removing noise by adding noise. arXiv preprint arXiv:1706.03825, 2017.

MBL+19

Grégoire Montavon, Alexander Binder, Sebastian Lapuschkin, Wojciech Samek, and Klaus-Robert MÃŒller. Layer-Wise Relevance Propagation: An Overview, pages 193–209. Springer International Publishing, Cham, 2019. URL: https://link.springer.com/chapter/10.1007%2F978-3-030-28954-6_10.

vStrumbeljK14(1,2,3,4)

Erik Štrumbelj and Igor Kononenko. Explaining prediction models and individual predictions with feature contributions. Knowledge and information systems, 41(3):647–665, 2014.

BW21

Rainier Barrett and Andrew D. White. Investigating active learning and meta-learning for iterative peptide design. Journal of Chemical Information and Modeling, 61(1):95–105, 2021. URL: https://doi.org/10.1021/acs.jcim.0c00946, doi:10.1021/acs.jcim.0c00946.

SLL20

Pascal Sturmfels, Scott Lundberg, and Su-In Lee. Visualizing the impact of feature attribution baselines. Distill, 2020. https://distill.pub/2020/attribution-baselines. doi:10.23915/distill.00022.

CK18

Kangway V Chuang and Michael J Keiser. Comment on “predicting reaction performance in c–n cross-coupling using machine learning”. Science, 362(6416):eaat8603, 2018.

Lip18

Zachary C Lipton. The mythos of model interpretability: in machine learning, the concept of interpretability is both important and slippery. Queue, 16(3):31–57, 2018.

KWG+18

Been Kim, Martin Wattenberg, Justin Gilmer, Carrie Cai, James Wexler, Fernanda Viegas, and others. Interpretability beyond feature attribution: quantitative testing with concept activation vectors (tcav). In International conference on machine learning, 2668–2677. PMLR, 2018.

NB20

Danilo Numeroso and Davide Bacciu. Explaining deep graph networks with molecular counterfactuals. arXiv preprint arXiv:2011.05134, 2020.

RH10

David Rogers and Mathew Hahn. Extended-connectivity fingerprints. Journal of chemical information and modeling, 50(5):742–754, 2010.

AZL21

Chirag Agarwal, Marinka Zitnik, and Himabindu Lakkaraju. Towards a rigorous theoretical analysis and evaluation of gnn explanations. arXiv preprint arXiv:2106.09078, 2021.

YYGJ20

Hao Yuan, Haiyang Yu, Shurui Gui, and Shuiwang Ji. Explainability in graph neural networks: a taxonomic survey. arXiv preprint arXiv:2012.15445, 2020.

MRC21

Andreas Madsen, Siva Reddy, and Sarath Chandar. Post-hoc interpretability for neural nlp: a survey. arXiv preprint arXiv:2108.04840, 2021.

NPK+21

AkshatKumar Nigam, Robert Pollice, Mario Krenn, Gabriel dos Passos Gomes, and Alán Aspuru-Guzik. Beyond generative models: superfast traversal, optimization, novelty, exploration and discovery (stoned) algorithm for molecules using selfies. Chem. Sci., 12:7079–7090, 2021. URL: http://dx.doi.org/10.1039/D1SC00231G, doi:10.1039/D1SC00231G.