首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >DGL-LifeSci:面向化学和生物领域的 GNN 算法库

DGL-LifeSci:面向化学和生物领域的 GNN 算法库

作者头像
DrugOne
发布2021-02-01 11:15:25
发布2021-02-01 11:15:25
3.3K0
举报
文章被收录于专栏:DrugOneDrugOne

作者 | 王建民

DGL团队发布了以生命科学为重点的软件包DGL-LifeSci。

尝试使用新的DGL--LifeSci并建立Attentive FP模型并可视化其预测结果。

基于深度图学习框架DGL

环境准备

  • PyTorch:深度学习框架
  • DGL:基于PyTorch的库,支持深度学习以处理图形
  • RDKit:用于构建分子图并从字符串表示形式绘制结构式
  • DGL-LifeSci:面向化学和生物领域的 GNN 算法库

DGL安装

conda install -c dglteam dgl #DGLv0.4.3

DGL-LifeSci安装

pip install dgllife

基于Attentive FP可视化训练模型

导入库

代码语言:javascript
复制
import matplotlib.pyplot as plt
import os
from rdkit import Chem
from rdkit.Chem import rdmolops, rdmolfiles
from rdkit import RDPaths
 
import dgl
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from dgl import model_zoo
from dgllife.model import AttentiveFPPredictor
from dgllife.utils import mol_to_complete_graph, mol_to_bigraph
from dgllife.utils import atom_type_one_hot
from dgllife.utils import atom_degree_one_hot
from dgllife.utils import atom_formal_charge
from dgllife.utils import atom_num_radical_electrons
from dgllife.utils import atom_hybridization_one_hot
from dgllife.utils import atom_total_num_H_one_hot
from dgllife.utils import one_hot_encoding
from dgllife.utils import CanonicalAtomFeaturizer
from dgllife.utils import CanonicalBondFeaturizer
from dgllife.utils import ConcatFeaturizer
from dgllife.utils import BaseAtomFeaturizer
from dgllife.utils import BaseBondFeaturizer
from dgllife.utils import one_hot_encoding 
from dgl.data.utils import split_dataset
 
from functools import partial
from sklearn.metrics import roc_auc_score

定义辅助函数

代码来源于dgl/example。

代码语言:javascript
复制
def chirality(atom):
    try:
        return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \
               [atom.HasProp('_ChiralityPossible')]
    except:
        return [False, False] + [atom.HasProp('_ChiralityPossible')]
     
def collate_molgraphs(data):
    """Batching a list of datapoints for dataloader.
    Parameters
    ----------
    data : list of 3-tuples or 4-tuples.
        Each tuple is for a single datapoint, consisting of
        a SMILES, a DGLGraph, all-task labels and optionally
        a binary mask indicating the existence of labels.
    Returns
    -------
    smiles : list
        List of smiles
    bg : BatchedDGLGraph
        Batched DGLGraphs
    labels : Tensor of dtype float32 and shape (B, T)
        Batched datapoint labels. B is len(data) and
        T is the number of total tasks.
    masks : Tensor of dtype float32 and shape (B, T)
        Batched datapoint binary mask, indicating the
        existence of labels. If binary masks are not
        provided, return a tensor with ones.
    """
    assert len(data[0]) in [3, 4], \
        'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))
    if len(data[0]) == 3:
        smiles, graphs, labels = map(list, zip(*data))
        masks = None
    else:
        smiles, graphs, labels, masks = map(list, zip(*data))
 
    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    labels = torch.stack(labels, dim=0)
     
    if masks is None:
        masks = torch.ones(labels.shape)
    else:
        masks = torch.stack(masks, dim=0)
    return smiles, bg, labels, masks

原子和键特征化器

代码语言:javascript
复制
atom_featurizer = BaseAtomFeaturizer(
                 {'hv': ConcatFeaturizer([
                  partial(atom_type_one_hot, allowable_set=[
                          'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
                    encode_unknown=True),
                  partial(atom_degree_one_hot, allowable_set=list(range(6))),
                  atom_formal_charge, atom_num_radical_electrons,
                  partial(atom_hybridization_one_hot, encode_unknown=True),
                  lambda atom: [0], # A placeholder for aromatic information,
                    atom_total_num_H_one_hot, chirality
                 ],
                )})
bond_featurizer = BaseBondFeaturizer({
                                     'he': lambda bond: [0 for _ in range(10)]
    })

加载数据集,rdkit mol对象转换为图对象

带有featurizer的mol_to_bigraph方法将rdkit mol对象转换为图对象。此外,smiles_to_bigraph方法可以将smiles转换为图。

代码语言:javascript
复制
train_mols = Chem.SDMolSupplier('solubility.train.sdf')
train_smi =[Chem.MolToSmiles(m) for m in train_mols]
train_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in train_mols]).reshape(-1,1)
 
