äºæž¬ã説æãã
Contents
12. äºæž¬ã説æãã¶
ãã¥ãŒã©ã«ãããã¯ãŒã¯ã®äºæž¬ã¯äžè¬ã«è§£éå¯èœã§ã¯ãããŸããããã®ç« ã§ã¯ãã®äºæž¬ã説æããæ¹æ³ãæ¢ããŸããããã¯èª¬æå¯èœãªAIïŒXAIïŒãšããåºç¯ãªãããã¯ã®äžéšã§ããããªãç¹å®ã®äºæž¬ããªãããã®ããç解ããã«ã®åœ¹ç«ã€ã¯ãã§ããã¢ãã«ã®äºæž¬ãç解ã§ããããã«ãªãã°ãå®çšçãçè«çããã€èŠå¶ã®èŠ³ç¹ã§æ£åœåããããšã«ã€ãªãããŸãããã®ããéèŠãªãããã¯ãšãªã£ãŠããŸãããŸããã®æ ¹æ ãç解ã§ããã°ãã¢ãã«ã®äºæž¬ãå©çšããå¯èœæ§ãé«ããªãããšã瀺ãããŠãã[LS04]ã確ãã«å®çšçã§ããããäžã€ã®å®çšçãªé¢å¿äºã¯ãã¢ãã«ãã©ã®ããã«ãã®äºæž¬ã«è³ã£ãã®ããç解ã§ããã°ãæ¹æ³ãæ£ç¢ºã«å®è£ ããããšãã¯ããã«å®¹æã«ãªããšããããšã§ããéææ§ã®çè«çãªæ£åœæ§ã¯ãã¢ãã«ãã¡ã€ã³ã®äžå®å šæ§ïŒããªãã¡ãå ±å€éã·ããïŒ[DVK17]ãèå¥ããã®ã«åœ¹ç«ã¡ãŸããæè¿ã欧å·é£å[GF17]ãšG20[Dev19]ã¯ãæ©æ¢°äºæž¬ã«é¢ãã説æãæšå¥šãŸãã¯èŠæ±ããã¬ã€ãã©ã€ã³ãæ¡æããŸããããã®ãããããã¯çŸåšãã³ã³ãã©ã€ã¢ã³ã¹ã®åé¡ã«ãªã£ãŠããŠããŸãã欧å·é£åã¯ããã«èžã¿èŸŒãã§ãããå³ããæ³åŸæ¡ãæ€èšããŠããŸãã
èªè å±€ãšç®ç
ãã®ç« ã¯Standard LayersãšDeep Learning on Sequencesã«åºã¥ããŠããŸãããŸããæ¡ä»¶ä»ã確çãå«ã確çè«ã«é¢ããååãªç¥èãããããšãåæãšããŠããŸããããã§ãªãå Žåã¯ãç§ã®ããŒããå ¥éçãªç¢ºçã®ããã¹ããèªãã§æŠèŠãç解ããããšããå§ãããŸãããã®ç« ãèªããšã以äžã®ããšãã§ããããã«ãªããšæ³å®ãããŸãã
ãªã説æãéèŠããæ£åœåã§ãã
æ£åœåã解éã説æã®åºå¥ãã§ãã
ç¹åŸŽééèŠåºŠãšã·ã£ãŒãã¬ã€å€ãèšç®ã§ãã
åå®ä»®æ³ãå®çŸ©ãããããèšç®ã§ãã
ã©ã®ã¢ãã«ã解éå¯èœã§ãã©ã®ããã«è§£éå¯èœãªãµãã²ãŒãã¢ãã«ãé©åãããããç解ãã
説æå¯èœãªAIã®å¿ èŠæ§ã«é¢ããæåãªäŸãšããŠãèºçã§ERã«éã°ããæ£è ã®æ»äº¡ãªã¹ã¯ãè©äŸ¡ããMLäºæž¬åšãæ§ç¯ããCaruanaãã®[CLG+15]ãæããããŸãããã®è«æã®ã¢ã€ãã¢ã¯ãèºçã®æ£è ããã®ããŒã«ã§ã¹ã¯ãªãŒãã³ã°ããã°ãå»åž«ãã©ã®æ£è ãããæ»äº¡ãªã¹ã¯ãé«ãããç¥ãã®ã«åœ¹ç«ã€ãšãããã®ã§ããããã®ã¢ãã«ã¯ããªãæ£ç¢ºã§ãããããã®äºæž¬ã®è§£éã調ã¹ããšãããå»åŠçã«ããããªæšè«ããªãããŠããŸãããé©ãã¹ãããšã«ããã®ã¢ãã«ã¯åæ¯æã¡ã®æ£è ïŒåæ¯æ£è ãšããïŒãèºçã§ERã«æ¥é¢ããå Žåãæ»äº¡ãªã¹ã¯ãäœäžããããšã瀺åããŠããŸãããåæ¯ãšããã®ã¯åŒåžãå°é£ã«ãªãç æ°ã§ããã«ãé¢ããããèºçæ£è ã®æ»äº¡çãäœããªããšããããšãããã£ãã®ã§ãã
ãã®çµæã¯å¶çºçãªãã®ã§ãããåæ¯ã¯å®éã«ã¯èºçã§ã®æ»äº¡ãªã¹ã¯ãé«ãã®ã§ãããå»åž«ã¯ãã®ããšãå®æããŠããããã圌ãã«å¯ŸããŠããç©æ¥µçã§äžå¯§ãªå¯Ÿå¿ãããŠããã®ã§ããããªãã¡ãåæ¯æ£è ã«å¯ŸããŠå»åž«ãããäžå¯§ãªã±ã¢ãšé æ ®ãããŠãããããæ»äº¡è æ°ãå°ãªããªã£ãŠããã®ã§ããçµéšåããèšãã°ãã¢ãã«ã®äºæž¬ã¯æ£ããã§ãããããããããã®ã¢ãã«ãå®çšåãããã°ãåæ¯æ£è ãæ»äº¡ãªã¹ã¯ãäœããšèª€ã£ãŠå€æããåæ¯æ£è ã¯æ¬æ¥åããããã¯ãã®ã±ã¢ãåããããã«åœãèœãšãå¯èœæ§ããããŸããã幞ãã«ãç 究è ã¯ã¢ãã«ã®è§£éå¯èœæ§ã«ãã£ãŠãã®åé¡ãç¹å®ããåæ¯æ£è ãåœã®å±éºã«ããããªã¹ã¯ãåé¿ããããšãã§ããŸããããã®ããã«ã解éã¯åžžã«äºæž¬ã¢ãã«ãæ§ç¯ããéã«èæ ®ãã¹ãã¹ãããã§ããããšãããããŸãã
12.1. 説æãšã¯äœã¶
ããã§ã¯ãMiller[Mil19]ã®èª¬æã®å®çŸ©ã䜿ããŸããMillerã¯ã解éå¯èœæ§ãæ£åœåã説æã以äžã®å®çŸ©ã§åºå¥ããŠããŸãã
解éå¯èœæ§ïŒèŠ³å¯è ãå€æã®åå ãç解ã§ãã床åããæããŸããMillerã¯ããã説æå¯èœæ§ãšå矩ãšã¿ãªããŸãããããã¯äžè¬çã«ã¢ãã«ã®ç¹æ§ã§ãã
æ£åœåïŒã¢ãã«ã®ãã¹ããšã©ãŒãæ£ç¢ºåºŠã®ããã«ããªããã®æ±ºå®ãè¯ãã®ãã«ã€ããŠã®èšŒæ ã説æã§ããããã¯ã¢ãã«ã®ç¹æ§ã§ãã
説æïŒçµæã®æèãšåå ãäžããã人éã察象ãšããæ å ±ã®æ瀺ã§ããããããæ¬ç« ã®äž»èŠãªè«ç¹ã§ããããã¯äžè¬çã«ã¢ãã«ã®ç¹æ§ã§ã¯ãªããæ°ãã«çæããè¿œå çãªæ å ±ã§ãã
説æã®æ§æèŠçŽ ã«ã€ããŠè©³ãã説æããŸããã説æã¯äºæž¬ãæ£åœåããããšãšã¯ç°ãªãããšã«æ³šæããŠãã ãããæ£åœåãšã¯æ¢ã«èŠãããã«ãã¢ãã«ã®äºæž¬ãæ£ç¢ºã§ãããšä¿¡ããã¹ãçç±ã«ã€ããŠã®çµéšçãªèšŒæ ã§ããäžæ¹ã説æãšã¯äºæž¬ã®åå ãæããã«ããããšã§ãããæçµçã«äººã«ç解ãããããšãç®çãšããŠããŸãã
ãã£ãŒãã©ãŒãã³ã°ã¯ããã ãã§ã¯ãã©ãã¯ããã¯ã¹çãªã¢ããªã³ã°ææ³ã§ãã解éå¯èœæ§ã説æå¯èœæ§ããããŸãããéã¿ãã¢ãã«åŒã調ã¹ãŠãããªããã®äºæž¬ããªãããã®ãã«ã€ããŠã®æŽå¯ã¯ã»ãšãã©åŸãããŸããã解éå¯èœæ§ã¯ãã£ãŒãã©ãŒãã³ã°ã«å¯Ÿããè¿œå çãªã¿ã¹ã¯ã§ãããã¢ãã«ã®äºæž¬ã«èª¬æãå ããããšãæå³ããŸãããããããã¯é£ããåé¡ã§ãããã£ãŒãã©ãŒãã³ã°ã®ãã©ãã¯ããã¯ã¹çãªæ§è³ªã«å ããŠãã¢ãã«ã®äºæž¬ã«å¯Ÿãã説æãšã¯äœãã«ã€ããŠã®ã³ã³ã»ã³ãµã¹ãåããŠããªãããã§ã[DVK17]ããã人ã¯è§£éå¯èœæ§ã«åäºæž¬ãæ£åœåããèªç¶èšèªã®èª¬æãæåŸ ããŸããããã人ã¯ã©ã®ç¹åŸŽãäºæž¬ã«æãè²¢ç®ãããã瀺ãã ãã§ååãšèããŸãã
MLã¢ãã«ã®è§£éã«ã¯2ã€ã®ã¢ãããŒãããããŸãã説æã«ããäºåŸè§£éãšèªå·±èª¬æåã¢ãã«[MSK+19]ã§ããèªå·±èª¬æåã¢ãã«ã¯ïŒå°é家ãã¢ãã«ã®åºåãèŠãŠïŒè«ççã«ç¹åŸŽéãšçµã³ã€ããããšãã§ããããã«æ§ç¯ãããŠãããæ¬è³ªçã«è§£éå¯èœã§ãããã ãã¿ã¹ã¯ã¢ãã«ã«åŒ·ãäŸåããŸã[MSMuller18]ã身è¿ãªäŸã§ã¯ãåååååŠãäžç¹éåãšãã«ã®ãŒèšç®ã®ãããªç©çåŠã«åºã¥ãã·ãã¥ã¬ãŒã·ã§ã³ããããŸããåååååŠã®è»è·¡ã調ã¹ãåºåãããæ°å€ãèŠãŠãäŸãã°è¬ç©ååãã¿ã³ãã¯è³ªã«çµåãããšäºæž¬ããçç±ã説æããããšãã§ããŸãã
èªå·±èª¬æåã¢ãã«ã¯ãã£ãŒãã©ãŒãã³ã°ã®è§£éã«ã¯åœ¹ã«ç«ããªã/é¢é£ä»ããªãããã«æãããããããŸãããããããåŸã®ç¯ã§ãèªå·±èª¬æåã®ãµãã²ãŒãã¢ãã« (ãŸãã¯ãããã·ã¢ãã«ïŒãäœãããã£ãŒãã©ãŒãã³ã°ã¢ãã«ãšäžèŽããããã«èšç·Žããã°ããããšããããã§ããããæåãããµãã²ãŒãã¢ãã«ã䜿ããã«ãã£ãŒãã©ãŒãã³ã°ã¢ãã«ãä»ããã®ã¯èšç·Žã³ã¹ããæžãããããã§ãããããã¯ãªãã§ãããããåŠç¿æžã¿ãã¥ãŒã©ã«ãããã¯ãŒã¯ã¯ä»»æã®ç¹ãã©ãã«ä»ãã§ãããããªãã¡åŠç¿ããŒã¿ãç¡éã«çæããããšãã§ããããã§ãããµãã²ãŒãã¢ãã«ã®ä»ã«ãAttentionæ©æ§[BCB14]ã®ããã«èªå·±èª¬æçãªç¹åŸŽãå å ãããã£ãŒãã©ãŒãã³ã°ã¢ãã«ãæ§ç¯ããããšãã§ããŸããAttentionæ©æ§ã«åºã¥ããŠå ¥åç¹åŸŽéãšäºæž¬å€ãçµã³ã€ããããšãã§ããŸãããŸãæ©æ¢°åŠç¿ã«ã¯ã·ã³ããªãã¯ååž°ãšãããã®ããããçŽæ¥è§£éã§ããæ°åŒãæ±ãããšã§èªå·±èª¬æçãªã¢ãã«ãæ§ç¯ããããšããŸã[AGFW21, BD00, UT20]ããã®ç¹æ§ãããã·ã³ããªãã¯ååž°ã¯ãµãã²ãŒãã¢ãã«ãçæããããã«çšããããŸã[CSGB+20]ã
説æã«ããäºåŸè§£éã«ã¯æ§ã ãªã¢ãããŒãããããŸãã代衚çãªãã®ã¯åŠç¿ããŒã¿ã®éèŠåºŠãç¹åŸŽééèŠåºŠãåå®ä»®æ³çãªèª¬æ[WSW22, RSG16a, RSG16b, WMR17]ã§ããããŒã¿ã®éèŠåºŠã«åºã¥ãäºåŸè§£éã®äŸã¯ãäºæž¬ã説æããæã圱é¿åã®ããåŠç¿ããŒã¿ãç¹å®ããããšã§ã[KL17]ãããã«ãã£ãŠèª¬æãã€ããã©ããã¯è°è«ã®äœå°ããããŸãããã©ã®ããŒã¿ãäºæž¬ã«é¢é£ããŠããããç解ããã®ã«åœ¹ç«ã€ããšã¯ç¢ºãã§ããç¹åŸŽééèŠåºŠã¯ããããæãäžè¬çãªXAIã¢ãããŒãã§ãã³ã³ãã¥ãŒã¿ããžã§ã³ã®ç 究ã«é »ç¹ã«ç»å ŽããäŸãã°ç»åã®åé¡ã«ãšã£ãŠæãéèŠãªãã¯ã»ã«ããã€ã©ã€ãããŸãã
åå®ä»®æ³çãªèª¬æã¯äºåŸè§£éã®æ°ããæ¹æ³ã§ããåå®ä»®æ³ã¯èª¬æãšããŠæ©èœããæ°ããããŒã¿ç¹ã§ããåå®ä»®æ³ã¯ããã®ç¹åŸŽéãã©ãã»ã©éèŠã§ææã§ãããã«ã€ããŠã®æŽå¯ãäžããŸããäŸãšããŠèè³ãå§ããã¢ãã«ããããšããŸãããã®ã¢ãã«ã¯ä»¥äžã®åå®ä»®æ³çãªèª¬æãçæããããšãã§ããŸãïŒ[WMR17]ããïŒã
ããªãã¯ã幎åãéµäŸ¿çªå·ãè³ç£ã«åºã¥ããŠããŒã³ãæåŠãããŸããããããããªãã®å¹Žåã45,000ãã«ã§ããã°ãããªãã¯ããŒã³ãæäŸãããã§ãããã
2çªç®ã®æãåå®ä»®æ³ã§ãããç¹åŸŽéãã©ã®ããã«å€ããã°ã¢ãã«ã®çµæã«åœ±é¿ãäžãããã瀺ããŠããŸããåå®ä»®æ³ã¯è€éããšèª¬æåãè¯ããã©ã³ã¹ã§æäŸããŸãã
以äžãåºç¯ãªXAIåéã«é¢ããæŠèŠ³ã§ããã解éå¯èœãªãã£ãŒãã©ãŒãã³ã°ã«ã€ããŠã®æè¿ã®ã¬ãã¥ãŒã¯Samekãã®[SML+21]ãèŠãŠãã ããããŸããã£ãŒãã©ãŒãã³ã°ãå«ã解éå¯èœãªæ©æ¢°åŠç¿ã«é¢ããç¶²çŸ çãªæ å ±ã¯Christopher Molnarããªã³ã©ã€ã³ããã¯ã§å ¬éããŠããŸããäºæž¬èª€å·®ãäºæž¬ã®ä¿¡é Œæ§ã¯æ£åœåã®æå³åãã匷ãã®ã§ããã§ã¯æ±ããŸããããRegression & Model Assessmentã®ææ³ãé©çšã§ããã®ã§åç §ããŠãã ããã
12.2. ç¹åŸŽééèŠåºŠÂ¶
ç¹åŸŽééèŠåºŠã¯ãæ©æ¢°åŠç¿ã¢ãã«ã解éããäžã§æããããããããæãäžè¬çãªæ¹æ³ã§ããç¹åŸŽééèŠåºŠã®åºåã¯åç¹åŸŽéã«å¯Ÿããã©ã³ãã³ã°ãŸãã¯æ°å€ã§ãããéåžžã¯åäžã®äºæž¬ã«å¯Ÿãããã®ã§ããã¢ãã«å šäœã®ç¹åŸŽééèŠåºŠã¯å€§åçç¹åŸŽééèŠåºŠãšåŒã°ããåäžã®äºæž¬ã«å¯ŸããŠã¯å±æçç¹åŸŽééèŠåºŠãšåŒã°ããŸãã倧åçãªç¹åŸŽééèŠåºŠãšè§£éå¯èœæ§ãæã€ããšã¯æ¯èŒçãŸãã§ããæ£ç¢ºãªãã£ãŒãã©ãŒãã³ã°ã¢ãã«ã¯ç¹åŸŽç©ºéã®äœçœ®ã«ãã£ãŠéèŠãªç¹åŸŽéãå€ããããã§ãã
ãŸãã¯ç·åœ¢ã¢ãã«ã§ç¹åŸŽééèŠåºŠãèŠãŠã¿ãŸãããã
ãã㧠\(\vec{x}\)ã¯ç¹åŸŽéãã¯ãã«ã§ããç¹åŸŽééèŠåºŠãè©äŸ¡ããç°¡åãªæ¹æ³ã¯ãç¹å®ã®ç¹åŸŽé\(x_i\)ã«é¢ããéã¿\(w_i\)ãåã«èŠãããšã§ãããã®éã¿\(w_i\)ã¯ãä»ã®ãã¹ãŠã®ç¹åŸŽãäžå®ã§ã\(x_i\)ã1å¢å ããå Žåã«ã©ã®çšåºŠå€åãããã瀺ããŠããŸãããããç¹åŸŽéã®å€§ãããåçšåºŠã§ããã°ããã®æ¹æ³ã¯ç¹åŸŽéã®é äœä»ããšããŠæ©èœããã§ããããããããç¹åŸŽéãåäœãæã€å Žåãåäœã®éžæãšç¹åŸŽéã®çžå¯Ÿçãªå€§ããã«åœ±é¿ãããŸããäŸãã°ãæ°æž©ãææ°ããè¯æ°ã«å€æŽãããå Žåã1床äžæããæã®åœ±é¿ã¯å°ãããªããŸãã
ç¹åŸŽéã®å€§ãããåäœã®åœ±é¿ãæé€ããŠç¹åŸŽééèŠåºŠãè©äŸ¡ããå°ãè¯ãæ¹æ³ã¯ã\(w_i\)ãç¹åŸŽéã®æšæºåå·®ã§å²ãããšã§ããæšæºåå·®ãšã¯ãäºæž¬å€ã®äºä¹èª€å·®ã®ç·åãå差平æ¹åã§å²ã£ãå€ã§ããããªãã¡æšæºåå·®ã¯äºæž¬ã®æ£ç¢ºåºŠãšç¹åŸŽéã®åæ£ã®æ¯ã§ããæšæºåå·®ã§å²ã£ã\(w_i\)ã¯t-ååžãšæ¯èŒã§ããããã\(t\)-çµ±èšéãšåŒã°ããŸãã
ããã§ãNã¯äŸæ°ã Dã¯ç¹åŸŽéæ°ã \(\bar{x}_i\)ã¯içªç®ã®ç¹åŸŽéã®å¹³åå€ã§ãã\(t_i\)å€ã¯ãç¹åŸŽéã®é äœä»ããšä»®èª¬æ€å®ã«å©çšã§ããŸãããã\(P(t > t_i) < 0.05\)ã§ããã°ããã®ç¹åŸŽéã¯çµ±èšçã«ææã§ã\(P(t)\)ã¯Studentâs \(t\)-ååžã§ããç¹åŸŽéã®æææ§ã¯ã¢ãã«ã«ååšããä»ã®ç¹åŸŽéã«äŸåããããšã«æ³šæããŠãã ãããã€ãŸãæ°ããç¹åŸŽéãè¿œå ãããšãäžéšãåé·ã«ãªãå¯èœæ§ããããŸãã
次ã«éç·åœ¢ã®å ŽåãèŠãŠã¿ãŸããããéç·åœ¢åŠç¿é¢æ°\(\hat{f}(\vec{x})\)ã§ã¯ãç¹åŸŽéã1å¢å ããå Žåã«äºæž¬ãã©ã®ããã«å€åãããã埮åè¿äŒŒã§èšç®ããŸãã
1ã ãå€åããããš
å®éã«ã¯ãã®åŒãå°ãå€åœ¢ããŸãã0ãäžå¿ãšãããã€ã©ãŒçŽæ°ã䜿ã代ããã«ãä»ã®ã«ãŒãïŒé¢æ°ã0ãšãªãç¹ïŒãäžå¿ã«ããŸããããããããšã§ã決å®å¢çïŒã«ãŒãïŒã§ç³»åããæ¥å°ãããäºæž¬ãããã¯ã©ã¹ã決å®å¢çãããé ããããããè¿ã¥ããããããããšãã§ããŸãããã1ã€ã®æ¹æ³ã¯ããã€ã©ãŒçŽæ°ã®1次ã®é ã䜿çšããŠç·åœ¢ã¢ãã«ãæ§ç¯ãããšãããã®ã§ãããããŠããã®ç·åœ¢ã¢ãã«ã«äžèšãšåæ§ã®ããšãè¡ãããã®ä¿æ°ãç¹åŸŽéã®ãéèŠåºŠããšããŠäœ¿çšããŸããå ·äœçã«ã¯ã\(\hat{f}(\vec{x})\)ã«å¯ŸããŠä»¥äžã®ãããªãµãã²ãŒãé¢æ°ã䜿çšããŸãã
ããã§ã\(\vec{x}'\)ã¯\(\hat{f}(\vec{x})\)ã®ã«ãŒãã§ãããããŠãèªæãªã«ãŒã\(\vec{x}' = \vec{0}\)ãéžæãããããããŸããããè¿åã®ã«ãŒããçæ³çã§ãããã®ã«ãŒãã¯ãã°ãã°ããŒã¹ã©ã€ã³å ¥åãšåŒã°ããŸããäžèšã®ç·åœ¢ã®äŸãšã¯å¯Ÿç §çã«ãéšåçãª\(\frac{\partial \hat{f}(\vec{x})}{\partial x_i}\)ã®ç©ãšããŒã¹ã©ã€ã³ããã®å¢å\((x_i - x_i')\)ãèããŸãã
12.2.1. ãã¥ãŒã©ã«ãããã¯ãŒã¯ã®ç¹åŸŽééèŠåºŠÂ¶
ãã¥ãŒã©ã«ãããã¯ãŒã¯ã§ã¯ãåå°é¢æ°ã¯åºåã«å¯Ÿããå®éã®å€åãè¿äŒŒããã«ã¯äžååã§ããå ¥åã«å¯Ÿããå°ããªå€åãäžé£ç¶ãªå ŽåïŒReLUã®ãããªéç·åœ¢æ§ã®ããïŒãã»ãšãã©èª¬æåãæããªããªãããšããããŸããããã¯shattered gradientsåé¡[BFL+17]ãšåŒã°ããŠããŸãããŸãåã ã®ç¹åŸŽéã«åãããšãç¹åŸŽééã®çžé¢ãæ¬ èœããŠããŸããŸããããã¯ç·åœ¢ã¢ãã«ã«ã¯ãªãåé¡ã§ãããããã£ãŠã埮åè¿äŒŒã¯å±æçãªç·åœ¢ã¢ãã«ã§ã¯ååã«æ©èœããŸããããã£ãŒããã¥ãŒã©ã«ãããã¯ãŒã¯ã§ã¯æ©èœããŸããã
ãã¥ãŒã©ã«ãããã¯ãŒã¯ã«ãããshattered gradientsåé¡ãåé¿ããæ¹æ³ã¯ãããããããŸãããã䜿ãããã®ã¯integrated gradients [STY17] ãšSmoothGrad[STK+17]ã®2ã€ã®æ¹æ³ã§ããintegrated gradientsã¯\(\vec{x}'\)ãã\(\vec{x}\)ãŸã§çŽç·ã§çµã¶çµè·¯ãèãããã®çµè·¯äžã§å¯Ÿè±¡ãšãªãå€æ°ã®åŸ®åå€ãç©åã§çµ±åããŸãã
\(t\)ã¯çµè·¯ã«æ²¿ã£ãããå¢åã§ã\(t = 0\) ã®ãšã \(\vec{x}' + t\left(\vec{x} - \vec{x}'\right) = \vec{x}'\)ã\(t = 1\) ã®ãšã \(\vec{x}' + t\left(\vec{x} - \vec{x}'\right) = \vec{x}\) ã§ãããã®åŒã«ããåç¹åŸŽé\(i\)ã®integrated gradientãåŸãŸããintegrated gradientã¯åç¹åŸŽéã®éèŠåºŠã§ãããshattered gradientsã®æã€è€éãã¯ãããŸããããŸãã¢ãã« \(f(\vec{x})\) ãã»ãšãã©è³ããšããã§åŸ®åå¯èœã§ããã°ã \(\sum_i \textrm{IG}_i = f(\vec{x}) - f(\vec{x}')\) ãšããåŒãæç«ããŸããããã¯integrated gradientsã§èšç®ãããåç¹åŸŽééèŠåºŠã®åèšå€ããããŒã¹ã©ã€ã³ãšäºæž¬å€ã®å·®ã«çãããªãããšãæå³ããŠããŸããããªãã¡ããŒã¹ã©ã€ã³ããäºæž¬å€ã®å€åéãå®å šã«åé¢ããŠãããŸã[STY17]
integrated gradientsã®å®è£ ã¯æ¯èŒçç°¡åã§ããçµè·¯ãå ¥åç¹åŸŽé \(\vec{x}\)ãšããŒã¹ã©ã€ã³ \(\vec{x}'\)ã®éã«ããé¢æ£å ¥åã®éåã«åå²ããããšã«ããããªãŒãã³åã§çµè·¯ã®ç©åãè¿äŒŒããŸãããããã®å ¥åã®åŸé ããã¥ãŒã©ã«ãããã¯ãŒã¯ã§èšç®ããŸãããããŠãããŒã¹ã©ã€ã³ããã®ç¹åŸŽéã®å€åé\(\left(\vec{x} - \vec{x}'\right)\)ãä¹ããŸãã
SmmothGradã¯integrated gradientsãšåæ§ã®èãæ¹ã§ãããããçµè·¯ã«ãã£ãåŸé ãåèšããã®ã§ã¯ãªããäºæž¬ã®è¿ãã«ããã©ã³ãã ãªç¹ããåŸé ãèšç®ããŸããåŒã¯ä»¥äžã®éãã§ãã
\(M\)ã¯ãµã³ãã«æ°ã®éžæã§ããã\(\vec{\epsilon}\)ã¯\(D\)ãŒãå¹³åã¬ãŠã·ã¢ã³ãããµã³ããªã³ã°ãããŸã[STK+17]ã ããã§ã®å®è£ äžã®å¯äžã®å€æŽç¹ã¯ãçµè·¯ãäžé£ã®ã©ã³ãã ãªæåã«çœ®ãæããããšã§ãã
ãããã®åŸé ããŒã¹ã®æ¹æ³ä»¥å€ã«ããLayer-wise Relevance Propagation (LRP)ã¯ãã¥ãŒã©ã«ãããã¯ãŒã¯ã«ãããç¹åŸŽééèŠåºŠã®è§£æã®äžè¬çãªæ¹æ³ã§ããLPRã¯ã1ã€ã®å±€ã®åºåå€ãå ¥åç¹åŸŽéã«åå²ãããã¥ãŒã©ã«ãããã¯ãŒã¯ãä»ããéäŒæãè¡ãããšã§æ©èœããŸããããã¯ãé¢é£æ§ãåæ£ãããããšããããšã§ããLPRã®å€ãã£ããšããã¯ãåå±€ã®çš®é¡æ¯ã«ç¬èªã®å®è£ ãå¿ èŠãªããšã§ãã解æçãªå°é¢æ°ã«é Œãããå±€ã®æ¹çšåŒã®ãã€ã©ãŒçŽæ°å±éã§å¯Ÿå¿ããŸããGNNãã·ãŒã±ã³ã¹ã¢ãã«çšã®LRPããããLRPã¯ææãååŠã®ã»ãšãã©ã®å Žé¢ã§äœ¿ãããšãã§ããŸã[MBL+19]ã
12.2.2. ã·ã£ãŒãã¬ã€å€Â¶
ã¢ãã«éäŸåçã«ç¹åŸŽééèŠåºŠãæ±ãæ¹æ³ãšããŠãã·ã£ãŒãã¬ã€å€ããããŸããã·ã£ãŒãã¬ã€å€ã¯ã²ãŒã çè«ã«ç±æ¥ãããã®ã§ãååçãªãã¬ãŒã€ãŒã«ããã®è²¢ç®åºŠã«å¿ããŠå ±é ¬ãæ¯æãæ¹æ³ã«ã€ããŠã®è§£æ±ºçã§ããåç¹åŸŽéããã¬ãŒã€ãŒã§ãããäºæž¬å€ãžã®è²¢ç®åºŠã«å¿ããŠãæ¯æããããšãæ³å®ããŠããŸããã·ã£ãŒãã¬ã€å€ \(\phi_i(x)\)ã¯ãã€ã³ã¹ã¿ã³ã¹\(x\)ã®ç¹åŸŽé\(i\)ã«å¯Ÿããæ¯æãã§ããäºæž¬é¢æ°å€ \(\hat{f}(x)\)ãã·ã£ãŒãã¬ã€å€ã«åå²ããŠããã®åãé¢æ°å€\(\sum_i \phi_i(x) = \hat{f}(x)\)ãšãªãããã«ããŸããã€ãŸããããç¹åŸŽéã®ã·ã£ãŒãã¬ã€å€ã¯äºæž¬ã«å¯Ÿããæ°å€çãªè²¢ç®åºŠãšè§£éã§ããŸããã·ã£ãŒãã¬ã€å€ã®åŒ·åãªå©ç¹ã¯ãã¢ãã«ã«äŸåãããäºæž¬å€ãåç¹åŸŽéã«åå²ã§ããäºæž¬ã®èª¬æã«å¿ èŠãªå±æ§ïŒå¯Ÿç§°æ§ãç·åœ¢æ§ãé åºäžå€æ§ãªã©ïŒãæã€ããšã§ããæ¬ ç¹ã¯ãå³å¯ãªèšç®ã«ã¯ç¹åŸŽéã®çµã¿åããã®æ°ã ãã³ã¹ãããããããšãã¹ããŒã¹æ§ãæããªãããšã§ãããçµæçã«ç¹åŸŽéæ°ã®å¢å ã«äŒŽã£ãŠæçšæ§ãäœããªããŸããããã§çŽ¹ä»ããææ³ãã¹ããŒã¹æ§ãæããªããã®ãã»ãšãã©ã§ããL1æ£åå(Standard Layersåç §)ã®ããã«ãåžžã«ã¢ãã«ãã¹ããŒã¹ã«ããããšã§ã¹ããŒã¹ãªèª¬æãå®çŸããããšãã§ããŸãã
ã·ã£ãŒãã¬ã€å€ã¯æ¬¡ã®ããã«èšç®ãããŸãã
\(S \in N \backslash x_i\)ã¯ç¹åŸŽé\(x_i\)ãé€ããå šãŠã®ç¹åŸŽéã®éåãæå³ãã\(S\cup x_i\)ã¯ç¹åŸŽé \(x_i\)ãéåã«æ»ãããšãæå³ããŸãããŸã\(v(S)\)ã¯\(S\)ã«å«ãŸããç¹åŸŽéã®ã¿ã䜿çšããå Žåã®\(\hat{f}(x)\)ã®å€ã§ããã\(Z\)ã¯æ£èŠåçšã®å€ã§ãããã®åŒã¯ãç¹åŸŽé\(i\)ãè¿œå /åé€ããããšã«ãã£ãŠåœ¢æããã\(\hat{f}\)ã®åããããã¹ãŠã®å·®ã®å¹³åãšè§£éããããšãã§ããŸãã
ããããç¹åŸŽé\(i\)ãã¢ãã«åŒããã©ã®ããã«ãåãé€ããããšãã§ããã§ãããããç¹åŸŽé\(i\)ãç¡çšã®ãã®ãšããŠæ±ãïŒåšèŸºåããïŒããšã§ã§ããŸããåšèŸºåãšã¯ç¢ºçå€æ°\(P(x) = \int\, P(x,y)\,dy\)ãç©åããæ¹æ³ã§ããããšãæãåºããŠãã ãããããã¯åããããã¹ãŠã®å€\(x\)ãç©åããŸããåšèŸºåã¯ç¢ºçå€æ°ã®é¢æ°ã«ã䜿ãããšãã§ããŸããããã¯æããã«ç¢ºçå€æ°ã§ãããã®ã§ãããæåŸ å€\(E_y[f | X = x] = \int\,f(X=x,y)P(X=x,y)\, dy\)ãåãããšã«ãã£ãŠäœ¿ãããšãã§ããŸããç©åã§ã¯ç¢ºçå€æ°\(X\)ãåºå®ãããŠããããã\(E_y[f]\)ã¯\(x\)ã®é¢æ°ã§ããããšã匷調ããŸãããã\(x\)ãåºå®ãããŠããå ŽæïŒé¢æ°ã®åŒæ°ïŒ\(f(x,y)\)ã®æåŸ å€ãèšç®ããããšã«ãã£ãŠé€ãããŸããæ¬è³ªçã«ã¯ããã¹ãŠã®åãåŸãå€\(y\)ã®å¹³åã§ãã\(f(x,y)\)ãæ°ããé¢æ°\(E_y[f]\)ã«çœ®ãæããŠããŸãããããŸã§ããªã詳现ã«èª¬æããŠããŸããããäžã®ã³ãŒããèŠãã°çŽæçã«ç解ã§ããŸããããäžã€ä»ãå ãããšããã°ãå€ã\(\hat{f}\)ã®å¹³åå€ã«å¯Ÿããçžå¯Ÿçãªå€åã§ãããšããããšã§ããäœåãªé ã¯ç¡èŠããŠãããŸããŸãããã念ã®ããå ¥ããŠãããŸãããããã£ãŠãå€ã®æ¹çšåŒã¯ [vStrumbeljK14]ãšãªããŸãã
åšèŸºå\(\int\,f(x_0, x_1, \ldots, x_i,\ldots, x_N)P(x_0, x_1, \ldots, x_i,\ldots, x_N)\, dx_i\)ã¯ã©ã®ããã«èšç®ããã®ã§ãããããæ¢ç¥ã®ç¢ºçååžã¯ãããŸããããã®å ŽåãããŒã¿ãçµéšçãªååžãšããŠèããããšã§\(P(\vec{x})\)ãããµã³ããªã³ã°ã§ããŸããããªãã¡ãããŒã¿ç¹ããµã³ããªã³ã°ããããšã§\(P(\vec{x})\)ãããµã³ããªã³ã°ã§ããŸãããã ãã\(\vec{x}\)ãå ±ã«ãµã³ããªã³ã°ããå¿ èŠãããããå°ãè€éã§ããé€ãããç¹åŸŽéãšã®éã«çžé¢ãããå Žåãåã ã®ç¹åŸŽãã©ã³ãã ã«æ··ããããšã¯ã§ããŸããã
Strumbeljã[vStrumbeljK14]ã¯\(i\)çªç®ã®ã·ã£ãŒãã¬ã€å€ãçŽæ¥æšå®ã§ããããšã瀺ããŸããã
\(\vec{z}\)ã¯ãå®éã®äŸ\(\vec{x}\)ãšã©ã³ãã ãªæŒ«ç¶ãšããäŸ\(\vec{x}'\)ããæãããã¡ã©ãã®äŸã§ãã\(\vec{x}\)ãš\(\vec{x}'\)ããã©ã³ãã ã«éžæãã\(\vec{z}\)ãæ§æããŸãã\(\vec{z}_{+i}\)ã¯äŸ\(\vec{x}\)ã®\(i\)çªç®ã®ç¹åŸŽéãæã¡ã\(\vec{z}_{-i}\)ã¯ã©ã³ãã ãªäŸ\(\vec{x}'\)ã®\(i\)çªç®ã®ç¹åŸŽéãæã¡ãŸãã\(M\)ã¯ãã®å€ã«å¯ŸããŠè¯ããµã³ãã«ãåŸãããã«ååã«å€§ããéžã°ããŸãã[vStrumbeljK14]ã¯\(M\)ã®éžææ¹æ³ã«é¢ããæéã瀺ããŠããŸãããåºæ¬çã«ã¯èšç®å¯èœã§åŠ¥åœãªç¯å²ã§å€§ããª\(M\)ãéžæããŸãããã®è¿äŒŒã®äžã€ã®å€æŽç¹ã¯ãæåŸ å€ïŒãšãã«ã¯\(\phi_0\)ãšè¡šèšãããïŒãè¡šãæ瀺çãªé ã䜿ã£ãŠããããšã§ããå®å šæ§ããæããæ¹çšåŒã¯æ¬¡ã®ããã«ãªããŸãã
æåŸ å€ã\phi_0\(ãšããŠæ瀺çã«å«ããå Žåãããã¯\)\vec{x}$ã«äŸåããŸããã
ãã®å¹ççãªè¿äŒŒæ¹æ³ã匷åãªçè«ãã¢ãã«éäŸåçã«ãããã·ã£ãŒãã¬ã€å€ã¯äºæž¬å€ã«å¯Ÿããç¹åŸŽééèŠåºŠãèšè¿°ããã®ã«åªããéžæè¢ãšãªããŸãã
12.3. Notebookã®å®è¡Â¶
ãã®ããŒãžäžéšã®   ãæŒããšããã®ããŒãããã¯ãGoogle Colab.ã§éãããŸããå¿ èŠãªããã±ãŒãžã®ã€ã³ã¹ããŒã«æ¹æ³ã«ã€ããŠã¯ä»¥äžãåç §ããŠãã ããã
Tip
å¿ èŠãªããã±ãŒãžãã€ã³ã¹ããŒã«ããã«ã¯ãæ°èŠã»ã«ãäœæããŠæ¬¡ã®ã³ãŒããå®è¡ããŠãã ããã
!pip install dmol-book
ããã€ã³ã¹ããŒã«ãããŸããããªãå Žåãããã±ãŒãžã®ããŒãžã§ã³äžäžèŽãåå ã§ããå¯èœæ§ããããŸããåäœç¢ºèªããšããææ°ããŒãžã§ã³ã®äžèŠ§ã¯ããããåç §ã§ããŸã
import haiku as hk
import jax
import tensorflow as tf
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import urllib
from functools import partial
from jax.example_libraries import optimizers as opt
import dmol
np.random.seed(0)
tf.random.set_seed(0)
ALPHABET = [
"-",
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
ãŸããã¢ããé žé åãšone-hotãã¯ãã«ãçžäºã«å€æããé¢æ°ãå®çŸ©ããŸãããã
def seq2array(seq, L=200):
return np.pad(list(map(ALPHABET.index, seq)), (0, L - len(seq))).reshape(1, -1)
def array2oh(a):
a = np.squeeze(a)
o = np.zeros((len(a), 21))
o[np.arange(len(a)), a] = 1
return o.astype(np.float32).reshape(1, -1, 21)
urllib.request.urlretrieve(
"https://github.com/whitead/dmol-book/raw/master/data/hemolytic.npz",
"hemolytic.npz",
)
with np.load("hemolytic.npz", "rb") as r:
pos_data, neg_data = r["positives"], r["negatives"]
12.4. ç¹åŸŽééèŠåºŠã®äŸÂ¶
ãããããèµ€è¡çãç Žå£ãããã©ããïŒæº¶è¡æ§ïŒãäºæž¬ããããããäºæž¬ã¿ã¹ã¯ã§ãç¹åŸŽééèŠåºŠæ³ã®äŸãèŠãŠã¿ãŸãããã
ããã¯Standard Layersã®æº¶è§£åºŠäºæž¬ã®äŸã«äŒŒãŠããŸããããŒã¿ã¯[BW21]ãå©çšããŸãã
ã¢ãã«ã¯ããããé
åïŒäŸïŒDDFRD
ïŒãåã蟌ã¿ããã®ããããã溶è¡æ§ã§ãã確çãäºæž¬ããŸãã
ããã§ã®ç¹åŸŽééèŠåºŠæ³ã®ç®æšã¯ãã©ã®ã¢ããé
žã溶è¡æŽ»æ§ã«æãéèŠã§ããããç¹å®ããããšã§ãã
äžã®éããã»ã«ã¯ããŒã¿ãããŒããåŠçããŠããŒã¿ã»ããã«ããŸãã
# create labels and stich it all into one
# tensor
labels = np.concatenate(
(
np.ones((pos_data.shape[0], 1), dtype=pos_data.dtype),
np.zeros((neg_data.shape[0], 1), dtype=pos_data.dtype),
),
axis=0,
)
features = np.concatenate((pos_data, neg_data), axis=0)
# we now need to shuffle before creating TF dataset
# so that our train/test/val splits are random
i = np.arange(len(labels))
np.random.shuffle(i)
labels = labels[i]
features = features[i]
L = pos_data.shape[-2]
# need to add token for empty amino acid
# dataset just has all zeros currently
features = np.concatenate((np.zeros((features.shape[0], L, 1)), features), axis=-1)
features[np.sum(features, -1) == 0, 0] = 1.0
batch_size = 16
full_data = tf.data.Dataset.from_tensor_slices((features.astype(np.float32), labels))
# now split into val, test, train
N = pos_data.shape[0] + neg_data.shape[0]
split = int(0.1 * N)
test_data = full_data.take(split).batch(batch_size)
nontest = full_data.skip(split)
val_data, train_data = nontest.take(split).batch(batch_size), nontest.skip(
split
).shuffle(1000).batch(batch_size)
JaxïŒHaikuã䜿çšïŒã§ç³ã¿èŸŒã¿ã¢ãã«ãåæ§ç¯ããåŸé ãããå°ãç°¡åã«æ±ããããã«ããŸãããŸããã®ä»ã«ãããã€ãã¢ãã«ã«å€æŽãå ããŠããŸããç³ã¿èŸŒã¿ã«å ããŠãé åã®é·ããšã¢ããé žã®å²åãè¿œå æ å ±ãšããŠæž¡ããŠããŸãã
def binary_cross_entropy(logits, y):
"""Binary cross entropy without sigmoid. Works with logits directly"""
return (
jnp.clip(logits, 0, None) - logits * y + jnp.log(1 + jnp.exp(-jnp.abs(logits)))
)
def model_fn(x):
# get fractions, excluding skip character
aa_fracs = jnp.mean(x, axis=1)[:, 1:]
# compute convolutions/poolings
mask = jnp.sum(x[..., 1:], axis=-1, keepdims=True)
for kernel, pool in zip([5, 3, 3], [4, 2, 2]):
x = hk.Conv1D(16, kernel)(x) * mask
x = jax.nn.tanh(x)
x = hk.MaxPool(pool, pool, "VALID")(x)
mask = hk.MaxPool(pool, pool, "VALID")(mask)
# combine fractions, length, and convolution ouputs
x = jnp.concatenate((hk.Flatten()(x), aa_fracs, jnp.sum(mask, axis=1)), axis=1)
# dense layers. no bias, so zeros give P=0.5
logits = hk.Sequential(
[
hk.Linear(256, with_bias=False),
jax.nn.tanh,
hk.Linear(64, with_bias=False),
jax.nn.tanh,
hk.Linear(1, with_bias=False),
]
)(x)
return logits
model = hk.without_apply_rng(hk.transform(model_fn))
def loss_fn(params, x, y):
logits = model.apply(params, x)
return jnp.mean(binary_cross_entropy(logits, y))
@jax.jit
def hemolytic_prob(params, x):
logits = model.apply(params, x)
return jax.nn.sigmoid(jnp.squeeze(logits))
@jax.jit
def accuracy_fn(params, x, y):
logits = model.apply(params, x)
return jnp.mean((logits >= 0) * y + (logits < 0) * (1 - y))
rng = jax.random.PRNGKey(0)
xi, yi = features[:batch_size], labels[:batch_size]
params = model.init(rng, xi)
opt_init, opt_update, get_params = opt.adam(1e-2)
opt_state = opt_init(params)
@jax.jit
def update(step, opt_state, x, y):
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state), x, y)
opt_state = opt_update(step, grads, opt_state)
return value, opt_state
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
epochs = 32
for e in range(epochs):
avg_v = 0
for i, (xi, yi) in enumerate(train_data):
v, opt_state = update(i, opt_state, xi.numpy(), yi.numpy())
avg_v += v
opt_params = get_params(opt_state)
def predict(x):
return jnp.squeeze(model.apply(opt_params, x))
def predict_prob(x):
return hemolytic_prob(opt_params, x)
ãããã³ãŒããè¿œãã®ã倧å€ã§ã倧äžå€«ã§ãïŒãã®ç« ã®ãŽãŒã«ã¯ã¢ãã«ã®èª¬æãåŸãæ¹æ³ã瀺ãããšã§ãããå¿
ãããã¢ãã«ãæ§ç¯ããæ¹æ³ã§ã¯ãããŸãããã§ãããã次ã®æ°è¡ã«æ³šç®ããŠãã ãããããã§ã¯ãäºæž¬å€ãåŸãŠããã説æããããã«ãã©ã®ããã«ã¢ãã«ã䜿ããã説æããŸããã¢ãã«ã¯ãlogitsã®å Žå㯠predict(x)
ã確çã®å Žå㯠predict_prob
ãä»ããŠåŒã³åºãããŸãã
ã¢ããé žé åããããããè©ŠããŠãã¢ãã«ã®ã€ã¡ãŒãžãã€ããã§ã¿ãŸããããã¢ãã«ã¯logitsïŒãªããºã®å¯Ÿæ°ïŒãåºåãããããã·ã°ã¢ã€ãé¢æ°ã«ããããšç¢ºçãåŸãããŸããããããã¯é åããone-hotãã¯ãã«ã®è¡åã«å€æããå¿ èŠããããŸããããã§ã¯ã2ã€ã®æ¢ç¥ã®é åãè©ŠããŠã¿ãŸããããQã¯æº¶è¡çæ®åºãšããŠããç¥ãããŠããã2çªç®ã®é åã¯ããªGã§ãããã¯æãåçŽãªã¢ããé žã§ãã
s = "QQQQQ"
sm = array2oh(seq2array(s))
p = predict_prob(sm)
print(f"Probability {s} of being hemolytic {p:.2f}")
s = "GGGGG"
sm = array2oh(seq2array(s))
p = predict_prob(sm)
print(f"Probability {s} of being hemolytic {p:.2f}")
Probability QQQQQ of being hemolytic 1.00
Probability GGGGG of being hemolytic 0.00
ããã¯åŠ¥åœãªçµæã«èŠããŸããã¢ãã«ã®åºåã¯ããããã®2ã€ã®é åã«é¢ããç§ãã¡ã®çŽæã«äžèŽããŠããŸãã
ããŠããã®ã¢ãã«ã®æ£ç¢ºåºŠãèšç®ããŸããããéåžžã«è¯ãçµæãåŸãããŸãã
acc = []
for xi, yi in test_data:
acc.append(accuracy_fn(opt_params, xi.numpy(), yi.numpy()))
print(jnp.mean(np.array(acc)))
0.95208335
12.4.1. åŸé ¶
ã§ã¯ããªãããé åã溶è¡æ§ãšãªãã®ã調ã¹ãŠã¿ãŸãããããŸããå ¥åã«å¯ŸããåŸé ãèšç®ããããšããå§ããŸããããã¯çŽ æŽãªã¢ãããŒãã§shattered gradientsã«ãªããããã§ããããã®åŸèšç®ããintegrated gradientsãšsmooth gradientsã®ããã»ã¹ã®äžéšã§ãããç¡é§ã§ã¯ãããŸãããããã§ã¯æº¶è¡æ§ã§ããããšãç¥ãããŠãããããè€éãªããããé åã䜿ã£ãŠãããèå³æ·±ã解æãããŠã¿ãŸãããã
def plot_grad(g, s, ax=None):
# g = np.array(g)
if ax is None:
plt.figure()
ax = plt.gca()
if len(g.shape) == 3:
h = g[0, np.arange(len(s)), list(map(ALPHABET.index, s))]
else:
h = g
ax.bar(np.arange(len(s)), height=h)
ax.set_xticks(range(len(s)))
ax.set_xticklabels(s)
ax.set_xlabel("Amino Acid $x_i$")
ax.set_ylabel(r"Gradient $\frac{\partial \hat{f}(\vec{x})}{\partial x_i}$")
s = "RAGLQFPVGRLLRRLLRRLLR"
sm = array2oh(seq2array(s))
p = predict_prob(sm)
print(f"Probability {s} of being hemolytic {p:.2f}")
Probability RAGLQFPVGRLLRRLLRRLLR of being hemolytic 1.00
ã³ãŒãã¯è³ã£ãŠã·ã³ãã«ã§ããããã§ã¯åŸé ãèšç®ããŠã¿ãŸãããã
gradient = jax.grad(predict, 0)
g = gradient(sm)
plot_grad(g, s)

ã¢ãã«ã¯logitsãåºåããããšãå¿ããªãã§ãã ãããåŸé ãæ£ã®å€ã§ããã°ããã®ã¢ããé žã溶è¡æ§ã§ãã確çãé«ããè² ã®å€ã§ããã°ãã®ã¢ããé žé åãé溶è¡æ§ã§ãã確çãé«ããããšãæå³ããŸããèå³æ·±ãããšã«ããã€ã·ã³(L)ãšã¢ã«ã®ãã³(R)ã¯åŒ·ãäœçœ®äŸåæ§ãããããšãããããŸãã
12.4.2. Integrated Gradients¶
次ã«integrated gradientsæ³ãå®è£ ããŸãã3ã€ã®åºæ¬çãªã¹ããããèžã¿ãŸãã
ããŒã¹ã©ã€ã³ããå ¥åããããã«åããå ¥åé åïŒçµè·¯ïŒãäœããŸãã
åå ¥åã«å¯ŸããŠåŸé ãèšç®ããŸãã
åŸé ã®åèšãèšç®ããããã«ããŒã¹ã©ã€ã³ãšããããã®å·®ãä¹ããŸãã
ããŒã¹ã©ã€ã³ã¯ãã¹ãŠ0ã§ããã0.5ã®ç¢ºçãäžããŸã (logits = 0, a model root)ããã®ããŒã¹ã©ã€ã³ã¯ãŸãã«æ±ºå®å¢çäžã«ãããŸãããã¹ãŠã®ã°ãªã·ã³ãã¢ã©ãã³ã®ããã«ä»ã®ããŒã¹ã©ã€ã³ã䜿ãããšãã§ããŸããããããã¯0.5ã®ç¢ºçãããã«è¿ããã®ã§ããã¹ãã§ããããŒã¹ã©ã€ã³éžæã®è©³çŽ°ãšã€ã³ã¿ã©ã¯ãã£ããªæ¢çŽ¢ã¯[SLL20]ãèŠãŠãã ããã
def integrated_gradients(sm, N):
baseline = jnp.zeros((1, L, 21))
t = jnp.linspace(0, 1, N).reshape(-1, 1, 1)
path = baseline * (1 - t) + sm * t
def get_grad(pi):
# compute gradient
# add/remove batch axes
return gradient(pi[jnp.newaxis, ...])[0]
gs = jax.vmap(get_grad)(path)
# sum pieces (Riemann sum), multiply by (x - x')
ig = jnp.mean(gs, axis=0, keepdims=True) * (sm - baseline)
return ig
ig = integrated_gradients(sm, 1024)
plot_grad(ig, s)

äœçœ®äŸåæ§ãããé¡èã«ãªããã¢ã«ã®ãã³ã¯äœçœ®ã«å¯ŸããŠéåžžã«ææã§ããããšãããããŸããå çšã®éåžžã®åŸé ãšæ¯èŒãããšã質çãªå€åã¯ããŸããããŸããã
12.4.3. SmoothGrad¶
SmmotGradãè¡ãã¹ãããã¯integrated gradientsãšã»ãšãã©åãã§ãã
å ¥åããããã«ã©ã³ãã ãªæåãå ããå ¥åé åïŒçµè·¯ïŒãäœæããŸãã
åå ¥åã®åŸé ãèšç®ããŸãã
åŸé ã®å¹³åãèšç®ããŸãã
ãã€ããŒãã©ã¡ãŒã¿ãšããŠã\(\sigma\)ããããåççã«ã¯ã¢ãã«ã®åºåãå€åãããªãããã§ããã ãå°ããããå¿ èŠããããŸãã
def smooth_gradients(sm, N, rng, sigma=1e-3):
baseline = jnp.zeros((1, L, 21))
t = jax.random.normal(rng, shape=(N, sm.shape[1], sm.shape[2])) * sigma
path = sm + t
# remove examples that are negative and force summing to 1
path = jnp.clip(path, 0, 1)
path /= jnp.sum(path, axis=2, keepdims=True)
def get_grad(pi):
# compute gradient
# add/remove batch axes
return gradient(pi[jnp.newaxis, ...])[0]
gs = jax.vmap(get_grad)(path)
# mean
ig = jnp.mean(gs, axis=0, keepdims=True)
return ig
sg = smooth_gradients(sm, 1024, jax.random.PRNGKey(0))
plot_grad(sg, s)

éåžžã®åŸé ã®çµæã«é ·äŒŒããŠããããã«èŠããŸããããã¯ã1次å ã®å ¥åãšæµ ããããã¯ãŒã¯ããshattered gradientsã«å¯ŸããŠããã»ã©ææã§ã¯ãªããããšæãããŸãã
12.4.4. ã·ã£ãŒãã¬ã€å€Â¶
次ã«ãåŒ10.9ã䜿ã£ãŠåç¹åŸŽéã«å¯Ÿããã·ã£ãŒãã¬ã€å€ãè¿äŒŒããŠã¿ãŸããããã·ã£ãŒãã¬ã€å€ã®èšç®ã¯åŸé ãå¿ èŠãšããªãããããããŸã§ã®ã¢ãããŒããšã¯ç°ãªããŸããåºæ¬çãªã¢ã«ãŽãªãºã ã¯æ¬¡ã®ããã«ãªããŸãã
ã©ã³ãã ãªç¹xâãéžæããŸãã
xãšxâãçµã¿åãããŠç¹\(z\)ãäœããŸãã
äºæž¬é¢æ°ã®å€åãèšç®ããŸãã
å¹çåã®ããã«è¡ã£ã1ã€ã®å·¥å€«ã¯ãããã£ã³ã°ã§é åãå€æŽããªãããã«ããããšã§ããåºæ¬çã«é åãé·ããããããªããšã¯ããªãããã«ããŠããŸãã
def shapley(i, sm, sampled_x, rng, model):
M, F, *_ = sampled_x.shape
z_choice = jax.random.bernoulli(rng, shape=(M, F))
# only swap out features within length of sm
mask = jnp.sum(sm[..., 1:], -1)
z_choice *= mask
z_choice = 1 - z_choice
# construct with and w/o ith feature
z_choice = z_choice.at[:, i].set(0.0)
z_choice_i = z_choice.at[:, i].set(1.0)
# select them via multiplication
z = sm * z_choice[..., jnp.newaxis] + sampled_x * (1 - z_choice[..., jnp.newaxis])
z_i = sm * z_choice_i[..., jnp.newaxis] + sampled_x * (
1 - z_choice_i[..., jnp.newaxis]
)
v = model(z_i) - model(z)
return jnp.squeeze(jnp.mean(v, axis=0))
# assume data is alrady shuffled, so just take M
M = 4096
sl = len(s)
sampled_x = train_data.unbatch().batch(M).as_numpy_iterator().next()[0]
# make batched shapley so we can compute for all features
bshapley = jax.vmap(shapley, in_axes=(0, None, None, 0, None))
sv = bshapley(
jnp.arange(sl),
sm,
sampled_x,
jax.random.split(jax.random.PRNGKey(0), sl),
predict,
)
# compute global expectation
eyhat = 0
for xi, yi in full_data.batch(M).as_numpy_iterator():
eyhat += jnp.mean(predict(xi))
eyhat /= len(full_data)
ã·ã£ãŒãã¬ã€å€ã«é¢ããäžã€ã®è¯ããã§ãã¯ã¯ããããã®åèšãã¢ãã«é¢æ°ã®å€ãããã¹ãŠã®ã€ã³ã¹ã¿ã³ã¹ã«ãããæåŸ å€ãåŒãããã®ã«çããããšã確èªããããšã§ãããã ã [vStrumbeljK14]ã®åŒãè¿äŒŒããŠäœ¿çšããŠããã®ã§ãå®å šãªäžèŽã¯æåŸ ã§ããŸããããã®å€ã¯æ¬¡ã®ããã«èšç®ãããŸãã
print(np.sum(sv), predict(sm))
6.7373457 8.068422
äºæ³éããããããéããŸããããã¯ä»å䜿ã£ãŠããè¿äŒŒæ³ã®åœ±é¿ã§ãããµã³ãã«æ°ãã·ã£ãŒãã¬ã€å€ã®åèšã«ã©ã®ããã«åœ±é¿ãããã調ã¹ãããšã§ãããã確èªããããšãã§ããŸãã

Fig. 12.1 ã·ã£ãŒãã¬ã€å€è¿äŒŒã«ããããã·ã£ãŒãã¬ã€å€ã®ç·åãšãµã³ãã«æ°ã®é¢æ°ã®éæ°å€ã®æ¯èŒÂ¶
åŸã ã«åæããŠããŸããæåŸã«ãåã ã®ã·ã£ãŒãã¬ã€å€ã衚瀺ããŠèŠãŸãããããããäºæž¬ã«å¯Ÿãã説æãšãªããŸãã
plot_grad(sv, s)

ãããŸã§ã«èŠãŠãã4ã€ã®ææ³ãåŸé æ³ãIntegrated Gradientæ³ãSmoothGradæ³ãã·ã£ãŒãã¬ã€å€ã®çµæã䞊ã¹ãŠç€ºããŸãã
heights = []
plt.figure(figsize=(12, 4))
x = np.arange(len(s))
for i, (gi, l) in enumerate(zip([g, ig, sg], ["Gradient", "Integrated", "Smooth"])):
h = gi[0, np.arange(len(s)), list(map(ALPHABET.index, s))]
plt.bar(x + i / 5 - 1 / 4, h, width=1 / 5, edgecolor="black", label=l)
plt.bar(x + 3 / 5 - 1 / 4, sv, width=1 / 5, edgecolor="black", label="Shapley")
ax = plt.gca()
ax.set_xticks(range(len(s)))
ax.set_xticklabels(s)
ax.set_xlabel("Amino Acid $x_i$")
ax.set_ylabel(r"Importance [logits]")
plt.legend()
plt.show()

æ®æ®µãããããããæ±ã£ãŠããè ãšããŠãããã§ã¯ã·ã£ãŒãã¬ã€å€ãæãæ£ç¢ºã ãšæããŸããLãšRã®ãã¿ãŒã³ãéèŠã ãšã¯èããŠããŸããã§ããããã·ã£ãŒãã¬ã€å€ã¯ãã瀺ããŠããŸãããŸãä»ã®ææ³ã®çµæãšç°ãªããã·ã£ãŒãã¬ã€å€ã¯ãã§ãã«ã¢ã©ãã³(F)ãéèŠãªå¹æãæã€ãšã¯ç€ºããŠããŸããã
ãã®çµæããäœãçµè«ã¥ããããšãã§ããã§ããããããããã次ã®ãããªèª¬æãå ããããšãã§ããã§ããããããã®é åã¯ãäž»ã«ã°ã«ã¿ãã³ããããªã³ããããŠãã€ã·ã³ãšã¢ã«ã®ãã³ã®é åã«ãã£ãŠæº¶è¡æ§ã§ãããšäºæž¬ãããŠããŸããã
12.5. ç¹åŸŽééèŠåºŠã¯äœã®ããã«ããã®ãïŒÂ¶
ç¹åŸŽééèŠåºŠã¯ãå®çšçãªäºæž¬ãæŽå¯ãäžããæ確ãªèª¬æã«ã€ãªããããšã¯ã»ãšãã©ãããŸãããå æé¢ä¿ããªããããååšããªãç¹åŸŽéã®èª¬æã«æå³ãèŠåºãããšã«ç¹ãããããŸãã[CK18]ãããäžã€ã®æ³šæç¹ã¯ãå®éã®ååŠç©è³ªã®äœç³»ã§ã¯ãªããã¢ãã«ã説æããŠãããšããããšã§ããäŸãã°ãã溶è¡æŽ»æ§ã¯5çªç®ã®ã°ã«ã¿ãã³ã«ãããã®ã§ãããšè§£éããã®ã¯é¿ããŸãããã代ããã«ãã¢ãã«ã¯5çªç®ã«ã°ã«ã¿ãã³ãäœçœ®ãããã溶è¡æŽ»æ§ã§ãããšäºæž¬ããŸããããšããŠãã ããã
å®çšçãªèª¬æã¯ãç¹åŸŽéãã©ã®ããã«å€ããã°çµæã«åœ±é¿ãããã瀺ããã®ã§ãçµæã®åå ãç¥ã£ãŠããããšã«äŒŒãŠããŸãããããã£ãŠãäžè¿°ã®çç±ãããç¹åŸŽééèŠåºŠã«èª¬ææ§ããããã©ããã«ã€ããŠã¯è°è«ãç¶ããŠããŸã[Lip18]ãåèãŸã§ã«ãç¹åŸŽééèŠåºŠã人ã®æŠå¿µã«çµã³ã€ããããšããç 究åéã¯ãtesting with concept activation vectorsïŒTCAVïŒ[KWG+18]ãšåŒã°ããŠããŸããã¡ãªã¿ã«ç§èªèº«ã¯XAIã®ããã«ç¹åŸŽééèŠåºŠãããŸã䜿ã£ãŠããŸãããããã¯ã説æãå®çšçã§ãå æé¢ä¿ã瀺ããã®ã§ããªãããã°ãã°ä»ã®æ··ä¹±ãæãããã§ãã
12.6. åŠç¿ããŒã¿ã®éèŠåºŠÂ¶
ããäžã€ã®ç§ãã¡ãæåŸ ãã説æã解éã¯ãã©ã®åŠç¿ããŒã¿ç¹ãäºæž¬ã«æãè²¢ç®ããŠããããšããããšã§ããããã¯æ¬¡ã®è³ªåã«å¯ŸããããçŽæ¥çãªåçã«ãªããŸããããªãç§ã®ã¢ãã«ã¯ãããäºæž¬ããã®ã§ãããããããã¥ãŒã©ã«ãããã¯ãŒã¯ã¯åŠç¿ããŒã¿ã®çµæã§ããããªããã®äºæž¬ããªãããã®ãã«å¯Ÿããçãã¯åŠç¿ããŒã¿ã蟿ãããšã§åŸãããŸããããäºæž¬ã«å¯ŸããŠåŠç¿ããŒã¿ãã©ã³ã¯ä»ãããããšã§ãã©ã®åŠç¿ããŒã¿ç¹ããã¥ãŒã©ã«ãããã¯ãŒã¯ã®äºæž¬ã«åœ±é¿ãäžããŠããã®ãã«é¢ããæŽå¯ãåŸãããšãã§ããŸããããã¯åœ±é¿é¢æ°\(\mathcal{I}(x_i, x)\)ã®ããã§ãããåŠç¿ããŒã¿ç¹\(i\)ãšå ¥å\(x\)ã«å¯Ÿãã圱é¿åºŠã¹ã³ã¢ãäžããŸãã圱é¿åºŠãèšç®ããæãç°¡åãªæ¹æ³ã¯ããã¥ãŒã©ã«ãããã¯ãŒã¯ã«\(x_i\)ãããå ŽåïŒã€ãŸã\(\hat{f}(x)\)ïŒãšãªãå ŽåïŒã€ãŸã\(\hat{f}_{-x_i}(x)\)ïŒãåŠç¿ããŠã圱é¿åºŠã以äžã®ããã«å®çŸ©ããŸãã
äŸãã°ãåŠç¿ããŒã¿ããåŠç¿ããŒã¿ç¹\(x_i\)ãé€ããåŸãäºæž¬å€ãé«ããªãã°ããã®ç¹ã¯æ£ã®åœ±é¿åãæã£ãŠãããšããããšã«ãªããŸãããã®åœ±é¿é¢æ°ã®èšç®ã¯éåžžããŒã¿ç¹ã®æ°ã ãã¢ãã«ãç¹°ãè¿ãåŠç¿ããå¿ èŠããããŸãããéåžžã¯èšç®ã§ããŸããã [KL17] show a way to approximate this by looking at infinitesimal changes to the weights of each training point. ãããã®åœ±é¿é¢æ°ãèšç®ããã«ã¯æ倱é¢æ°ã«é¢ããHessianãèšç®ããå¿ èŠããããããäžè¬çã«ã¯äœ¿çšãããŸãããããããJAXã䜿ã£ãŠããå Žåã¯ããã®èšç®ãç°¡åã«è¡ãããšãã§ããŸãã
åŠç¿ããŒã¿ã®éèŠåºŠã¯ãã£ãŒãã©ãŒãã³ã°ã®å°é家ã«ãšã£ãŠæçšãªè§£éãæäŸããŸããããäºæž¬ã«å¯ŸããŠã©ã®åŠç¿ããŒã¿ç¹ãæã圱é¿åãæã£ãŠããã®ããæããŠãããŸããããã¯ããŒã¿ã«é¢ããåé¡ã«å¯ŸåŠããå Žåãåœéœæ§ã«å¯Ÿãã説æã蟿ãã®ã«åœ¹ç«ã¡ãŸãããããããã£ãŒãã©ãŒãã³ã°ã¢ãã«ã®äºæž¬çµæãå©çšããäžè¬ã®å©çšè ã¯ãããããåŠç¿ããŒã¿ã®ã©ã³ã¯ä»ãã ãã§ã¯æºè¶³ããªãã§ãããã
12.7. ãµãã²ãŒãã¢ãã«Â¶
解éå¯èœæ§ã«ãããããäžè¬çãªèãæ¹ã®äžã€ã¯ã解éå¯èœãªã¢ãã«ããã©ãã¯ããã¯ã¹ã¢ãã«ã«ç¹å®ã®äŸã®è¿åã§é©åãããããšã§ãããããªããªãã解éå¯èœãªã¢ãã«ã¯ãããŠã倧åçã«ãã©ãã¯ããã¯ã¹ã¢ãã«ã«é©åãããããšã¯ã§ããªãããã§ããããã§ãªããã°ãæåãã解éå¯èœãªã¢ãã«ã䜿ãããã©ãã¯ããã¯ã¹ã¢ãã«ã¯äœ¿ããªãã§ãããããããã解éå¯èœãªã¢ãã«ã¯èå³ããäŸã®åšèŸºã®å°ããªé åã«ã ãåœãŠã¯ããã°ãå±æçã«æ£ãã解éå¯èœãªã¢ãã«ã䜿ã£ãŠèª¬æãäžããããšãã§ããŸãããã®è§£éå¯èœãªã¢ãã«ãããŒã«ã«ãµãã²ãŒãã¢ãã«ãšåŒã³ãŸãã解éå¯èœãªããŒã«ã«ãµãã²ãŒãã¢ãã«ã«ã¯ã決å®æšãç·åœ¢ã¢ãã«ãïŒç°¡æœãªèª¬æã®ããã®ïŒã¹ããŒã¹ç·åœ¢ã¢ãã«ããã€ãŒããã€ãºåé¡åšãªã©ããããŸãã
ããŒã«ã«ãµãã²ãŒãã¢ãã«ãšããŠäžè¬çã«ç¥ãããŠããã¢ã«ãŽãªãºã ã¯Local Interpretable Model-Agnostic Explanations (LIME) [RSG16a]ãšåŒã°ããŠããŸããLIMEã¯ãå ã®ãã©ãã¯ããã¯ã¹ã¢ãã«ãåŠç¿ãããæ倱é¢æ°ãå©çšããŠãããŒã«ã«ãµãã²ãŒãã¢ãã«ãèå³ããäŸã®è¿åã«ãã£ãããããŸããããŒã«ã«ãµãã²ãŒãã¢ãã«ã®æ倱é¢æ°ã¯ããµãã²ãŒãã¢ãã«ãååž°ããéã«ãèå³ããäŸã«è¿ãç¹ãè©äŸ¡ããããéã¿ä»ããããŸããLIMEã®è«æã§ã¯ããµãã²ãŒãã¢ãã«ã®ã¹ããŒã¹åãè¡šèšã«å«ããŠããŸãããããã¯ããŒã«ã«ãµãã²ãŒãã¢ãã«ã®ç¹æ§ã§ã¯ãªããããããã§ã¯äžæŠçããŸãããã£ãŠããµãã²ãŒãã¢ãã«ã®æ倱ã¯æ¬¡ã®ããã«å®çŸ©ãããŸãã
\(w(x', x)\)ã¯èå³ããäŸ\(x\)ã®è¿ãã«ããç¹ã«éã¿ãä»ããéã¿ã«ãŒãã«é¢æ°ã \(\mathcal{l}(\cdot,\cdot)\)ã¯å ã®ãã©ãã¯ããã¯ã¹ã¢ãã«ã®æ倱ã\(\hat{f}(\cdot)\) ã¯ãã©ãã¯ããã¯ã¹ã¢ãã«ã\(\hat{f}_s(\cdot)\)ã¯ãµãã²ãŒãã¢ãã«ãè¡šããŸãã
éã¿é¢æ°ã¯å°ãã¢ãããã¯ã§ããã€ãŸãããŒã¿åã«äŸåããŸããã¹ã«ã©ãŒã©ãã«ã®ååž°ã¿ã¹ã¯ã§ã¯ãã«ãŒãã«é¢æ°ã䜿ããŸãããæ§ã ãªéžæè¢ããããŸããã¬ãŠã·ã¢ã³ãã³ãµã€ã³ããšãããã³ããªã©ã§ããããã¹ãããŒã¿ã§ã¯ãLIMEã®å®è£ ã§ã¯ãããã³ã°è·é¢ã䜿ã£ãŠããŸããããã¯åã«2ã€ã®æååã®éã§äžèŽããªãããã¹ãããŒã¯ã³ã®æ°ãã«ãŠã³ããããã®ã§ããç»åãããã³ã°è·é¢ã䜿ããŸãããã¹ãŒããŒãã¯ã»ã«ã¯äŸãšåãã空çœãšããŸãã
ç¹ \(x'\)ã¯ã©ã®ããã«çæãããã®ã§ãããããé£ç¶å€ã®å Žåã\(x'\)ã¯äžæ§ã«ãµã³ããªã³ã°ãããŸãããç¹åŸŽç©ºéã¯ãã°ãã°éããŠããªããããããã¯éåžžã«é£ããããšã§ããéã¿ä»ãé¢æ°ã«åŸã£ãŠ\(x'\)ããµã³ããªã³ã°ããéã¿ä»ããçç¥ããã°ãïŒããã«åŸã£ãŠãµã³ããªã³ã°ãããã®ã§ïŒéããŠããªãç¹åŸŽç©ºéã®ãããªåé¡ãé¿ããããšãã§ããŸããäžè¬ã«ãé£ç¶ãã¯ãã«ç¹åŸŽç©ºéã§ã¯ãLIMEã¯å°ã䞻芳çã§ããç»åãããã¹ãã®å Žåã\(x'\)ã¯ããŒã¯ã³ïŒåèªïŒããã¹ãã³ã°ãããã¹ãŒããŒãã¯ã»ã«ããŒãåïŒé»åïŒããããšã«ãã圢æãããŸããããã¯ãã·ã£ãŒãã¬ã€å€ã«ããªãè¿ã説æãšãªãã¯ãã§ãå®éãLIMEãã·ã£ãŒãã¬ã€å€ãšåçã§ããããšããããã€ãã®å°ããªè¡šèšæ³ã®å€æŽã§ç€ºãããšãã§ããŸãã
12.8. åå®ä»®æ³Â¶
åå®ä»®æ³ã¯æé©ååé¡ã®è§£ã§ãã\(x\)ãšç°ãªãã©ãã«ãæã¡ã\(x\)ã«ã§ããã ãè¿ãäŸ\(x'\)ãèŠã€ããŸã[WMR17]ãããã¯æ¬¡ã®ããã«å®åŒåã§ããŸãã
\(\hat{f}(x)\)ãã¹ã«ã©ãŒãåºåããååž°åé¡ã§ã¯ãå¶çŽæ¡ä»¶ã\(\hat{f}(x)\)ãããã\(\Delta\)ã ãé¢ãããã«ä¿®æ£ããå¿ èŠããããŸãããã®æé©ååé¡ãæºãã\(x'\) ã¯åå®ä»®æ³ïŒçºçããªãã£ãæ¡ä»¶ãç°ãªãçµæãå°ããã§ãããæ¡ä»¶ïŒãšåŒã°ããŸããéåžžã\(x'\)ãæ±ããããšã¯ãç¡åŸ®åæé©åãšããŠæ±ãããŸãã\(\frac{\partial \hat{f}}{\partial x'}\)ãèšç®ããŠå¶çŽä»ãæé©åããŸãããå®éã«ã¯ã¢ã³ãã«ã«ãæé©åã®ããã« \(\hat{f}(x) \neq \hat{f}(x')\)ãŸã§ã©ã³ãã ã«\(x\)ãæåãããæ¹ãéãå ŽåããããŸããæåž«ãªãåŠç¿ã§æ°ãã\(x'\)ãææ¡ã§ããçæã¢ãã«ã䜿çšããããšãã§ããŸããååã«é¢ããæ®éçãªåå®ä»®æ³çæåšã«ã€ããŠã¯[WSW22]ãåç §ããŠãã ãããååã®ã°ã©ããã¥ãŒã©ã«ãããã¯ãŒã¯ã«ç¹åããææ³ã«ã€ããŠã¯[NB20]ãåç §ããŠãã ããã
è·é¢ã®å®çŸ©ã¯ãLIMEã®èª¬æã®äžã§ãè¿°ã¹ãããã«ãéèŠãªäž»èŠ³çé¢å¿ããšã§ããååæ§é ã®æèã§äœ¿ãããäžè¬çãªè·é¢ã®äŸã¯ãMoragnãã£ã³ã¬ãŒããªã³ã[RH10]ã®ãããªååãã£ã³ã¬ãŒããªã³ã/èšè¿°åã®ã¿ãã¢ãä¿æ°ïŒãŸãã¯Jaccardä¿æ°ïŒã§ãã
åå®ä»®æ³ã¯ã·ã£ãŒãã¬ã€å€ãšæ¯èŒããŠäžã€æ¬ ç¹ããããŸããããã¯å®å šãªèª¬æãäžããŠã¯ãããªãããšã§ããã·ã£ãŒãã¬ã€å€ã¯äºæž¬å€ã®åèšã§ããã説æã®ã©ã®ãããªéšåãèŠéããŠããªãããšãæå³ããŠããŸããäžæ¹ãåå®ä»®æ³ã¯ã§ããã ãå°ãªãç¹åŸŽéãå€ããïŒè·é¢ãæå°åããïŒãããäºæž¬ã«å¯äžããŠããäžéšã®ç¹åŸŽéã«ã€ããŠã®æ å ±ãèŠéããŠããŸãããšããããŸãããŸãã·ã£ãŒãã¬ã€å€ã®å©ç¹ã¯å®çšçã§ããããšã§ãããåå®ä»®æ³ã¯çŽæ¥äœ¿çšããããšãã§ããŸãã
12.8.1. äŸÂ¶
äžèšã®ããããã®äŸã§ãã®ã¢ã€ãã¢ãçŽ æ©ãå®è£ ããããšãã§ããŸããè·é¢ã¯ããã³ã°è·é¢ãšå®çŸ©ããŸãããããŠ\(x'\)ã¯äžã€ã®ã¢ããé žçœ®æã§ãããããåæããŠã©ãã«ã®çœ®æãã§ãããã©ããè©ŠããŠã¿ãŸãããããŸã1åã®çœ®æãè¡ãé¢æ°ãå®çŸ©ããŸãã
def check_cf(x, i, j):
# copy
x = jnp.array(x)
# substitute
x = x.at[:, i].set(0)
x = x.at[:, i, j].set(1)
return predict(x)
check_cf(sm, 0, 0)
DeviceArray(8.552943, dtype=float32)
次ã«ãjnp.meshgrid
ã§å¯èœãªãã¹ãŠã®çœ®æãäœããvmap
ã§å
ã»ã©å®çŸ©ããé¢æ°ãé©å¿ããŸããravel()<jax.numpy.ravel>`ã¯ã€ã³ããã¯ã¹ã®é
åãäžæ¬¡å
ã«ãããããè€éãªvmapãè¡ãå¿
èŠã¯ãããŸããã
ii, jj = jnp.meshgrid(jnp.arange(sl), jnp.arange(21))
ii, jj = ii.ravel(), jj.ravel()
x = jax.vmap(check_cf, in_axes=(None, 0, 0))(sm, ii, jj)
次ã«ãäºæž¬å€ãè² ã«ãªã£ãïŒããªãã¡logitsã0ããå°ããïŒã¢ããé žçœ®æããã¹ãŠè¡šç€ºããŸãã
from IPython.core.display import display, HTML
out = ["<tt>"]
for i, j in zip(ii[jnp.squeeze(x) < 0], jj[jnp.squeeze(x) < 0]):
out.append(f'{s[:i]}<span style="color:red;">{ALPHABET[j]}</span>{s[i+1:]}<br/>')
out.append("</tt>")
display(HTML("".join(out)))
RAGLQF-VGRLLRRLLRRLLR
RAGLAFPVGRLLRRLLRRLLR
RAGLQFAVGRLLRRLLRRLLR
RAGLCFPVGRLLRRLLRRLLR
RAGLQFCVGRLLRRLLRRLLR
RAGLQFPCGRLLRRLLRRLLR
RAGLIFPVGRLLRRLLRRLLR
RAGLQFIVGRLLRRLLRRLLR
RAGLLFPVGRLLRRLLRRLLR
RAGLQFLVGRLLRRLLRRLLR
RAGLFFPVGRLLRRLLRRLLR
RAGLQFFVGRLLRRLLRRLLR
RAGLQFPFGRLLRRLLRRLLR
RAGLPFPVGRLLRRLLRRLLR
RAGLTFPVGRLLRRLLRRLLR
RAGLWFPVGRLLRRLLRRLLR
RAGLVFPVGRLLRRLLRRLLR
RAGLQFVVGRLLRRLLRRLLR
解éã¯ããã€ããããŸãããåºæ¬çã«ã¯ã°ã«ã¿ãã³ãçæ°Žåºãšäº€æãããããããªã³ãVãFãAãCã«çœ®ãæããããšã§ãããããé溶è¡æ§ã«ãããšãã解éã§ããåå®ä»®æ³ãšããŠè¿°ã¹ããšããããã°ã«ã¿ãã³ãçæ°Žæ§ã¢ããé žã«äº€æããã°ããã®ããããã¯é溶è¡æ§ã«ãªãã§ãããããšããããšã«ãªããŸãã
12.9. ç¹å®ã®ã¢ãŒããã¯ãã£ã®èª¬æ¶
äžèšãšåãååãGNNã«ãé©çšãããŸããããããã®ã¢ã€ãã¢ãã°ã©ãäžã§åäœããããã«å€æããæé©ãªæ¹æ³ã«ã€ããŠã¯æ§ã ãªã¢ã€ãã¢ããããŸããGNNã«ç¹åãã解éå¯èœæ§ã®çè«ã«ã€ããŠã¯[AZL21]ããGNNã§èª¬æãæ§ç¯ããããã«å©çšã§ããææ³ã«ã€ããŠã¯[YYGJ20]ãåç §ããŠãã ããã
NLPã¯èª¬æãšè§£éãæ§ç¯ããããã®ç¹å¥ãªã¢ãããŒããååšããããäžã€ã®åéã§ããæè¿ã®èª¿æ»ãšããŠ[MRC21]ãåç §ããŠãã ããã
12.10. ã¢ãã«éäŸåçãªååã®åå®ä»®æ³ã®èª¬æ¶
ååŠã«ãããåå®ä»®æ³ã«é¢é£ããäž»ãªèª²é¡ã¯(12.14)ã®åŸ®åãèšç®ããããšã®é£ããã§ãããããã£ãŠããã®ã¿ã¹ã¯ã«çŠç¹ãåœãŠãã»ãšãã©ã®ææ³ã¯ããããŸã§èŠãŠããããã«ã¢ãã«ã®ã¢ãŒããã¯ãã£ã«ç¹åããŠããŸããWellawatteã[WSW22]ã¯ã¢ãã«ã®ã¢ãŒããã¯ãã£ã«é¢ä¿ãªãååã«å¯ŸããŠãããè¡ãMolecular Model Agnostic Counterfactual ExplanationsïŒMMACEïŒãšããæ¹æ³ãå°å ¥ããŠããŸãã
MMACEæ³ã¯exmol
ããã±ãŒãžã§å®è£
ãããŠããŸããååãšã¢ãã«ãäžãããšãexmol
ã¯å±æçãªåå®ä»®æ³ç説æãçæããããšãã§ããŸããMMACEæ³ã«ã¯2ã€ã®äž»èŠãªã¹ãããããããŸãããŸããäžããããåºæ¬ååãäžå¿ã«å±æçãªååŠç©ºéãå±éããŸãã次ã«ãåãµã³ãã«ç¹ã«ããŠãŒã¶ãæå®ããã¢ãã«ã¢ãŒããã¯ãã£ã®ã©ãã«ãä»ããŸãããããã®ã©ãã«ã¯ãå±æçãªååŠç©ºéã«ãããåå®ä»®æ³ãç¹å®ããããã«äœ¿çšãããŸããMMACEæ³ã¯ã¢ãã«éäŸåçã§ãexmol
ããã±ãŒãžã¯åé¡ãšååž°ã®äž¡æ¹ã®ã¿ã¹ã¯ã«å¯ŸããŠåå®ä»®æ³ãçæããããšãã§ããŸãã
ããã§ã¯ãexmol
ã䜿ã£ãŠã©ã®ããã«åå®ä»®æ³ãçæããã®ãèŠãŠã¿ãŸãããããã®äŸã§ã¯ãååã®èšåºæ¯æ§ãäºæž¬ããã©ã³ãã ãã©ã¬ã¹ãã¢ãã«ãåŠç¿ããŸãããã®äºå€åé¡ã¿ã¹ã¯ã§ã¯ãMoleculeNetã°ã«ãŒã[WRF+18]ãçºè¡šããClassificationç« ã§äœ¿çšããã®ãšåãããŒã¿ã»ããã䜿ããŸãã
12.11. Notebookã®å®è¡Â¶
äžã®ã  ãã¯ãªãã¯ããŠã€ã³ã¿ã©ã¯ãã£ããªGoogle Colabã§ãã®ããŒãžãéå§ããŸãããããèªèº«ã®ç°å¢ã§ãGoogle Colabã§ããããã±ãŒãžã®ã€ã³ã¹ããŒã«é¢ãã詳现ã¯ä»¥äžãåç §ããŠãã ããã
Tip
ããã±ãŒãžãã€ã³ã¹ããŒã«ããã«ã¯ãæ°èŠã»ã«ãäœæããŠæ¬¡ã®ã³ãŒããå®è¡ããŠãã ããã
!pip install exmol jupyter-book matplotlib numpy pandas seaborn sklearn mordred[full] rdkit-pypi
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import rdkit, rdkit.Chem, rdkit.Chem.Draw
from rdkit.Chem.Draw import IPythonConsole
import numpy as np
import mordred, mordred.descriptors
import warnings
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
import exmol
IPythonConsole.ipython_useSVG = True
toxdata = pd.read_csv(
"https://github.com/whitead/dmol-book/raw/master/data/clintox.csv.gz"
)
# make object that can compute descriptors
calc = mordred.Calculator(mordred.descriptors, ignore_3D=True)
# make subsample from pandas df
molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in toxdata.smiles]
# view one molecule to make sure things look good.
molecules[0]
ããŒã¿ãã€ã³ããŒãããããMordred
ããã±ãŒãžã§å
¥åèšè¿°åãçæããŸãã
# Get valid molecules from the sample
valid_mol_idx = [bool(m) for m in molecules]
valid_mols = [m for m in molecules if m]
# Compute molecular descriptors using Mordred
features = calc.pandas(valid_mols, quiet=True)
labels = toxdata[valid_mol_idx].FDA_APPROVED
# Standardize the features
features -= features.mean()
features /= features.std()
# we have some nans in features, likely because std was 0
features = features.values.astype(float)
features_select = np.all(np.isfinite(features), axis=0)
features = features[:, features_select]
print(f"We have {len(features)} features per molecule")
We have 1478 features per molecule
ãã®äŸã§ã¯ãKeras
ã§å®è£
ãããã·ã³ãã«ã§å¯ãªãã¥ãŒã©ã«ãããã¯ãŒã¯åé¡åšã䜿çšããŸãããŸãããã®ã·ã³ãã«ãªåé¡åšãåŠç¿ããããã䜿ã£ãŠexmol
ã®åå®ä»®æ³ã®ã©ãã«ãçæããŸããåŠç¿æžã¿ã¢ãã«ã®æ§èœãæ¹åããããšã§ãããæ£ç¢ºãªçµæãæåŸ
ããããšãã§ããŸãããexmolã®ä»çµã¿ãç解ããã«ã¯ãä»ã®ãšãã以äžã®äŸã§ååã§ãã
# Train and test spit
X_train, X_test, y_train, y_test = train_test_split(
features, labels, test_size=0.2, shuffle=True
)
ft_shape = X_train.shape[-1]
# reshape data
X_train = X_train.reshape(-1, ft_shape)
X_test = X_test.reshape(-1, ft_shape)
ããã§ã¯ã¢ãã«ãæ§ç¯ããŠå®è¡ããŠã¿ãŸãããããã£ãŒãã©ãŒãã³ã°ã®æŠèŠ ç« ã«å¯ãªã¢ãã«ã«é¢ãã詳ããã€ã³ãããã¯ã·ã§ã³ããããŸãã
model = tf.keras.models.Sequential()
model.add(tf.keras.Input(shape=(ft_shape,)))
model.add(tf.keras.layers.Dense(32, activation="relu"))
model.add(tf.keras.layers.Dense(32))
model.add(Dense(1, activation="sigmoid"))
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
# Model training
model.fit(X_train, y_train, epochs=50, batch_size=32, verbose=0)
_, accuracy = model.evaluate(X_test, y_test)
print(f"Model accuracy: {accuracy*100:.2f}%")
1/10 [==>...........................] - ETA: 1s - loss: 0.1194 - accuracy: 0.9688
10/10 [==============================] - 0s 1ms/step - loss: 0.2520 - accuracy: 0.9392
Model accuracy: 93.92%
ç§ãã¡ãäœã£ãã¢ãã«ã®æ£ç¢ºåºŠã¯è¯ãããã§ããïŒ
次ã«ãSMILESåã³/ãŸãã¯SEFLIESã®ååè¡šçŸãåã蟌ã¿ãåŠç¿æžã¿åé¡åšããã©ãã«ã®äºæž¬ãåºåããã©ãããŒé¢æ°ãæžããŸããSELFIESã®è©³ãã説æã¯Deep Learning on Sequencesã®ç« ã«ãããŸãããã®ã©ãããŒé¢æ°ã¯ exmol
ã®exmol.sample_space
ã«å
¥åãšããŠäžããããäžããããããŒã¹ãšãªãååã®åšãã«å±æçãªååŠç©ºéãäœããŸããexmol
ã¯ãSuperfast Traversal, Optimization, Novelty, Exploration and Discovery (STONED)ã¢ã«ãŽãªãºã [NPK+21]ãçæã¢ã«ãŽãªãºã ãšããŠäœ¿çšããŠãå±æ空éãæ¡åŒµããŠãããŸããããŒã¹ãšãªãååãäžãããããšãSTONEDã¢ã«ãŽãªãºã ã¯ååã®SELFIESè¡šçŸãã©ã³ãã ã«å€ç°ãããŸãããããã®å€ç°ã¯æåå眮æãæ¿å
¥ãæ¬ æã§ãã
def model_eval(smiles, selfies):
molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in smiles]
features = calc.pandas(molecules)
features = features.values.astype(float)
features = features[:, features_select]
labels = np.round(model.predict(np.nan_to_num(features).reshape(-1, ft_shape)))
return labels
次ã«ãSTONEDã䜿ã£ãŠãexmol.sample_space
ã§å±æçãªååŠç©ºéããµã³ããªã³ã°ããŠã¿ãŸãããã®äŸã§ã¯ãåŒæ°num_samples
ã§ãµã³ãã«ç©ºéã®å€§ãããå€æŽããŸããããã§éžæããããŒã¹ãšãªãååã¯FDAéæ¿èªååã§ãã
space = exmol.sample_space("C1CC(=O)NC(=O)C1N2CC3=C(C2=O)C=CC=C3N", model_eval);
ãã£ãããµã³ãã«ç©ºéãäœæããããexmol.sample_space
é¢æ°ã䜿ã£ãŠå±æçãªååŠç©ºéã®åå®ä»®æ³ãç¹å®ã§ããŸããååå®ä»®æ³ã¯ãä»å æ
å ±ãå«ãpythonã®dataclass
ã§ãã
exps = exmol.cf_explain(space, 2)
exps[1]
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In [32], line 2
1 exps = exmol.cf_explain(space, 2)
----> 2 exps[1]
IndexError: list index out of range
çæãããåå®ä»®æ³ã¯exmol
ã®ããããã³ãŒãexmol.sample_space
ãšexmol.sample_space
ã䜿ã£ãŠç°¡åã«å¯èŠåã§ããŸããããŒã¹ãšåå®ä»®æ³ã®éã®é¡äŒŒåºŠã¯ECFP4ãã£ã³ã¬ãŒããªã³ãã®ã¿ãã¢ãä¿æ°ã§ããäžäœ3ã€ã®åå®ä»®æ³ãããã«ç€ºããŸãã
exmol.plot_cf(exps, nrows=1)
ããã§éžæããããŒã¹ãšãªãååã¯FDAéæ¿èªã§ããçæãããåå®ä»®æ³ãèŠããšãè€çŽ ç°åŒåºã¯æ¯æ§ã«åœ±é¿ãäžãããšçµè«ã¥ããããšãã§ããŸãããããã£ãŠãæã ã®ã¢ãã«ã«ãããšãè€çŽ ç°åŒåºãå€æŽããããšã§ããŒã¹ãšãªãååã¯éæ¯æ§åããããããããŸããããã®ããšã¯ãåå®ä»®æ³ã®èª¬æãã©ã®ããã«ä¿®æ£ãå ããããšãã§ãããã«ã€ããŠã®å®çšçãªæŽå¯ãäžããçç±ã瀺ããŠããŸãã
æåŸã«ãçæããååŠç©ºéãå¯èŠåããŠã¿ãŸãããïŒ
exmol.plot_space(space, exps)
12.12. ãŸãšã¶
ãã£ãŒãã©ãŒãã³ã°ã¢ãã«ã®è§£éã¯ãã¢ãã«ã®æ£ç¢ºæ§ãä¿èšŒããäºæž¬ã人ã«ãšã£ãŠæçšãªãã®ã«ããããã«å¿ èŠäžå¯æ¬ ã§ããæ³ä»€é å®ã®ããã«èŠæ±ãããããšããããŸãã
ãã¥ãŒã©ã«ãããã¯ãŒã¯ã®è§£éå¯èœæ§ã¯ãããåºç¯ãªãããã¯ã§ããAIã«ããã説æå¯èœæ§ïŒXAIïŒã®äžéšã§ããããã®ãããã¯ã¯ãŸã åæ段éã§ãã
説æã¯ãŸã å®çŸ©ãææ§ã§ãããå€ãã®å Žåãã¢ãã«ã®ç¹åŸŽéã§è¡šçŸãããŸãã
説æã®æŠç¥ãšããŠã¯ãç¹åŸŽééèŠåºŠãåŠç¿ããŒã¿ã®éèŠåºŠãåå®ä»®æ³ãå±æçã«æ£ç¢ºãªãµãã²ãŒãã¢ãã«ãªã©ããããŸãã
ã»ãšãã©ã®èª¬æã¯äŸããšã«ïŒæšè«æã«ïŒçæãããŸãã
æãäœç³»çã§ãããèšç®ã³ã¹ãã®ããã説æã¯ã·ã£ãŒãã¬ã€å€ã§ãã
åå®ä»®æ³ã¯æãçŽæçã§æºè¶³ã®ãã説æãæäŸããŸãããå®å šãªèª¬æã«ã¯ãªããªããããããªããšããæèŠããããŸãã
exmol
ã¯ã¢ãã«éäŸåçãªååã®åå®ä»®æ³ã®èª¬æãçæãããœãããŠã§ã¢ã§ãã
12.13. Cited References¶
- WRF+18
Zhenqin Wu, Bharath Ramsundar, Evan N Feinberg, Joseph Gomes, Caleb Geniesse, Aneesh S Pappu, Karl Leswing, and Vijay Pande. Moleculenet: a benchmark for molecular machine learning. Chemical science, 9(2):513â530, 2018.
- LS04
John D Lee and Katrina A See. Trust in automation: designing for appropriate reliance. Human factors, 46(1):50â80, 2004.
- DVK17(1,2)
Finale Doshi-Velez and Been Kim. Towards a rigorous science of interpretable machine learning. arXiv preprint arXiv:1702.08608, 2017.
- GF17
Bryce Goodman and Seth Flaxman. European Union regulations on algorithmic decision-making and a âright to explanationâ. AI Magazine, 38(3):50â57, 2017.
- Dev19
Organisation for Economic Co-operation and Development. Recommendation of the Council on Artificial Intelligence. 2019. URL: https://legalinstruments.oecd.org/en/instruments/OECD-LEGAL-0449.
- CLG+15
Rich Caruana, Yin Lou, Johannes Gehrke, Paul Koch, Marc Sturm, and Noemie Elhadad. Intelligible models for healthcare: Predicting pneumonia risk and hospital 30-day readmission. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 1721â1730. ACM, 2015.
- Mil19
Tim Miller. Explanation in artificial intelligence: insights from the social sciences. Artificial intelligence, 267:1â38, 2019.
- MSK+19
James W Murdoch, Chandan Singh, Karl Kumbier, Reza Abbasi-Asl, and Bin Yu. Interpretable machine learning: definitions, methods, and applications. eprint arXiv, pages 1â11, 2019. URL: http://arxiv.org/abs/1901.04592.
- MSMuller18
Grégoire Montavon, Wojciech Samek, and Klaus-Robert MÃŒller. Methods for interpreting and understanding deep neural networks. Digital Signal Processing, 73:1â15, 2018.
- BCB14
Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473, 2014.
- AGFW21
Mehrad Ansari, Heta A Gandhi, David G Foster, and Andrew D White. Iterative symbolic regression for learning transport equations. arXiv preprint arXiv:2108.03293, 2021.
- BD00
Lynne Billard and Edwin Diday. Regression analysis for interval-valued data. In Data analysis, classification, and related methods, pages 369â374. Springer, 2000.
- UT20
Silviu-Marian Udrescu and Max Tegmark. Ai feynman: a physics-inspired method for symbolic regression. Science Advances, 6(16):eaay2631, 2020.
- CSGB+20
Miles Cranmer, Alvaro Sanchez Gonzalez, Peter Battaglia, Rui Xu, Kyle Cranmer, David Spergel, and Shirley Ho. Discovering symbolic models from deep learning with inductive biases. Advances in Neural Information Processing Systems, 33:17429â17442, 2020.
- WSW22(1,2,3)
Geemi P Wellawatte, Aditi Seshadri, and Andrew D White. Model agnostic generation of counterfactual explanations for molecules. Chem. Sci., pages â, 2022. URL: http://dx.doi.org/10.1039/D1SC05259D, doi:10.1039/D1SC05259D.
- RSG16a(1,2,3)
Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. " why should i trust you?" explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining, 1135â1144. 2016.
- RSG16b
Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. Model-agnostic interpretability of machine learning. arXiv preprint arXiv:1606.05386, 2016.
- WMR17(1,2,3)
Sandra Wachter, Brent Mittelstadt, and Chris Russell. Counterfactual explanations without opening the black box: automated decisions and the gdpr. Harv. JL & Tech., 31:841, 2017.
- KL17(1,2)
Pang Wei Koh and Percy Liang. Understanding black-box predictions via influence functions. In International Conference on Machine Learning, 1885â1894. PMLR, 2017.
- SML+21
Wojciech Samek, Grégoire Montavon, Sebastian Lapuschkin, Christopher J. Anders, and Klaus-Robert MÃŒller. Explaining deep neural networks and beyond: a review of methods and applications. Proceedings of the IEEE, 109(3):247â278, 2021. doi:10.1109/JPROC.2021.3060483.
- BFL+17
David Balduzzi, Marcus Frean, Lennox Leary, J. P. Lewis, Kurt Wan-Duo Ma, and Brian McWilliams. The shattered gradients problem: if resnets are the answer, then what is the question? In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, 342â350. PMLR, 06â11 Aug 2017. URL: http://proceedings.mlr.press/v70/balduzzi17b.html.
- STY17(1,2)
Mukund Sundararajan, Ankur Taly, and Qiqi Yan. Axiomatic attribution for deep networks. In International Conference on Machine Learning, 3319â3328. PMLR, 2017.
- STK+17(1,2)
Daniel Smilkov, Nikhil Thorat, Been Kim, Fernanda Viégas, and Martin Wattenberg. Smoothgrad: removing noise by adding noise. arXiv preprint arXiv:1706.03825, 2017.
- MBL+19
Grégoire Montavon, Alexander Binder, Sebastian Lapuschkin, Wojciech Samek, and Klaus-Robert MÃŒller. Layer-Wise Relevance Propagation: An Overview, pages 193â209. Springer International Publishing, Cham, 2019. URL: https://link.springer.com/chapter/10.1007%2F978-3-030-28954-6_10.
- vStrumbeljK14(1,2,3,4)
Erik Å trumbelj and Igor Kononenko. Explaining prediction models and individual predictions with feature contributions. Knowledge and information systems, 41(3):647â665, 2014.
- BW21
Rainier Barrett and Andrew D. White. Investigating active learning and meta-learning for iterative peptide design. Journal of Chemical Information and Modeling, 61(1):95â105, 2021. URL: https://doi.org/10.1021/acs.jcim.0c00946, doi:10.1021/acs.jcim.0c00946.
- SLL20
Pascal Sturmfels, Scott Lundberg, and Su-In Lee. Visualizing the impact of feature attribution baselines. Distill, 2020. https://distill.pub/2020/attribution-baselines. doi:10.23915/distill.00022.
- CK18
Kangway V Chuang and Michael J Keiser. Comment on âpredicting reaction performance in cân cross-coupling using machine learningâ. Science, 362(6416):eaat8603, 2018.
- Lip18
Zachary C Lipton. The mythos of model interpretability: in machine learning, the concept of interpretability is both important and slippery. Queue, 16(3):31â57, 2018.
- KWG+18
Been Kim, Martin Wattenberg, Justin Gilmer, Carrie Cai, James Wexler, Fernanda Viegas, and others. Interpretability beyond feature attribution: quantitative testing with concept activation vectors (tcav). In International conference on machine learning, 2668â2677. PMLR, 2018.
- NB20
Danilo Numeroso and Davide Bacciu. Explaining deep graph networks with molecular counterfactuals. arXiv preprint arXiv:2011.05134, 2020.
- RH10
David Rogers and Mathew Hahn. Extended-connectivity fingerprints. Journal of chemical information and modeling, 50(5):742â754, 2010.
- AZL21
Chirag Agarwal, Marinka Zitnik, and Himabindu Lakkaraju. Towards a rigorous theoretical analysis and evaluation of gnn explanations. arXiv preprint arXiv:2106.09078, 2021.
- YYGJ20
Hao Yuan, Haiyang Yu, Shurui Gui, and Shuiwang Ji. Explainability in graph neural networks: a taxonomic survey. arXiv preprint arXiv:2012.15445, 2020.
- MRC21
Andreas Madsen, Siva Reddy, and Sarath Chandar. Post-hoc interpretability for neural nlp: a survey. arXiv preprint arXiv:2108.04840, 2021.
- NPK+21
AkshatKumar Nigam, Robert Pollice, Mario Krenn, Gabriel dos Passos Gomes, and Alán Aspuru-Guzik. Beyond generative models: superfast traversal, optimization, novelty, exploration and discovery (stoned) algorithm for molecules using selfies. Chem. Sci., 12:7079â7090, 2021. URL: http://dx.doi.org/10.1039/D1SC00231G, doi:10.1039/D1SC00231G.