test_mols =  Chem.SDMolSupplier('solubility.test.sdf')
test_smi = [Chem.MolToSmiles(m) for m in test_mols]
test_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in test_mols]).reshape(-1,1)
 
train_graph =[mol_to_bigraph(mol,
                           node_featurizer=atom_featurizer, 
                           edge_featurizer=bond_featurizer) for mol in train_mols]
 
test_graph =[mol_to_bigraph(mol,
                           node_featurizer=atom_featurizer, 
                           edge_featurizer=bond_featurizer) for mol in test_mols]

AttentivFp模型

并定义用于训练和测试的数据加载器。

代码语言:javascript
复制
model = AttentiveFPPredictor(node_feat_size=39,
                                  edge_feat_size=10,
                                  num_layers=2,
                                  num_timesteps=2,
                                  graph_feat_size=200,
                                  n_tasks=1,
                                  dropout=0.2)
#model = model.to('cuda:0')
 
train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs)
test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)

定义可视化函数

  • 代码来源于DGL库。
  • DGL模型具有get_node_weight选项,该选项返回图形的node_weight。该模型具有两层GRU,因此以下代码我将0用作时间步长,因此时间步长必须为0或1。
代码语言:javascript
复制
def drawmol(idx, dataset, timestep):
    smiles, graph, _ = dataset[idx]
    print(smiles)
    bg = dgl.batch([graph])
    atom_feats, bond_feats = bg.ndata['hv'], bg.edata['he']
    if torch.cuda.is_available():
        print('use cuda')
        bg.to(torch.device('cuda:0'))
        atom_feats = atom_feats.to('cuda:0')
        bond_feats = bond_feats.to('cuda:0')
    
    _, atom_weights = model(bg, atom_feats, bond_feats, get_node_weight=True)
    assert timestep < len(atom_weights), 'Unexpected id for the readout round'
    atom_weights = atom_weights[timestep]
    min_value = torch.min(atom_weights)
    max_value = torch.max(atom_weights)
    atom_weights = (atom_weights - min_value) / (max_value - min_value)
    
    norm = matplotlib.colors.Normalize(vmin=0, vmax=1.28)
    cmap = cm.get_cmap('bwr')
    plt_colors = cm.ScalarMappable(norm=norm, cmap=cmap)
    atom_colors = {i: plt_colors.to_rgba(atom_weights[i].data.item()) for i in range(bg.number_of_nodes())}
    
    mol = Chem.MolFromSmiles(smiles)
    rdDepictor.Compute2DCoords(mol)
    drawer = rdMolDraw2D.MolDraw2DSVG(280, 280)
    drawer.SetFontSize(1)
    op = drawer.drawOptions()
    
    mol = rdMolDraw2D.PrepareMolForDrawing(mol)
    drawer.DrawMolecule(mol, highlightAtoms=range(bg.number_of_nodes()),
                             highlightBonds=[],
                             highlightAtomColors=atom_colors)
    drawer.FinishDrawing()
    svg = drawer.GetDrawingText()
    svg = svg.replace('svg:', '')
    if torch.cuda.is_available():
        atom_weights = atom_weights.to('cpu')
        
    a = np.array([[0,1]])
    plt.figure(figsize=(9, 1.5))
    img = plt.imshow(a, cmap="bwr")
    plt.gca().set_visible(False)
    cax = plt.axes([0.1, 0.2, 0.8, 0.2])
    plt.colorbar(orientation='horizontal', cax=cax)
    plt.show()
    return (Chem.MolFromSmiles(smiles), atom_weights.data.numpy(), svg)

绘制测试数据集分子

该模型预测溶解度,颜色表示红色是溶解度的积极影响,蓝色是负面影响。

代码语言:javascript
复制
target = test_loader.dataset
 
for i in range(len(target))[:5]:
    mol, aw, svg = drawmol(i, target, 0)
    print(aw.min(), aw.max())
    display(SVG(svg))

参考资料

  1. https://github.com/dmlc/dgl/tree/master/apps/life_sci
  2. https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/attentive_fp.py
  3. https://pubs.acs.org/doi/full/10.1021/acs.jcim.9b00387
  4. https://github.com/awslabs/dgl-lifesci
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-06-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 DrugAI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • DGL-LifeSci安装
  • 基于Attentive FP可视化训练模型
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档