diff --git a/build/lib/torchdrug/__init__.py b/build/lib/torchdrug/__init__.py new file mode 100644 index 00000000..03e5aacc --- /dev/null +++ b/build/lib/torchdrug/__init__.py @@ -0,0 +1,15 @@ +from . import patch +from .data.constant import * + +import sys +import logging + +logger = logging.getLogger("") +logger.setLevel(logging.INFO) +format = logging.Formatter("%(asctime)-10s %(message)s", "%H:%M:%S") + +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter(format) +logger.addHandler(handler) + +__version__ = "0.2.1" diff --git a/build/lib/torchdrug/core/__init__.py b/build/lib/torchdrug/core/__init__.py new file mode 100644 index 00000000..2e0abba2 --- /dev/null +++ b/build/lib/torchdrug/core/__init__.py @@ -0,0 +1,9 @@ +from .core import _MetaContainer, Registry, Configurable, make_configurable +from .engine import Engine +from .meter import Meter +from .logger import LoggerBase, LoggingLogger, WandbLogger + +__all__ = [ + "_MetaContainer", "Registry", "Configurable", + "Engine", "Meter", "LoggerBase", "LoggingLogger", "WandbLogger", +] \ No newline at end of file diff --git a/build/lib/torchdrug/core/core.py b/build/lib/torchdrug/core/core.py new file mode 100644 index 00000000..0227b66b --- /dev/null +++ b/build/lib/torchdrug/core/core.py @@ -0,0 +1,372 @@ +import re +import types +import inspect +from collections import defaultdict +from contextlib import contextmanager + +from decorator import decorator + + +class _MetaContainer(object): + """ + Meta container that maintains meta types about members. + + The meta type of each member is tracked when a member is assigned. + We use a context manager to define the meta types for a bunch of assignment. + + The meta types are stored as a dict in ``instance.meta_dict``, + where keys are member names and values are meta types. + + >>> class MyClass(_MetaContainer): + >>> ... + + >>> instance = MyClass() + >>> with instance.context("important"): + >>> instance.value = 1 + >>> assert instance.meta_dict["value"] == "important" + + Members assigned with :meth:`context(None) ` or without a context won't be tracked. + + >>> instance.random = 0 + >>> assert "random" not in instance.meta_dict + + You can also restrict available meta types by defining a set :attr:`_meta_types` in the derived class. + + .. note:: + + Meta container also supports auto inference of meta types. + This can be enabled by setting :attr:`enable_auto_context` to ``True`` in the derived class. + + Once auto inference is on, any member without an explicit context will be recognized through their name prefix. + For example, ``instance.node_value`` will be recognized as ``node`` if ``node`` is defined in ``meta_types``. + + This may make code hard to maintain. Use with caution. + """ + + _meta_types = set() + enable_auto_context = False + + def __init__(self, meta_dict=None, **kwargs): + if meta_dict is None: + meta_dict = {} + else: + meta_dict = meta_dict.copy() + + self._setattr("_meta_contexts", set()) + self._setattr("meta_dict", meta_dict) + for k, v in kwargs.items(): + self._setattr(k, v) + + @contextmanager + def context(self, type): + """ + Context manager for assigning members with a specific meta type. + """ + if type is not None and self._meta_types and type not in self._meta_types: + raise ValueError("Expect context type in %s, but got `%s`" % (self._meta_types, type)) + self._meta_contexts.add(type) + yield + self._meta_contexts.remove(type) + + def __setattr__(self, key, value): + if hasattr(self, "meta_dict"): + types = self._meta_contexts + if not types and self.enable_auto_context: + for type in self._meta_types: + if key.startswith(type): + types.append(type) + if len(types) > 1: + raise ValueError("Auto context found multiple contexts for key `%s`. " + "If this is desired, set `enable_auto_context` to False " + "and manually specify the context. " % key) + if types: + self.meta_dict[key] = types.copy() + self._setattr(key, value) + + def __delattr__(self, key): + if hasattr(self, "meta_dict") and key in self.meta_dict: + del self.meta_dict[key] + super(_MetaContainer, self).__delattr__(self, key) + + def _setattr(self, key, value): + return super(_MetaContainer, self).__setattr__(key, value) + + @property + def data_dict(self): + """A dict that maps tracked names to members.""" + return {k: getattr(self, k) for k in self.meta_dict} + + def data_by_meta(self, include=None, exclude=None): + """ + Return members based on the specific meta types. + + Parameters: + include (list of string, optional): meta types to include + exclude (list of string, optional): meta types to exclude + + Returns: + (dict, dict): data member dict and meta type dict + """ + if include is None and exclude is None: + return self.data_dict, self.meta_dict + + include = self._standarize_type(include) + exclude = self._standarize_type(exclude) + types = include or set().union(*self.meta_dict.values()) + types = types - exclude + data_dict = {} + meta_dict = {} + for k, v in self.meta_dict.items(): + if v.issubset(types): + data_dict[k] = getattr(self, k) + meta_dict[k] = v + return data_dict, meta_dict + + def _standarize_type(self, types): + if types is None: + types = set() + elif isinstance(types, str): + types = {types} + else: + types = set(types) + return types + + +class Tree(defaultdict): + + def __init__(self): + super(Tree, self).__init__(Tree) + + def flatten(self, prefix=None, result=None): + if prefix is None: + prefix = "" + else: + prefix = prefix + "." + if result is None: + result = {} + for k, v in self.items(): + if isinstance(v, Tree): + v.flatten(prefix + k, result) + else: + result[prefix + k] = v + return result + + +class Registry(object): + """ + Registry class for managing all call-by-name access to objects. + + Typical scenarios: + + 1. Create a model according to a string. + + >>> gcn = R.search("GCN")(128, [128]) + + 2. Register a customize hook to the package. + + >>> @R.register("features.atom.my_feature") + >>> def my_featurizer(atom): + >>> ... + >>> + >>> data.Molecule.from_smiles("C1=CC=CC=C1", atom_feature="my_feature") + """ + + table = Tree() + + def __new__(cls): + raise ValueError("Registry shouldn't be instantiated.") + + @classmethod + def register(cls, name): + """ + Register an object with a canonical name. Hierarchical names are separated by ``.``. + """ + + def wrapper(obj): + entry = cls.table + keys = name.split(".") + for key in keys[:-1]: + entry = entry[key] + if keys[-1] in entry: + raise KeyError("`%s` has already been registered by %s" % (name, entry[keys[-1]])) + + entry[keys[-1]] = obj + obj._registry_key = name + + return obj + + return wrapper + + @classmethod + def get(cls, name): + """ + Get an object with a canonical name. Hierarchical names are separated by ``.``. + """ + entry = cls.table + keys = name.split(".") + for i, key in enumerate(keys): + if key not in entry: + raise KeyError("Can't find `%s` in `%s`" % (key, ".".join(keys[:i]))) + entry = entry[key] + return entry + + @classmethod + def search(cls, name): + """ + Search an object with the given name. The name doesn't need to be canonical. + + For example, we can search ``GCN`` and get the object of ``models.GCN``. + """ + keys = [] + pattern = re.compile(r"\b%s\b" % name) + for k, v in cls.table.flatten().items(): + if pattern.search(k): + keys.append(k) + value = v + if len(keys) == 0: + raise KeyError("Can't find any registered key containing `%s`" % name) + if len(keys) > 1: + keys = ["`%s`" % key for key in keys] + raise KeyError("Ambiguous key `%s`. Found %s" % (name, ", ".join(keys))) + return value + + +class _Configurable(type): + + def config_dict(self): + + def unroll_config_dict(obj): + if isinstance(type(obj), _Configurable): + obj = obj.config_dict() + elif isinstance(obj, (str, bytes)): + return obj + elif isinstance(obj, dict): + return type(obj)({k: unroll_config_dict(v) for k, v in obj.items()}) + elif isinstance(obj, (list, tuple)): + return type(obj)(unroll_config_dict(x) for x in obj) + return obj + + cls = getattr(self, "_registry_key", self.__class__.__name__) + config = {"class": cls} + for k, v in self._config.items(): + config[k] = unroll_config_dict(v) + return config + + @classmethod + def load_config_dict(cls, config): + if cls == _Configurable: + real_cls = Registry.search(config["class"]) + custom_load_func = real_cls.load_config_dict.__func__ != cls.load_config_dict.__func__ + if custom_load_func: + return real_cls.load_config_dict(config) + cls = real_cls + elif getattr(cls, "_registry_key", cls.__name__) != config["class"]: + raise ValueError("Expect config class to be `%s`, but found `%s`" % (cls.__name__, config["class"])) + + new_config = {} + for k, v in config.items(): + if isinstance(v, dict) and "class" in v: + v = _Configurable.load_config_dict(v) + elif isinstance(v, list): + v = [_Configurable.load_config_dict(_v) + if isinstance(_v, dict) and "class" in _v else _v + for _v in v] + if k != "class": + new_config[k] = v + + return cls(**new_config) + + def __new__(typ, *args, **kwargs): + + cls = type.__new__(typ, *args, **kwargs) + + @decorator + def wrapper(init, self, *args, **kwargs): + sig = inspect.signature(init) + func = sig.bind(self, *args, **kwargs) + func.apply_defaults() + config = {} + keys = list(sig.parameters.keys()) + for k, v in zip(keys[1:], func.args[1:]): # exclude self + config[k] = v + config.update(func.kwargs) + for k in getattr(self, "_ignore_args", {}): + config.pop(k) + self._config = dict(config) + return init(self, *args, **kwargs) + + def get_function(method): + if isinstance(method, types.MethodType): + return method.__func__ + return method + + if isinstance(cls.__init__, types.FunctionType): + cls.__init__ = wrapper(cls.__init__) + custom_load_func = hasattr(cls, "load_config_dict") and \ + get_function(cls.load_config_dict) != get_function(typ.load_config_dict) + custom_config_func = hasattr(cls, "config_dict") and \ + get_function(cls.config_dict) != get_function(typ.config_dict) + if not custom_load_func: + cls.load_config_dict = _Configurable.load_config_dict + if not custom_config_func: + cls.config_dict = _Configurable.config_dict + + return cls + + +class Configurable(metaclass=_Configurable): + """ + Class for load/save configuration. + It will automatically record every argument passed to the ``__init__`` function. + + This class is inspired by :meth:`state_dict()` in PyTorch, but designed for hyperparameters. + + Inherit this class to construct a configurable class. + + >>> class MyClass(nn.Module, core.Configurable): + + Note :class:`Configurable` only applies to the current class rather than any derived class. + For example, the following definition only records the arguments of ``MyClass``. + + >>> class DerivedClass(MyClass): + + In order to record the arguments of ``DerivedClass``, explicitly specify the inheritance. + + >>> class DerivedClass(MyClass, core.Configurable): + + To get the configuration of an instance, use :meth:`config_dict()`, + which returns a dict of argument names and values. + If an argument is also an instance of :class:`Configurable`, it will be recursively expanded in the dict. + The configuration dict can be passed to :meth:`load_config_dict()` to create a copy of the instance. + + For classes already registered in :class:`Registry`, + they can be directly created from the :class:`Configurable` class. + This is convenient for building models from configuration files. + + >>> config = models.GCN(128, [128]).config_dict() + >>> gcn = Configurable.load_config_dict(config) + """ + pass + + +def make_configurable(cls, module=None, ignore_args=()): + """ + Make a configurable class out of an existing class. + The configurable class will automatically record every argument passed to its ``__init__`` function. + + Parameters: + cls (type): input class + module (str, optional): bind the output class to this module. + By default, bind to the original module of the input class. + ignore_args (set of str, optional): arguments to ignore in the ``__init__`` function + """ + ignore_args = set(ignore_args) + module = module or cls.__module__ + Metaclass = type(cls) + if issubclass(Metaclass, _Configurable): # already a configurable class + return cls + if Metaclass != type: # already have a meta class + MetaClass = type(_Configurable.__name__, (Metaclass, _Configurable), {}) + else: + MetaClass = _Configurable + return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module}) diff --git a/build/lib/torchdrug/core/engine.py b/build/lib/torchdrug/core/engine.py new file mode 100644 index 00000000..701fec02 --- /dev/null +++ b/build/lib/torchdrug/core/engine.py @@ -0,0 +1,297 @@ +import os +import sys +import logging +from itertools import islice + +import torch +from torch import distributed as dist +from torch import nn +from torch.utils import data as torch_data + +from torchdrug import data, core, utils +from torchdrug.core import Registry as R +from torchdrug.utils import comm, pretty + + +module = sys.modules[__name__] +logger = logging.getLogger(__name__) + + +@R.register("core.Engine") +class Engine(core.Configurable): + """ + General class that handles everything about training and test of a task. + + This class can perform synchronous distributed parallel training over multiple CPUs or GPUs. + To invoke parallel training, launch with one of the following commands. + + 1. Single-node multi-process case. + + .. code-block:: bash + + python -m torch.distributed.launch --nproc_per_node={number_of_gpus} {your_script.py} {your_arguments...} + + 2. Multi-node multi-process case. + + .. code-block:: bash + + python -m torch.distributed.launch --nnodes={number_of_nodes} --node_rank={rank_of_this_node} + --nproc_per_node={number_of_gpus} {your_script.py} {your_arguments...} + + If :meth:`preprocess` is defined by the task, it will be applied to ``train_set``, ``valid_set`` and ``test_set``. + + Parameters: + task (nn.Module): task + train_set (data.Dataset): training set + valid_set (data.Dataset): validation set + test_set (data.Dataset): test set + optimizer (optim.Optimizer): optimizer + scheduler (lr_scheduler._LRScheduler, optional): scheduler + gpus (list of int, optional): GPU ids. By default, CPUs will be used. + For multi-node multi-process case, repeat the GPU ids for each node. + batch_size (int, optional): batch size of a single CPU / GPU + gradient_interval (int, optional): perform a gradient update every n batches. + This creates an equivalent batch size of ``batch_size * gradient_interval`` for optimization. + num_worker (int, optional): number of CPU workers per GPU + logger (str or core.LoggerBase, optional): logger type or logger instance. + Available types are ``logging`` and ``wandb``. + log_interval (int, optional): log every n gradient updates + """ + + def __init__(self, task, train_set, valid_set, test_set, optimizer, scheduler=None, gpus=None, batch_size=1, + gradient_interval=1, num_worker=0, logger="logging", log_interval=100): + self.rank = comm.get_rank() + self.world_size = comm.get_world_size() + self.gpus = gpus + self.batch_size = batch_size + self.gradient_interval = gradient_interval + self.num_worker = num_worker + + if gpus is None: + self.device = torch.device("cpu") + else: + if len(gpus) != self.world_size: + error_msg = "World size is %d but found %d GPUs in the argument" + if self.world_size == 1: + error_msg += ". Did you launch with `python -m torch.distributed.launch`?" + raise ValueError(error_msg % (self.world_size, len(gpus))) + self.device = torch.device(gpus[self.rank % len(gpus)]) + + if self.world_size > 1 and not dist.is_initialized(): + if self.rank == 0: + module.logger.info("Initializing distributed process group") + backend = "gloo" if gpus is None else "nccl" + comm.init_process_group(backend, init_method="env://") + + if hasattr(task, "preprocess"): + if self.rank == 0: + module.logger.warning("Preprocess training set") + # TODO: more elegant implementation + # handle dynamic parameters in optimizer + old_params = list(task.parameters()) + result = task.preprocess(train_set, valid_set, test_set) + if result is not None: + train_set, valid_set, test_set = result + new_params = list(task.parameters()) + if len(new_params) != len(old_params): + optimizer.add_param_group({"params": new_params[len(old_params):]}) + if self.world_size > 1: + task = nn.SyncBatchNorm.convert_sync_batchnorm(task) + buffers_to_ignore = [] + for name, buffer in task.named_buffers(): + if not isinstance(buffer, torch.Tensor): + buffers_to_ignore.append(name) + task._ddp_params_and_buffers_to_ignore = set(buffers_to_ignore) + if self.device.type == "cuda": + task = task.cuda(self.device) + + self.model = task + self.train_set = train_set + self.valid_set = valid_set + self.test_set = test_set + self.optimizer = optimizer + self.scheduler = scheduler + + if isinstance(logger, str): + if logger == "logging": + logger = core.LoggingLogger() + elif logger == "wandb": + logger = core.WandbLogger(project=task.__class__.__name__) + else: + raise ValueError("Unknown logger `%s`" % logger) + self.meter = core.Meter(log_interval=log_interval, silent=self.rank > 0, logger=logger) + self.meter.log_config(self.config_dict()) + + def train(self, num_epoch=1, batch_per_epoch=None): + """ + Train the model. + + If ``batch_per_epoch`` is specified, randomly draw a subset of the training set for each epoch. + Otherwise, the whole training set is used for each epoch. + + Parameters: + num_epoch (int, optional): number of epochs + batch_per_epoch (int, optional): number of batches per epoch + """ + sampler = torch_data.DistributedSampler(self.train_set, self.world_size, self.rank) + dataloader = data.DataLoader(self.train_set, self.batch_size, sampler=sampler, num_workers=self.num_worker) + batch_per_epoch = batch_per_epoch or len(dataloader) + model = self.model + model.split = "train" + if self.world_size > 1: + if self.device.type == "cuda": + model = nn.parallel.DistributedDataParallel(model, device_ids=[self.device], + find_unused_parameters=True) + else: + model = nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) + model.train() + + for epoch in self.meter(num_epoch): + sampler.set_epoch(epoch) + + metrics = [] + start_id = 0 + # the last gradient update may contain less than gradient_interval batches + gradient_interval = min(batch_per_epoch - start_id, self.gradient_interval) + + for batch_id, batch in enumerate(islice(dataloader, batch_per_epoch)): + if self.device.type == "cuda": + batch = utils.cuda(batch, device=self.device) + + loss, metric = model(batch) + if not loss.requires_grad: + raise RuntimeError("Loss doesn't require grad. Did you define any loss in the task?") + loss = loss / gradient_interval + loss.backward() + metrics.append(metric) + + if batch_id - start_id + 1 == gradient_interval: + self.optimizer.step() + self.optimizer.zero_grad() + + metric = utils.stack(metrics, dim=0) + metric = utils.mean(metric, dim=0) + if self.world_size > 1: + metric = comm.reduce(metric, op="mean") + self.meter.update(metric) + + metrics = [] + start_id = batch_id + 1 + gradient_interval = min(batch_per_epoch - start_id, self.gradient_interval) + + if self.scheduler: + self.scheduler.step() + + @torch.no_grad() + def evaluate(self, split, log=True): + """ + Evaluate the model. + + Parameters: + split (str): split to evaluate. Can be ``train``, ``valid`` or ``test``. + log (bool, optional): log metrics or not + + Returns: + dict: metrics + """ + if comm.get_rank() == 0: + logger.warning(pretty.separator) + logger.warning("Evaluate on %s" % split) + test_set = getattr(self, "%s_set" % split) + sampler = torch_data.DistributedSampler(test_set, self.world_size, self.rank) + dataloader = data.DataLoader(test_set, self.batch_size, sampler=sampler, num_workers=self.num_worker) + model = self.model + model.split = split + + model.eval() + preds = [] + targets = [] + for batch in dataloader: + if self.device.type == "cuda": + batch = utils.cuda(batch, device=self.device) + + pred, target = model.predict_and_target(batch) + preds.append(pred) + targets.append(target) + + pred = utils.cat(preds) + target = utils.cat(targets) + if self.world_size > 1: + pred = comm.cat(pred) + target = comm.cat(target) + metric = model.evaluate(pred, target) + if log: + self.meter.log(metric, category="%s/epoch" % split) + + return metric + + def load(self, checkpoint, load_optimizer=True, strict=True): + """ + Load a checkpoint from file. + + Parameters: + checkpoint (file-like): checkpoint file + load_optimizer (bool, optional): load optimizer state or not + strict (bool, optional): whether to strictly check the checkpoint matches the model parameters + """ + if comm.get_rank() == 0: + logger.warning("Load checkpoint from %s" % checkpoint) + checkpoint = os.path.expanduser(checkpoint) + print('I am working here...') + state = torch.load(checkpoint, map_location=self.device) + state["model"].pop("graph") # Made changes as per Issue #89 + state["model"].pop("fact_graph") # Made changes as per Issue #89 + self.model.load_state_dict(state["model"], strict=False) + + if load_optimizer: + self.optimizer.load_state_dict(state["optimizer"]) + for state in self.optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(self.device) + + comm.synchronize() + + def save(self, checkpoint): + """ + Save checkpoint to file. + + Parameters: + checkpoint (file-like): checkpoint file + """ + if comm.get_rank() == 0: + logger.warning("Save checkpoint to %s" % checkpoint) + checkpoint = os.path.expanduser(checkpoint) + if self.rank == 0: + state = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict() + } + torch.save(state, checkpoint) + + comm.synchronize() + + @classmethod + def load_config_dict(cls, config): + """ + Construct an instance from the configuration dict. + """ + if getattr(cls, "_registry_key", cls.__name__) != config["class"]: + raise ValueError("Expect config class to be `%s`, but found `%s`" % (cls.__name__, config["class"])) + + optimizer_config = config.pop("optimizer") + new_config = {} + for k, v in config.items(): + if isinstance(v, dict) and "class" in v: + v = core.Configurable.load_config_dict(v) + if k != "class": + new_config[k] = v + optimizer_config["params"] = new_config["task"].parameters() + new_config["optimizer"] = core.Configurable.load_config_dict(optimizer_config) + + return cls(**new_config) + + @property + def epoch(self): + """Current epoch.""" + return self.meter.epoch_id diff --git a/build/lib/torchdrug/core/logger.py b/build/lib/torchdrug/core/logger.py new file mode 100644 index 00000000..ce69de8c --- /dev/null +++ b/build/lib/torchdrug/core/logger.py @@ -0,0 +1,127 @@ +import logging +import warnings + +from torchdrug.core import Registry as R +from torchdrug.utils import pretty + + +class LoggerBase(object): + """ + Base class for loggers. + + Any custom logger should be derived from this class. + """ + + def log(self, record, step_id, category="train/batch"): + """ + Log a record. + + Parameters: + record (dict): dict of any metric + step_id (int): index of this log step + category (str, optional): log category. + Available types are ``train/batch``, ``train/epoch``, ``valid/epoch`` and ``test/epoch``. + """ + raise NotImplementedError + + def log_config(self, config): + """ + Log a hyperparameter config. + + Parameters: + config (dict): hyperparameter config + """ + raise NotImplementedError + + +@R.register("core.LoggingLogger") +class LoggingLogger(LoggerBase): + """ + Log outputs with the builtin logging module of Python. + + By default, the logs will be printed to the console. To additionally log outputs to a file, + add the following lines in the beginning of your code. + + .. code-block: python + + import logging + + format = logging.Formatter("%(asctime)-10s %(message)s", "%H:%M:%S") + handler = logging.FileHandler("log.txt") + handler.setFormatter(format) + logger = logging.getLogger("") + logger.addHandler(handler) + """ + + def __init__(self): + self.logger = logging.getLogger(__name__) + + def log(self, record, step_id, category="train/batch"): + if category.endswith("batch"): + self.logger.warning(pretty.separator) + elif category.endswith("epoch"): + self.logger.warning(pretty.line) + if category == "train/epoch": + for k in sorted(record.keys()): + self.logger.warning("average %s: %g" % (k, record[k])) + else: + for k in sorted(record.keys()): + self.logger.warning("%s: %g" % (k, record[k])) + + def log_config(self, config): + self.logger.warning(pretty.format(config, compact=True)) + + +@R.register("core.WandbLogger") +class WandbLogger(LoggingLogger): + """ + Log outputs with `Weights and Biases`_ and track the experiment progress. + + Note this class also output logs with the builtin logging module. + + See `wandb.init`_ for more details. + + .. _Weights and Biases: + https://docs.wandb.ai + + .. _wandb.init: + https://docs.wandb.ai/ref/python/init + + Parameters: + project (str, optional): name of the project + name (str, optional): name of this run + dir (str, optional): path to store meta data. Default is `./wandb`. + kwargs: keyword arguments for `wandb.init`_ + """ + + def __init__(self, project=None, name=None, dir=None, **kwargs): + super(WandbLogger, self).__init__() + try: + import wandb + except ModuleNotFoundError: + raise ModuleNotFoundError("Wandb is not found. Please install it with `pip install wandb`") + + if wandb.run is not None: + warnings.warn( + "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse " + "this run. If this is not desired, call `wandb.finish()` or `WandbLogger.finish()` before " + "instantiating `WandbLogger`." + ) + self.run = wandb.run + else: + self.run = wandb.init(project=project, name=name, dir=dir, **kwargs) + + self.run.define_metric("train/batch/*", step_metric="batch", summary="none") + for split in ["train", "valid", "test"]: + self.run.define_metric("%s/epoch/*" % split, step_metric="epoch") + + def log(self, record, step_id, category="train/batch"): + super(WandbLogger, self).log(record, step_id, category) + record = {"%s/%s" % (category, k): v for k, v in record.items()} + step_name = category.split("/")[-1] + record[step_name] = step_id + self.run.log(record) + + def log_config(self, confg_dict): + super(WandbLogger, self).log_config(confg_dict) + self.run.config.update(confg_dict) \ No newline at end of file diff --git a/build/lib/torchdrug/core/meter.py b/build/lib/torchdrug/core/meter.py new file mode 100644 index 00000000..41bd0a49 --- /dev/null +++ b/build/lib/torchdrug/core/meter.py @@ -0,0 +1,124 @@ +import time +import logging +from collections import defaultdict + +import numpy as np +import torch + +from torchdrug import core +from torchdrug.utils import pretty + +logger = logging.getLogger(__name__) + + +class Meter(object): + """ + Meter for recording metrics and training progress. + + Parameters: + log_interval (int, optional): log every n updates + silent (int, optional): surpress all outputs or not + logger (core.LoggerBase, optional): log handler + """ + def __init__(self, log_interval=100, silent=False, logger=None): + self.records = defaultdict(list) + self.log_interval = log_interval + self.epoch2batch = [0] + self.time = [time.time()] + self.epoch_id = 0 + self.batch_id = 0 + self.silent = silent + self.logger = logger + + def log(self, record, category="train/batch"): + """ + Log a record. + + Parameters: + record (dict): dict of any metric + category (str, optional): log category. + Available types are ``train/batch``, ``train/epoch``, ``valid/epoch`` and ``test/epoch``. + """ + if self.silent: + return + + if category.endswith("batch"): + step_id = self.batch_id + elif category.endswith("epoch"): + step_id = self.epoch_id + self.logger.log(record, step_id=step_id, category=category) + + def log_config(self, config): + """ + Log a hyperparameter config. + + Parameters: + config (dict): hyperparameter config + """ + if self.silent: + return + + self.logger.log_config(config) + + def update(self, record): + """ + Update the meter with a record. + + Parameters: + record (dict): dict of any metric + """ + if self.batch_id % self.log_interval == 0: + self.log(record, category="train/batch") + self.batch_id += 1 + + for k, v in record.items(): + if isinstance(v, torch.Tensor): + v = v.item() + self.records[k].append(v) + + def step(self): + """ + Step an epoch for this meter. + + Instead of manually invoking :meth:`step()`, it is suggested to use the following line + + >>> for epoch in meter(num_epoch): + >>> # do something + """ + self.epoch_id += 1 + self.epoch2batch.append(self.batch_id) + self.time.append(time.time()) + index = slice(self.epoch2batch[-2], self.epoch2batch[-1]) + duration = self.time[-1] - self.time[-2] + speed = (self.epoch2batch[-1] - self.epoch2batch[-2]) / duration + if self.silent: + return + + logger.warning("duration: %s" % pretty.time(duration)) + logger.warning("speed: %.2f batch / sec" % speed) + + eta = (self.time[-1] - self.time[self.start_epoch]) \ + / (self.epoch_id - self.start_epoch) * (self.end_epoch - self.epoch_id) + logger.warning("ETA: %s" % pretty.time(eta)) + if torch.cuda.is_available(): + logger.warning("max GPU memory: %.1f MiB" % (torch.cuda.max_memory_allocated() / 1e6)) + torch.cuda.reset_peak_memory_stats() + + record = {} + for k, v in self.records.items(): + record[k] = np.mean(v[index]) + self.log(record, category="train/epoch") + + def __call__(self, num_epoch): + self.start_epoch = self.epoch_id + self.end_epoch = self.start_epoch + num_epoch + + for epoch in range(self.start_epoch, self.end_epoch): + if not self.silent: + logger.warning(pretty.separator) + logger.warning("Epoch %d begin" % epoch) + yield epoch + if not self.silent: + logger.warning(pretty.separator) + logger.warning("Epoch %d end" % epoch) + self.step() diff --git a/build/lib/torchdrug/data/__init__.py b/build/lib/torchdrug/data/__init__.py new file mode 100644 index 00000000..508da327 --- /dev/null +++ b/build/lib/torchdrug/data/__init__.py @@ -0,0 +1,19 @@ +from .dictionary import PerfectHash, Dictionary +from .graph import Graph, PackedGraph, cat +from .molecule import Molecule, PackedMolecule +from .protein import Protein, PackedProtein +from .dataset import MoleculeDataset, ReactionDataset, ProteinDataset, \ + ProteinPairDataset, ProteinLigandDataset, \ + NodeClassificationDataset, KnowledgeGraphDataset, SemiSupervised, \ + semisupervised, key_split, scaffold_split, ordered_scaffold_split +from .dataloader import DataLoader, graph_collate +from . import constant +from . import feature + +__all__ = [ + "Graph", "PackedGraph", "Molecule", "PackedMolecule", "Protein", "PackedProtein", "PerfectHash", "Dictionary", + "MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised", + "ProteinDataset", "ProteinPairDataset", "ProteinLigandDataset", + "semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split", + "DataLoader", "graph_collate", "feature", "constant", +] diff --git a/build/lib/torchdrug/data/constant.py b/build/lib/torchdrug/data/constant.py new file mode 100644 index 00000000..ea3cba97 --- /dev/null +++ b/build/lib/torchdrug/data/constant.py @@ -0,0 +1,52 @@ +import sys + +module = sys.modules[__name__] + +# orderd by perodic table +ATOM_NAME = ["Null", + "Hydrogen", "Helium", "Lithium", "Beryllium", "Boron", "Carbon", "Nitrogen", "Oxygen", "Fluorine", + "Neon", "Sodium", "Magnesium", "Aluminium", "Silicon", "Phosphorus", "Sulfur", "Chlorine", "Argon", + "Potassium", "Calcium", "Scandium", "Titanium", "Vanadium", "Chromium", "Manganese", "Iron", "Cobalt", + "Nickel", "Copper", "Zinc", "Gallium", "Germanium", "Arsenic", "Selenium", "Bromine", "Krypton", + "Rubidium", "Strontium", "Yttrium", "Zirconium", "Niobium", "Molybdenum", "Technetium", "Ruthenium", + "Rhodium", "Palladium", "Silver", "Cadmium", "Indium", "Tin", "Antimony", "Tellurium", "Iodine", + "Xenon", "Cesium", "Barium", "Lanthanum", "Cerium", "Praseodymium", "Neodymium", "Promethium", + "Samarium", "Europium", "Gadolinium", "Terbium", "Dysprosium", "Holmium", "Erbium", "Thulium", + "Ytterbium", "Lutetium", "Hafnium", "Tantalum", "Tungsten", "Rhenium", "Osmium", "Iridium", + "Platinum", "Gold", "Mercury", "Thallium", "Lead", "Bismuth", "Polonium", "Astatine", "Radon", + "Francium", "Radium", "Actinium", "Thorium", "Protactinium", "Uranium", "Neptunium", "Plutonium", + "Americium", "Curium", "Berkelium", "Californium", "Einsteinium", "Fermium", "Mendelevium", + "Nobelium", "Lawrencium", "Rutherfordium", "Dubnium", "Seaborgium", "Bohrium", "Hassium", + "Meitnerium", "Darmstadtium", "Roentgenium", "Copernicium", "Nihonium", "Flerovium", "Moscovium", + "Livermorium", "Tennessine", "Oganesson"] + +ATOM_SYMBOL = ["Null", + "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", + "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", + "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", + "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", + "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", + "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", + "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"] + +# ordered by molecular mass +RESIDUE_NAME = ["Glycine", "Alanine", "Serine", "Proline", "Valine", "Threonine", "Cysteine", "Isoleucine", + "Leucine", "Asparagine", "Aspartic acid", "Glutamine", "Lysine", "Glutamic acid", "Methionine", + "Histidine", "Phenylalanine", "Arginine", "Tyrosine", "Tryptophan"] + +RESIDUE_INITIAL = ["G", "A", "S", "P", "V", "T", "C", "I", "L", "N", "D", "Q", "K", "E", "M", "H", "F", "R", "Y", "W"] + +RESIDUE_ATOM_NAME = ["C", "CA", "CB", "CD", "CD1", "CD2", "CE", "CE1", "CE2", "CE3", "CG", "CG1", "CG2", "CH2", + "CZ", "CZ2", "CZ3", "N", "ND1", "ND2", "NE", "NE1", "NE2", "NH1", "NH2", "NZ", "O", "OD1", + "OD2", "OE1", "OE2", "OG", "OG1", "OH", "OXT", "SD", "SG"] + +NUM_ATOM = len(ATOM_NAME) +NUM_AMINO_ACID = len(RESIDUE_NAME) + +for i, name in enumerate(ATOM_NAME): + if i == 0: + continue + setattr(module, name.upper(), i) + +for i, name in enumerate(RESIDUE_NAME): + setattr(module, name.upper(), i) \ No newline at end of file diff --git a/build/lib/torchdrug/data/dataloader.py b/build/lib/torchdrug/data/dataloader.py new file mode 100644 index 00000000..64cfbe3d --- /dev/null +++ b/build/lib/torchdrug/data/dataloader.py @@ -0,0 +1,104 @@ +from collections import deque +from collections.abc import Mapping, Sequence + +import torch + +from torchdrug import data + + +def graph_collate(batch): + """ + Convert any list of same nested container into a container of tensors. + + For instances of :class:`data.Graph `, they are collated + by :meth:`data.Graph.pack `. + + Parameters: + batch (list): list of samples with the same nested container + """ + elem = batch[0] + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, (str, bytes)): + return batch + elif isinstance(elem, data.Graph): + return elem.pack(batch) + elif isinstance(elem, Mapping): + return {key: graph_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, Sequence): + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('Each element in list of batch should be of equal size') + return [graph_collate(samples) for samples in zip(*batch)] + + raise TypeError("Can't collate data with type `%s`" % type(elem)) + + +class DataLoader(torch.utils.data.DataLoader): + """ + Extended data loader for batching graph structured data. + + See `torch.utils.data.DataLoader`_ for more details. + + .. _torch.utils.data.DataLoader: + https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader + + Parameters: + dataset (Dataset): dataset from which to load the data + batch_size (int, optional): how many samples per batch to load + shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch + sampler (Sampler, optional): sampler that draws single sample from the dataset + batch_sampler (Sampler, optional): sampler that draws a mini-batch of data from the dataset + num_workers (int, optional): how many subprocesses to use for data loading + collate_fn (callable, optional): merge a list of samples into a mini-batch + kwargs: keyword arguments for `torch.utils.data.DataLoader`_ + """ + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, + collate_fn=graph_collate, **kwargs): + super(DataLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, + **kwargs) + + +class DataQueue(torch.utils.data.Dataset): + + def __init__(self): + self.queue = deque() + + def append(self, item): + self.queue.append(item) + + def pop(self): + self.queue.popleft() + + def __getitem__(self, index): + return self.queue[index] + + def __len__(self): + return len(self.deque) + + +class ExperienceReplay(torch.utils.data.DataLoader): + + def __init__(self, cache_size, batch_size=1, shuffle=True, **kwargs): + super(ExperienceReplay, self).__init__(DataQueue(), batch_size, shuffle, **kwargs) + self.cache_size = cache_size + + def update(self, items): + for item in items: + self.dataset.append(item) + while len(self.dataset) > self.cache_size: + self.dataset.pop() + + @property + def cold(self): + return len(self.dataset) < self.cache_size \ No newline at end of file diff --git a/build/lib/torchdrug/data/dataset.py b/build/lib/torchdrug/data/dataset.py new file mode 100644 index 00000000..34e9606b --- /dev/null +++ b/build/lib/torchdrug/data/dataset.py @@ -0,0 +1,1223 @@ +import os +import csv +import math +import lmdb +import pickle +import logging +import warnings +from collections import defaultdict +from collections.abc import Sequence + +from tqdm import tqdm + +import numpy as np + +from rdkit import Chem +from rdkit.Chem.Scaffolds import MurckoScaffold +import torch +from torch.utils import data as torch_data + +from torchdrug import core, data, utils + + +logger = logging.getLogger(__name__) + + +class MoleculeDataset(torch_data.Dataset, core.Configurable): + """ + Molecule dataset. + + Each sample contains a molecule graph, and any number of prediction targets. + """ + + @utils.copy_args(data.Molecule.from_molecule) + def load_smiles(self, smiles_list, targets, transform=None, lazy=False, verbose=0, **kwargs): + """ + Load the dataset from SMILES and targets. + + Parameters: + smiles_list (list of str): SMILES strings + targets (dict of list): prediction targets + transform (Callable, optional): data transformation function + lazy (bool, optional): if lazy mode is used, the molecules are processed in the dataloader. + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. + verbose (int, optional): output verbose level + **kwargs + """ + num_sample = len(smiles_list) + if num_sample > 1000000: + warnings.warn("Preprocessing molecules of a large dataset consumes a lot of CPU memory and time. " + "Use load_smiles(lazy=True) to construct molecules in the dataloader instead.") + for field, target_list in targets.items(): + if len(target_list) != num_sample: + raise ValueError("Number of target `%s` doesn't match with number of molecules. " + "Expect %d but found %d" % (field, num_sample, len(target_list))) + + self.transform = transform + self.lazy = lazy + self.kwargs = kwargs + self.smiles_list = [] + self.data = [] + self.targets = defaultdict(list) + + if verbose: + smiles_list = tqdm(smiles_list, "Constructing molecules from SMILES") + for i, smiles in enumerate(smiles_list): + if not self.lazy or len(self.data) == 0: + mol = Chem.MolFromSmiles(smiles) + if not mol: + logger.debug("Can't construct molecule from SMILES `%s`. Ignore this sample." % smiles) + continue + mol = data.Molecule.from_molecule(mol, **kwargs) + else: + mol = None + self.data.append(mol) + self.smiles_list.append(smiles) + for field in targets: + self.targets[field].append(targets[field][i]) + + @utils.copy_args(load_smiles) + def load_csv(self, csv_file, smiles_field="smiles", target_fields=None, verbose=0, **kwargs): + """ + Load the dataset from a csv file. + + Parameters: + csv_file (str): file name + smiles_field (str, optional): name of the SMILES column in the table. + Use ``None`` if there is no SMILES column. + target_fields (list of str, optional): name of target columns in the table. + Default is all columns other than the SMILES column. + verbose (int, optional): output verbose level + **kwargs + """ + if target_fields is not None: + target_fields = set(target_fields) + + with open(csv_file, "r") as fin: + reader = csv.reader(fin) + if verbose: + reader = iter(tqdm(reader, "Loading %s" % csv_file, utils.get_line_count(csv_file))) + fields = next(reader) + smiles = [] + targets = defaultdict(list) + for values in reader: + if not any(values): + continue + if smiles_field is None: + smiles.append("") + for field, value in zip(fields, values): + if field == smiles_field: + smiles.append(value) + elif target_fields is None or field in target_fields: + value = utils.literal_eval(value) + if value == "": + value = math.nan + targets[field].append(value) + + self.load_smiles(smiles, targets, verbose=verbose, **kwargs) + + def load_pickle(self, pkl_file, verbose=0): + """ + Load the dataset from a pickle file. + + Parameters: + pkl_file (str): file name + verbose (int, optional): output verbose level + """ + with utils.smart_open(pkl_file, "rb") as fin: + num_sample, tasks = pickle.load(fin) + + self.smiles_list = [] + self.data = [] + self.targets = {task: [] for task in tasks} + indexes = range(num_sample) + if verbose: + indexes = tqdm(indexes, "Loading %s" % pkl_file) + for i in indexes: + smiles, mol, values = pickle.load(fin) + self.smiles_list.append(smiles) + self.data.append(mol) + for task, value in zip(tasks, values): + self.targets[task] = value + + def save_pickle(self, pkl_file, verbose=0): + """ + Save the dataset to a pickle file. + + Parameters: + pkl_file (str): file name + verbose (int, optional): output verbose level + """ + with utils.smart_open(pkl_file, "wb") as fout: + num_sample = len(self.data) + tasks = self.targets.keys() + pickle.dump((num_sample, tasks), fout) + + indexes = range(num_sample) + if verbose: + indexes = tqdm(indexes, "Dumping to %s" % pkl_file) + for i in indexes: + values = [v[i] for v in self.targets.values()] + pickle.dump((self.smiles_list[i], self.data[i], values), fout) + + def _standarize_index(self, index, count): + if isinstance(index, slice): + start = index.start or 0 + if start < 0: + start += count + stop = index.stop or count + if stop < 0: + stop += count + step = index.step or 1 + index = range(start, stop, step) + elif not isinstance(index, list): + raise ValueError("Unknown index `%s`" % index) + return index + + def get_item(self, index): + if getattr(self, "lazy", False): + # TODO: what if the smiles is invalid here? + item = {"graph": data.Molecule.from_smiles(self.smiles_list[index], **self.kwargs)} + else: + item = {"graph": self.data[index]} + item.update({k: v[index] for k, v in self.targets.items()}) + if self.transform: + item = self.transform(item) + return item + + def __getitem__(self, index): + if isinstance(index, int): + return self.get_item(index) + + index = self._standarize_index(index, len(self)) + return [self.get_item(i) for i in index] + + @property + def tasks(self): + """List of tasks.""" + return list(self.targets.keys()) + + @property + def node_feature_dim(self): + """Dimension of node features.""" + return self.data[0].node_feature.shape[-1] + + @property + def edge_feature_dim(self): + """Dimension of edge features.""" + return self.data[0].edge_feature.shape[-1] + + @property + def num_atom_type(self): + """Number of different atom types.""" + return len(self.atom_types) + + @property + def num_bond_type(self): + """Number of different bond types.""" + return len(self.bond_types) + + @utils.cached_property + def atom_types(self): + """All atom types.""" + atom_types = set() + + if getattr(self, "lazy", False): + warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") + for smiles in self.smiles_list: + graph = data.Molecule.from_smiles(smiles, **self.kwargs) + atom_types.update(graph.atom_type.tolist()) + else: + for graph in self.data: + atom_types.update(graph.atom_type.tolist()) + + return sorted(atom_types) + + @utils.cached_property + def bond_types(self): + """All bond types.""" + bond_types = set() + + if getattr(self, "lazy", False): + warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") + for smiles in self.smiles_list: + graph = data.Molecule.from_smiles(smiles, **self.kwargs) + bond_types.update(graph.edge_list[:, 2].tolist()) + else: + for graph in self.data: + bond_types.update(graph.edge_list[:, 2].tolist()) + + return sorted(bond_types) + + def __len__(self): + return len(self.data) + + def __repr__(self): + lines = [ + "#sample: %d" % len(self), + "#task: %d" % len(self.tasks), + ] + return "%s(\n %s\n)" % (self.__class__.__name__, "\n ".join(lines)) + + +class ReactionDataset(MoleculeDataset, core.Configurable): + """ + Chemical reaction dataset. + + Each sample contains two molecule graphs, and any number of prediction targets. + """ + + @utils.copy_args(data.Molecule.from_molecule) + def load_smiles(self, smiles_list, targets, transform=None, verbose=0, **kwargs): + """ + Load the dataset from SMILES and targets. + + Parameters: + smiles_list (list of str): SMILES strings + targets (dict of list): prediction targets + transform (Callable, optional): data transformation function + verbose (int, optional): output verbose level + **kwargs + """ + num_sample = len(smiles_list) + for field, target_list in targets.items(): + if len(target_list) != num_sample: + raise ValueError("Number of target `%s` doesn't match with number of molecules. " + "Expect %d but found %d" % (field, num_sample, len(target_list))) + + self.smiles_list = [] + self.data = [] + self.targets = defaultdict(list) + + if verbose: + smiles_list = tqdm(smiles_list, "Constructing molecules from SMILES") + for i, smiles in enumerate(smiles_list): + smiles_reactant, agent, smiles_product = smiles.split(">") + mols = [] + for _smiles in [smiles_reactant, smiles_product]: + mol = Chem.MolFromSmiles(_smiles) + if not mol: + logger.debug("Can't construct molecule from SMILES `%s`. Ignore this sample." % _smiles) + break + mol = data.Molecule.from_molecule(mol, **kwargs) + mols.append(mol) + else: + self.data.append(mols) + self.smiles_list.append(smiles) + for field in targets: + self.targets[field].append(targets[field][i]) + self.transform = transform + + @property + def node_feature_dim(self): + """Dimension of node features.""" + return self.data[0][0].node_feature.shape[-1] + + @property + def edge_feature_dim(self): + """Dimension of edge features.""" + return self.data[0][0].edge_feature.shape[-1] + + @property + def num_atom_type(self): + """Number of different atom types.""" + return len(self.atom_types) + + @property + def num_bond_type(self): + """Number of different bond types.""" + return len(self.bond_types) + + @utils.cached_property + def atom_types(self): + """All atom types.""" + atom_types = set() + for graphs in self.data: + for graph in graphs: + atom_types.update(graph.atom_type.tolist()) + return sorted(atom_types) + + @utils.cached_property + def bond_types(self): + """All bond types.""" + bond_types = set() + for graphs in self.data: + for graph in graphs: + bond_types.update(graph.edge_list[:, 2].tolist()) + return sorted(bond_types) + + def __len__(self): + return len(self.data) + + +class NodeClassificationDataset(torch_data.Dataset, core.Configurable): + """ + Node classification dataset. + + The whole dataset contains one graph, where each node has its own node feature and label. + """ + + def load_tsv(self, node_file, edge_file, verbose=0): + """ + Load the edge list from a tsv file. + + Parameters: + node_file (str): node feature and label file + edge_file (str): edge list file + verbose (int, optional): output verbose level + """ + inv_node_vocab = {} + inv_label_vocab = {} + node_feature = [] + node_label = [] + + with open(node_file, "r") as fin: + reader = csv.reader(fin, delimiter="\t") + if verbose: + reader = tqdm(reader, "Loading %s" % node_file, utils.get_line_count(node_file)) + for tokens in reader: + node_token = tokens[0] + feature_tokens = tokens[1: -1] + label_token = tokens[-1] + inv_node_vocab[node_token] = len(inv_node_vocab) + if label_token not in inv_label_vocab: + inv_label_vocab[label_token] = len(inv_label_vocab) + feature = [utils.literal_eval(f) for f in feature_tokens] + label = inv_label_vocab[label_token] + node_feature.append(feature) + node_label.append(label) + + edge_list = [] + + with open(edge_file, "r") as fin: + reader = csv.reader(fin, delimiter="\t") + if verbose: + reader = tqdm(reader, "Loading %s" % edge_file, utils.get_line_count(edge_file)) + for tokens in reader: + h_token, t_token = tokens + if h_token not in inv_node_vocab: + inv_node_vocab[h_token] = len(inv_node_vocab) + h = inv_node_vocab[h_token] + if t_token not in inv_node_vocab: + inv_node_vocab[t_token] = len(inv_node_vocab) + t = inv_node_vocab[t_token] + edge_list.append((h, t)) + + self.load_edge(edge_list, node_feature, node_label, inv_node_vocab=inv_node_vocab, + inv_label_vocab=inv_label_vocab) + + def load_edge(self, edge_list, node_feature, node_label, node_vocab=None, inv_node_vocab=None, label_vocab=None, + inv_label_vocab=None): + node_vocab, inv_node_vocab = self._standarize_vocab(node_vocab, inv_node_vocab) + label_vocab, inv_label_vocab = self._standarize_vocab(label_vocab, inv_label_vocab) + + self.num_labeled_node = len(node_feature) + if len(node_vocab) > len(node_feature): + logger.warning("Missing features & labels for %d / %d nodes" % + (len(node_vocab) - len(node_feature), len(node_vocab))) + dummy_label = 0 + dummy_feature = [0] * len(node_feature[0]) + node_label += [dummy_label] * (len(node_vocab) - len(node_feature)) + node_feature += [dummy_feature] * (len(node_vocab) - len(node_feature)) + + self.graph = data.Graph(edge_list, num_node=len(node_vocab), node_feature=node_feature) + with self.graph.node(): + self.graph.node_label = torch.as_tensor(node_label) + self.node_vocab = node_vocab + self.inv_node_vocab = inv_node_vocab + self.label_vocab = label_vocab + self.inv_node_vocab = inv_label_vocab + + def _standarize_vocab(self, vocab, inverse_vocab): + if vocab is not None: + if isinstance(vocab, dict): + assert set(vocab.keys()) == set(range(len(vocab))), "Vocab keys should be consecutive numbers" + vocab = [vocab[k] for k in range(len(vocab))] + if inverse_vocab is None: + inverse_vocab = {v: i for i, v in enumerate(vocab)} + if inverse_vocab is not None: + assert set(inverse_vocab.values()) == set(range(len(inverse_vocab))), \ + "Inverse vocab values should be consecutive numbers" + if vocab is None: + vocab = sorted(inverse_vocab, key=lambda k: inverse_vocab[k]) + return vocab, inverse_vocab + + @property + def num_node(self): + """Number of nodes.""" + return self.graph.num_node + + @property + def num_edge(self): + """Number of edges.""" + return self.graph.num_edge + + @property + def node_feature_dim(self): + """Dimension of node features.""" + return self.graph.node_feature.shape[-1] + + def __getitem__(self, index): + return { + "node_index": index, + "label": self.graph.node_label[index] + } + + def __len__(self): + return self.num_labeled_node + + def __repr__(self): + lines = [ + "#node: %d" % self.num_node, + "#edge: %d" % self.num_edge, + "#class: %d" % len(self.label_vocab), + ] + return "%s(\n %s\n)" % (self.__class__.__name__, "\n ".join(lines)) + + +class KnowledgeGraphDataset(torch_data.Dataset, core.Configurable): + """ + Knowledge graph dataset. + + The whole dataset contains one knowledge graph. + """ + + def load_triplet(self, triplets, entity_vocab=None, relation_vocab=None, inv_entity_vocab=None, + inv_relation_vocab=None): + """ + Load the dataset from triplets. + The mapping between indexes and tokens is specified through either vocabularies or inverse vocabularies. + + Parameters: + triplets (array_like): triplets of shape :math:`(n, 3)` + entity_vocab (dict of str, optional): maps entity indexes to tokens + relation_vocab (dict of str, optional): maps relation indexes to tokens + inv_entity_vocab (dict of str, optional): maps tokens to entity indexes + inv_relation_vocab (dict of str, optional): maps tokens to relation indexes + """ + entity_vocab, inv_entity_vocab = self._standarize_vocab(entity_vocab, inv_entity_vocab) + relation_vocab, inv_relation_vocab = self._standarize_vocab(relation_vocab, inv_relation_vocab) + + num_node = len(entity_vocab) if entity_vocab else None + num_relation = len(relation_vocab) if relation_vocab else None + self.graph = data.Graph(triplets, num_node=num_node, num_relation=num_relation) + self.entity_vocab = entity_vocab + self.relation_vocab = relation_vocab + self.inv_entity_vocab = inv_entity_vocab + self.inv_relation_vocab = inv_relation_vocab + + def load_tsv(self, tsv_file, verbose=0): + """ + Load the dataset from a tsv file. + + Parameters: + tsv_file (str): file name + verbose (int, optional): output verbose level + """ + inv_entity_vocab = {} + inv_relation_vocab = {} + triplets = [] + + with open(tsv_file, "r") as fin: + reader = csv.reader(fin, delimiter="\t") + if verbose: + reader = tqdm(reader, "Loading %s" % tsv_file) + for tokens in reader: + h_token, r_token, t_token = tokens + if h_token not in inv_entity_vocab: + inv_entity_vocab[h_token] = len(inv_entity_vocab) + h = inv_entity_vocab[h_token] + if r_token not in inv_relation_vocab: + inv_relation_vocab[r_token] = len(inv_relation_vocab) + r = inv_relation_vocab[r_token] + if t_token not in inv_entity_vocab: + inv_entity_vocab[t_token] = len(inv_entity_vocab) + t = inv_entity_vocab[t_token] + triplets.append((h, t, r)) + + self.load_triplet(triplets, inv_entity_vocab=inv_entity_vocab, inv_relation_vocab=inv_relation_vocab) + + def load_tsvs(self, tsv_files, verbose=0): + """ + Load the dataset from multiple tsv files. + + Parameters: + tsv_files (list of str): list of file names + verbose (int, optional): output verbose level + """ + inv_entity_vocab = {} + inv_relation_vocab = {} + triplets = [] + num_samples = [] + + for tsv_file in tsv_files: + with open(tsv_file, "r") as fin: + reader = csv.reader(fin, delimiter="\t") + if verbose: + reader = tqdm(reader, "Loading %s" % tsv_file, utils.get_line_count(tsv_file)) + + num_sample = 0 + for tokens in reader: + h_token, r_token, t_token = tokens + if h_token not in inv_entity_vocab: + inv_entity_vocab[h_token] = len(inv_entity_vocab) + h = inv_entity_vocab[h_token] + if r_token not in inv_relation_vocab: + inv_relation_vocab[r_token] = len(inv_relation_vocab) + r = inv_relation_vocab[r_token] + if t_token not in inv_entity_vocab: + inv_entity_vocab[t_token] = len(inv_entity_vocab) + t = inv_entity_vocab[t_token] + triplets.append((h, t, r)) + num_sample += 1 + num_samples.append(num_sample) + + self.load_triplet(triplets, inv_entity_vocab=inv_entity_vocab, inv_relation_vocab=inv_relation_vocab) + self.num_samples = num_samples + + def _standarize_vocab(self, vocab, inverse_vocab): + if vocab is not None: + if isinstance(vocab, dict): + assert set(vocab.keys()) == set(range(len(vocab))), "Vocab keys should be consecutive numbers" + vocab = [vocab[k] for k in range(len(vocab))] + if inverse_vocab is None: + inverse_vocab = {v: i for i, v in enumerate(vocab)} + if inverse_vocab is not None: + assert set(inverse_vocab.values()) == set(range(len(inverse_vocab))), \ + "Inverse vocab values should be consecutive numbers" + if vocab is None: + vocab = sorted(inverse_vocab, key=lambda k: inverse_vocab[k]) + return vocab, inverse_vocab + + @property + def num_entity(self): + """Number of entities.""" + return self.graph.num_node + + @property + def num_triplet(self): + """Number of triplets.""" + return self.graph.num_edge + + @property + def num_relation(self): + """Number of relations.""" + return self.graph.num_relation + + def __getitem__(self, index): + return self.graph.edge_list[index] + + def __len__(self): + return self.graph.num_edge + + def __repr__(self): + lines = [ + "#entity: %d" % self.num_entity, + "#relation: %d" % self.num_relation, + "#triplet: %d" % self.num_triplet, + ] + return "%s(\n %s\n)" % (self.__class__.__name__, "\n ".join(lines)) + + +class ProteinDataset(MoleculeDataset, core.Configurable): + """ + Protein dataset. + + Each sample contains a protein graph, and any number of prediction targets. + """ + + @utils.copy_args(data.Protein.from_sequence) + def load_sequence(self, sequences, targets, attributes=None, transform=None, lazy=False, verbose=0, **kwargs): + """ + Load the dataset from protein sequences and targets. + + Parameters: + sequences (list of str): protein sequence strings + targets (dict of list): prediction targets + attributes (dict of list): protein-level attributes + transform (Callable, optional): protein sequence transformation function + lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader. + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. + verbose (int, optional): output verbose level + **kwargs + """ + num_sample = len(sequences) + if num_sample > 1000000: + warnings.warn("Preprocessing proteins of a large dataset consumes a lot of CPU memory and time. " + "Use load_sequence(lazy=True) to construct molecules in the dataloader instead.") + for field, target_list in targets.items(): + if len(target_list) != num_sample: + raise ValueError("Number of target `%s` doesn't match with number of molecules. " + "Expect %d but found %d" % (field, num_sample, len(target_list))) + + self.transform = transform + self.lazy = lazy + self.kwargs = kwargs + self.sequences = [] + self.data = [] + self.targets = defaultdict(list) + + if verbose: + sequences = tqdm(sequences, "Constructing proteins from sequences") + for i, sequence in enumerate(sequences): + if not self.lazy or len(self.data) == 0: + protein = data.Protein.from_sequence(sequence, **kwargs) + else: + protein = None + if attributes is not None: + with protein.graph(): + for field in attributes: + setattr(protein, field, attributes[field][i]) + self.data.append(protein) + self.sequences.append(sequence) + for field in targets: + self.targets[field].append(targets[field][i]) + + @utils.copy_args(load_sequence) + def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, number_field="num_examples", + transform=None, lazy=False, verbose=0, **kwargs): + """ + Load the dataset from lmdb files. + + Parameters: + lmdb_files (list of str): list of lmdb files + sequence_field (str, optional): name of the field of protein sequence in lmdb files + target_fields (list of str, optional): name of target fields in lmdb files + number_field (str, optional): name of the field of sample count in lmdb files + transform (Callable, optional): protein sequence transformation function + lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader. + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. + verbose (int, optional): output verbose level + **kwargs + """ + if target_fields is not None: + target_fields = set(target_fields) + + sequences = [] + num_samples = [] + targets = defaultdict(list) + for lmdb_file in lmdb_files: + env = lmdb.open(lmdb_file, readonly=True, lock=False, readahead=False, meminit=False) + with env.begin(write=False) as txn: + num_sample = pickle.loads(txn.get(number_field.encode())) + for i in range(num_sample): + item = pickle.loads(txn.get(str(i).encode())) + sequences.append(item[sequence_field]) + if target_fields: + for field in target_fields: + value = item[field] + if isinstance(value, np.ndarray) and value.size == 1: + value = value.item() + targets[field].append(value) + num_samples.append(num_sample) + + self.load_sequence(sequences, targets, attributes=None, transform=transform, + lazy=lazy, verbose=verbose, **kwargs) + self.num_samples = num_samples + + @utils.copy_args(data.Protein.from_molecule) + def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, **kwargs): + """ + Load the dataset from pdb files. + + Parameters: + pdb_files (list of str): pdb file names + transform (Callable, optional): protein sequence transformation function + lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader. + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. + verbose (int, optional): output verbose level + **kwargs + """ + num_sample = len(pdb_files) + if num_sample > 1000000: + warnings.warn("Preprocessing proteins of a large dataset consumes a lot of CPU memory and time. " + "Use load_pdbs(lazy=True) to construct molecules in the dataloader instead.") + + self.transform = transform + self.lazy = lazy + self.kwargs = kwargs + self.data = [] + self.pdb_files = [] + self.sequences = [] + + if verbose: + pdb_files = tqdm(pdb_files, "Constructing proteins from pdbs") + for i, pdb_file in enumerate(pdb_files): + if not lazy or i == 0: + mol = Chem.MolFromPDBFile(pdb_file) + if not mol: + logger.debug("Can't construct molecule from pdb file `%s`. Ignore this sample." % pdb_file) + continue + protein = data.Protein.from_molecule(mol, **kwargs) + if not protein: + logger.debug("Can't construct protein from pdb file `%s`. Ignore this sample." % pdb_file) + continue + else: + protein = None + if hasattr(protein, "residue_feature"): + with protein.residue(): + protein.residue_feature = protein.residue_feature.to_sparse() + self.data.append(protein) + self.pdb_files.append(pdb_file) + self.sequences.append(protein.to_sequence() if protein else None) + + @utils.copy_args(load_sequence) + def load_fasta(self, fasta_file, verbose=0, **kwargs): + """ + Load the dataset from a fasta file. + + Parameters: + fasta_file (str): file name + verbose (int, optional): output verbose level + **kwargs + """ + with open(fasta_file, "r") as fin: + if verbose: + fin = tqdm(fin, "Loading %s" % fasta_file, utils.get_line_count(fasta_file)) + sequences = [] + lines = [] + for line in fin: + line = line.strip() + if line.startswith(">") and lines: + sequence = "".join(lines) + sequences.append(sequence) + lines = [] + else: + lines.append(line) + if lines: + sequence = "".join(lines) + sequences.append(sequence) + + return self.load_sequence(sequences, verbose=verbose, **kwargs) + + @utils.copy_args(data.Protein.from_molecule) + def load_pickle(self, pkl_file, transform=None, lazy=False, verbose=0, **kwargs): + """ + Load the dataset from a pickle file. + + Parameters: + pkl_file (str): file name + transform (Callable, optional): protein sequence transformation function + lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader. + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. + verbose (int, optional): output verbose level + **kwargs + """ + with utils.smart_open(pkl_file, "rb") as fin: + num_sample = pickle.load(fin) + + self.transform = transform + self.lazy = lazy + self.kwargs = kwargs + self.sequences = [] + self.pdb_files = [] + self.data = [] + indexes = range(num_sample) + if verbose: + indexes = tqdm(indexes, "Loading %s" % pkl_file) + for i in indexes: + pdb_file, sequence, protein = pickle.load(fin) + self.sequences.append(sequence) + self.pdb_files.append(pdb_file) + self.data.append(protein) + + def save_pickle(self, pkl_file, verbose=0): + with utils.smart_open(pkl_file, "wb") as fout: + num_sample = len(self.data) + pickle.dump(num_sample, fout) + + indexes = range(num_sample) + if verbose: + indexes = tqdm(indexes, "Dumping to %s" % pkl_file) + for i in indexes: + pdb_dir, pdb_name = os.path.split(self.pdb_files[i]) + split = os.path.basename(pdb_dir) + pdb_file = os.path.join(split, pdb_name) + pickle.dump((pdb_file, self.sequences[i], self.data[i]), fout) + + @property + def residue_feature_dim(self): + """Dimension of residue features.""" + return self.data[0].residue_feature.shape[-1] + + +class ProteinPairDataset(ProteinDataset, core.Configurable): + """ + Protein pair dataset. + + Each sample contains two protein graphs, and any number of prediction targets. + """ + + @utils.copy_args(data.Protein.from_sequence) + def load_sequence(self, sequences, targets, attributes=None, transform=None, lazy=False, verbose=0, **kwargs): + """ + Load the dataset from protein sequences and targets. + + Parameters: + sequences (list of list of str): protein sequence string pairs + targets (dict of list): prediction targets + attributes (dict of list): protein-level attributes + transform (Callable, optional): protein sequence transformation function + lazy (bool, optional): if lazy mode is used, the protein pairs are processed in the dataloader. + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. + verbose (int, optional): output verbose level + **kwargs + """ + num_sample = len(sequences) + if num_sample > 1000000: + warnings.warn("Preprocessing proteins of a large dataset consumes a lot of CPU memory and time. " + "Use load_sequence(lazy=True) to construct molecules in the dataloader instead.") + for field, target_list in targets.items(): + if len(target_list) != num_sample: + raise ValueError("Number of target `%s` doesn't match with number of molecules. " + "Expect %d but found %d" % (field, num_sample, len(target_list))) + + self.transform = transform + self.lazy = lazy + self.kwargs = kwargs + self.sequences = [] + self.data = [] + self.targets = defaultdict(list) + + if verbose: + sequences = tqdm(sequences, "Constructing proteins from sequences") + for i, sequence in enumerate(sequences): + if not self.lazy or len(self.data) == 0: + proteins = [data.Protein.from_sequence(seq, **kwargs) for seq in sequence] + else: + proteins = None + if attributes is not None: + for protein in proteins: + with protein.graph(): + for field in attributes: + setattr(protein, field, attributes[field][i]) + self.data.append(proteins) + self.sequences.append(sequence) + for field in targets: + self.targets[field].append(targets[field][i]) + + @utils.copy_args(load_sequence) + def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, number_field="num_examples", + transform=None, lazy=False, verbose=0, **kwargs): + """ + Load the dataset from lmdb files. + + Parameters: + lmdb_files (list of str): file names + sequence_field (str or list of str, optional): names of the fields of protein sequence in lmdb files + target_fields (list of str, optional): name of target fields in lmdb files + number_field (str, optional): name of the field of sample count in lmdb files + transform (Callable, optional): protein sequence transformation function + lazy (bool, optional): if lazy mode is used, the protein pairs are processed in the dataloader. + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. + verbose (int, optional): output verbose level + **kwargs + """ + if target_fields is not None: + target_fields = set(target_fields) + else: + target_fields = set() + if not isinstance(sequence_field, Sequence): + sequence_field = [sequence_field] + + sequences = [] + num_samples = [] + targets = defaultdict(list) + for lmdb_file in lmdb_files: + env = lmdb.open(lmdb_file, readonly=True, lock=False, readahead=False, meminit=False) + with env.begin(write=False) as txn: + num_sample = pickle.loads(txn.get(number_field.encode())) + for i in range(num_sample): + item = pickle.loads(txn.get(str(i).encode())) + sequences.append([item[field] for field in sequence_field]) + for field in target_fields: + value = item[field] + if isinstance(value, np.ndarray) and value.size == 1: + value = value.item() + targets[field].append(value) + num_samples.append(num_sample) + + self.load_sequence(sequences, targets, transform=transform, lazy=lazy, verbose=verbose, **kwargs) + self.num_samples = num_samples + + @property + def node_feature_dim(self): + """Dimension of node features.""" + return self.data[0][0].node_feature.shape[-1] + + @property + def residue_feature_dim(self): + """Dimension of residue features.""" + return self.data[0][0].residue_feature.shape[-1] + + +class ProteinLigandDataset(ProteinDataset, core.Configurable): + """ + Protein-ligand dataset. + + Each sample contains a protein graph and a molecule graph, and any number of prediction targets. + """ + + @utils.copy_args(data.Protein.from_sequence) + def load_sequence(self, sequences, smiles, targets, num_samples, attributes=None, transform=None, + lazy=False, verbose=0, **kwargs): + """ + Load the dataset from protein sequences, ligand SMILES strings and targets. + + Parameters: + sequences (list of str): protein sequence strings + smiles (list of str): ligand SMILES strings + targets (dict of list): prediction targets + num_samples (list of int): numbers of protein-ligand pairs in all splits + attributes (dict of list): protein-level attributes + transform (Callable, optional): protein sequence transformation function + lazy (bool, optional): if lazy mode is used, the protein-ligand pairs are processed in the dataloader. + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. + verbose (int, optional): output verbose level + **kwargs + """ + num_sample = len(sequences) + if num_sample > 1000000: + warnings.warn("Preprocessing proteins of a large dataset consumes a lot of CPU memory and time. " + "Use load_sequence(lazy=True) to construct molecules in the dataloader instead.") + if len(smiles) != len(sequences): + raise ValueError("Number of smiles doesn't match with number of proteins. " + "Expect %d but found %d" % (num_sample, len(smiles))) + for field, target_list in targets.items(): + if len(target_list) != num_sample: + raise ValueError("Number of target `%s` doesn't match with number of molecules. " + "Expect %d but found %d" % (field, num_sample, len(target_list))) + + self.transform = transform + self.lazy = lazy + self.kwargs = kwargs + self.sequences = [] + self.smiles = [] + self.data = [] + self.targets = defaultdict(list) + cum_num_samples = [num_samples[0]] + for num in num_samples[1:]: + cum_num_samples.append(cum_num_samples[-1] + num) + _cur_split = 0 + + if verbose: + sequences = tqdm(sequences, "Constructing proteins from sequences") + for i, (sequence, smile) in enumerate(zip(sequences, smiles)): + if i >= cum_num_samples[_cur_split]: + _cur_split += 1 + if not self.lazy or len(self.data) == 0: + protein = data.Protein.from_sequence(sequence, **kwargs) + mol = Chem.MolFromSmiles(smile) + if not mol: + logger.debug("Can't construct molecule from SMILES `%s`. Ignore this sample." % smile) + num_samples[_cur_split] -= 1 + continue + mol = data.Molecule.from_molecule(mol) + else: + protein = None + mol = None + if attributes is not None: + with protein.graph(): + for field in attributes: + setattr(protein, field, attributes[field][i]) + self.data.append([protein, mol]) + self.sequences.append(sequence) + self.smiles.append(smile) + for field in targets: + self.targets[field].append(targets[field][i]) + + return num_samples + + @utils.copy_args(load_sequence) + def load_lmdbs(self, lmdb_files, sequence_field="target", smiles_field="drug", target_fields=None, + number_field="num_examples", transform=None, lazy=False, verbose=0, **kwargs): + """ + Load the dataset from lmdb files. + + Parameters: + lmdb_files (list of str): file names + sequence_field (str, optional): name of the field of protein sequence in lmdb files + smiles_field (str, optional): name of the field of ligand SMILES string in lmdb files + target_fields (list of str, optional): name of target fields in lmdb files + number_field (str, optional): name of the field of sample count in lmdb files + transform (Callable, optional): protein sequence transformation function + lazy (bool, optional): if lazy mode is used, the protein-ligand pairs are processed in the dataloader. + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. + verbose (int, optional): output verbose level + **kwargs + """ + if target_fields is not None: + target_fields = set(target_fields) + + sequences = [] + smiles = [] + num_samples = [] + targets = defaultdict(list) + for lmdb_file in lmdb_files: + env = lmdb.open(lmdb_file, readonly=True, lock=False, readahead=False, meminit=False) + with env.begin(write=False) as txn: + num_sample = pickle.loads(txn.get(number_field.encode())) + for i in range(num_sample): + item = pickle.loads(txn.get(str(i).encode())) + sequences.append(item[sequence_field]) + smiles.append(item[smiles_field]) + if target_fields: + for field in target_fields: + value = item[field] + if isinstance(value, np.ndarray) and value.size == 1: + value = value.item() + targets[field].append(value) + num_samples.append(num_sample) + + num_samples = self.load_sequence(sequences, smiles, targets, num_samples, transform=transform, + lazy=lazy, verbose=verbose, **kwargs) + self.num_samples = num_samples + + @property + def ligand_node_feature_dim(self): + """Dimension of node features for ligands.""" + return self.data[0][1].node_feature.shape[-1] + + @property + def protein_node_feature_dim(self): + """Dimension of node features for proteins.""" + return self.data[0][0].node_feature.shape[-1] + + @property + def residue_feature_dim(self): + """Dimension of residue features for proteins.""" + return self.data[0][0].residue_feature.shape[-1] + + +class SemiSupervised(torch_data.Dataset, core.Configurable): + """ + Semi-supervised dataset. + + Parameters: + dataset (Dataset): supervised dataset + indices (list of int): sample indices to keep supervision + """ + + def __init__(self, dataset, indices): + self.dataset = dataset + self.indices = set(indices) + + def __getitem__(self, idx): + item = self.dataset[idx] + item["labeled"] = (idx in self.indices) + return item + + def __len__(self): + return len(self.dataset) + + +def semisupervised(dataset, length): + """ + Randomly construct a semi-supervised dataset based on the given length. + + Parameters: + dataset (Dataset): supervised dataset + length (int): length of supervised data to keep + """ + if length > len(dataset): + raise ValueError("Length of labeled data exceeds the length of the dataset") + + indexes = torch.randperm(length)[:length].tolist() + return SemiSupervised(dataset, indexes) + + +def key_split(dataset, keys, lengths=None, key_lengths=None): + + def round_to_boundary(i): + for j in range(min(i, len(dataset) - i)): + if keys[indexes[i - j]] != keys[indexes[i - j - 1]]: + return i - j + if keys[indexes[i + j]] != keys[indexes[i + j - 1]]: + return i + j + if i < len(dataset) - i: + return 0 + else: + return len(dataset) + + keys = torch.as_tensor(keys) + key_set, keys = torch.unique(keys, return_inverse=True) + perm = torch.randperm(len(key_set)) + keys = perm[keys] + indexes = keys.argsort().tolist() + + if key_lengths is not None: + assert lengths is None + key2count = keys.bincount() + key_offset = 0 + lengths = [] + for key_length in key_lengths: + lengths.append(key2count[key_offset: key_offset + key_length].sum().item()) + key_offset += key_length + + offset = 0 + offsets = [offset] + for length in lengths: + offset = round_to_boundary(offset + length) + offsets.append(offset) + offsets[-1] = len(dataset) + return [torch_data.Subset(dataset, indexes[offsets[i]: offsets[i + 1]]) for i in range(len(lengths))] + + +def scaffold_split(dataset, lengths): + """ + Randomly split a dataset into new datasets with non-overlapping scaffolds. + + Parameters: + dataset (Dataset): dataset to split + lengths (list of int): expected length for each split. + Note the results may be different in length due to rounding. + """ + + scaffold2id = {} + keys = [] + for sample in dataset: + scaffold = sample["graph"].to_scaffold() + if scaffold not in scaffold2id: + id = len(scaffold2id) + scaffold2id[scaffold] = id + else: + id = scaffold2id[scaffold] + keys.append(id) + + return key_split(dataset, keys, lengths) + + +def ordered_scaffold_split(dataset, lengths, chirality=True): + """ + Split a dataset into new datasets with non-overlapping scaffolds and sorted w.r.t. number of each scaffold. + + Parameters: + dataset (Dataset): dataset to split + lengths (list of int): expected length for each split. + Note the results may be different in length due to rounding. + """ + frac_train, frac_valid, frac_test = 0.8, 0.1, 0.1 + + scaffold2id = defaultdict(list) + for idx, smiles in enumerate(dataset.smiles_list): + scaffold = MurckoScaffold.MurckoScaffoldSmiles(smiles=smiles, includeChirality=chirality) + scaffold2id[scaffold].append(idx) + + scaffold2id = {key: sorted(value) for key, value in scaffold2id.items()} + scaffold_sets = [ + scaffold_set for (scaffold, scaffold_set) in sorted( + scaffold2id.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True) + ] + train_cutoff = frac_train * len(dataset) + valid_cutoff = (frac_train + frac_valid) * len(dataset) + train_idx, valid_idx, test_idx = [], [], [] + for scaffold_set in scaffold_sets: + if len(train_idx) + len(scaffold_set) > train_cutoff: + if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff: + test_idx.extend(scaffold_set) + else: + valid_idx.extend(scaffold_set) + else: + train_idx.extend(scaffold_set) + + return torch_data.Subset(dataset, train_idx), torch_data.Subset(dataset, valid_idx), torch_data.Subset(dataset, test_idx) diff --git a/build/lib/torchdrug/data/dictionary.py b/build/lib/torchdrug/data/dictionary.py new file mode 100644 index 00000000..e6fa9361 --- /dev/null +++ b/build/lib/torchdrug/data/dictionary.py @@ -0,0 +1,285 @@ +import math +import torch + +from torch_scatter import scatter_max + +from torchdrug import utils + + +class PerfectHash(object): + """ + Perfect hash function. + + The function can be applied to either scalar keys or vector keys. + It takes :math:`O(n\log n)` time and :math:`O(n)` space to construct the hash table. + It maps queries to their indexes in the original key set in :math:`O(1)` time. + If the query is not present in the key set, it returns -1. + + The algorithm is adapted from `Storing a Sparse Table with O(1) Worst Case Access Time`_. + + .. _Storing a Sparse Table with O(1) Worst Case Access Time: + http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.91.346&rep=rep1&type=pdf + + Parameters: + keys (LongTensor): keys of shape :math:`(N,)` or :math:`(N, D)` + weight (LongTensor, optional): weight of the level-1 hash + bias (LongTensor, optional): bias of the level-1 hash + sub_weights (LongTensor, optional): weight of level-2 hashes + sub_biases (LongTensor, optional): bias of level-2 hashes + """ + + prime = 1000000000039 + max_attempt = 10 + max_input_dim = (torch.iinfo(torch.int64).max - prime) / prime + + def __init__(self, keys, weight=None, bias=None, sub_weights=None, sub_biases=None): + if keys.ndim == 1: + keys = keys.unsqueeze(-1) + num_input, input_dim = keys.shape + if weight is None: + weight = torch.randint(0, self.prime, (1, input_dim), device=keys.device) + if bias is None: + bias = torch.randint(0, self.prime, (1,), device=keys.device) + if sub_weights is None: + sub_weights = torch.randint(0, self.prime, (num_input, input_dim), device=keys.device) + if sub_biases is None: + sub_biases = torch.randint(0, self.prime, (num_input,), device=keys.device) + + self.keys = keys + self.weight = weight + self.bias = bias + self.sub_weights = sub_weights + self.sub_biases = sub_biases + self.num_input = num_input + self.num_output = num_input + self.input_dim = input_dim + + self._construct_hash_table() + + def _construct_hash_table(self): + index = self.hash(self.keys) + count = index.bincount(minlength=self.num_output) + for i in range(self.max_attempt): + if (count ** 2).sum() < 4 * self.num_output: + break + self._reset_hash() + index = self.hash(self.keys) + count = index.bincount(minlength=self.num_output) + else: + raise RuntimeError("Fail to generate a level-1 hash after %d attempts. " + "Are you sure the keys are unique?" % self.max_attempt) + self.num_sub_outputs = (count ** 2).clamp(min=1) + self.num_sub_output = self.num_sub_outputs.sum() + self.offsets = self.num_sub_outputs.cumsum(0) - self.num_sub_outputs + + sub_index = self.sub_hash(self.keys, index) + count = sub_index.bincount(minlength=self.num_sub_output) + has_collision = scatter_max(count, self.second2first, dim_size=self.num_output)[0] > 1 + max_attempt = int(self.max_attempt * math.log(self.num_input) / math.log(2)) + for i in range(max_attempt): + if not has_collision.any(): + break + self._reset_sub_hash(has_collision) + sub_index = self.sub_hash(self.keys, index) + count = sub_index.bincount(minlength=self.num_sub_output) + has_collision = scatter_max(count, self.second2first, dim_size=self.num_output)[0] > 1 + else: + raise RuntimeError("Fail to generate level-2 hashes after %d attempts. " + "Are you sure the keys are unique?" % max_attempt) + + self.table = -torch.ones(self.num_sub_output, dtype=torch.long, device=self.device) + self.table[sub_index] = torch.arange(self.num_input, device=self.device) + + def __call__(self, keys): + """ + Get the indexes of keys in the original key set. + + Return -1 for keys that are not present in the key set. + """ + keys = torch.as_tensor(keys, dtype=torch.long, device=self.device) + if self.input_dim == 1 and keys.shape[-1] != 1: + keys = keys.unsqueeze(-1) + index = self.hash(keys) + sub_index = self.sub_hash(keys, index) + final_index = self.table[sub_index] + found = final_index != -1 + found_index = final_index[found] + equal = (keys[found] == self.keys[final_index[found]]).all(dim=-1) + final_index[found] = torch.where(equal, found_index, -torch.ones_like(found_index)) + return final_index + + def _reset_hash(self): + self.weight = torch.randint_like(self.weight, 0, self.prime) + self.bias = torch.randint_like(self.bias, 0, self.prime) + + def _reset_sub_hash(self, mask=None): + if mask is None: + self.sub_weights = torch.randint_like(self.sub_weights, 0, self.prime) + self.sub_biases = torch.randint_like(self.sub_biases, 0, self.prime) + else: + self.sub_weights[mask] = torch.randint_like(self.sub_weights[mask], 0, self.prime) + self.sub_biases[mask] = torch.randint_like(self.sub_biases[mask], 0, self.prime) + + def hash(self, keys): + """Apply the level-1 hash function to the keys.""" + keys = keys % self.prime + hash = (keys * self.weight % self.prime).sum(dim=-1) + self.bias + return hash % self.prime % self.num_output + + def sub_hash(self, keys, index): + """ + Apply level-2 hash functions to the keys. + + Parameters: + keys (LongTensor): query keys + index (LongTensor): output of the level-1 hash function + """ + keys = keys % self.prime + weight = self.sub_weights[index] + bias = self.sub_biases[index] + num_outputs = self.num_sub_outputs[index] + offsets = self.offsets[index] + hash = (keys * weight % self.prime).sum(dim=-1) + bias + return hash % self.prime % num_outputs + offsets + + @utils.cached_property + def second2first(self): + """Level-2 hash values to level-1 hash values mapping.""" + range = torch.arange(self.num_output, device=self.device) + second2first = range.repeat_interleave(self.num_sub_outputs) + return second2first + + @property + def device(self): + """Device.""" + return self.keys.device + + def cpu(self): + """ + Return a copy of this hash function in CPU memory. + + This is a non-op if the hash function is already in CPU memory. + """ + keys = self.keys.cpu() + + if keys is self.keys: + return self + else: + return type(self)(keys, weight=self.weight.cpu(), bias=self.bias.cpu(), + sub_weights=self.sub_weights.cpu(), sub_biases=self.sub_biases.cpu()) + + def cuda(self, *args, **kwargs): + """ + Return a copy of this hash function in CUDA memory. + + This is a non-op if the hash function is already on the correct device. + """ + keys = self.keys.cuda(*args, **kwargs) + + if keys is self.keys: + return self + else: + return type(self)(keys, weight=self.weight.cuda(*args, **kwargs), + bias=self.bias.cuda(*args, **kwargs), + sub_weights=self.sub_weights.cuda(*args, **kwargs), + sub_biases=self.sub_biases.cuda(*args, **kwargs)) + + +class Dictionary(object): + """ + Dictionary for mapping keys to values. + + This class has the same behavior as the built-in dict, except it operates on tensors and support batching. + + Example:: + + >>> keys = torch.tensor([[0, 0], [1, 1], [2, 2]]) + >>> values = torch.tensor([[0, 1], [1, 2], [2, 3]]) + >>> d = data.Dictionary(keys, values) + >>> assert (d[[[0, 0], [2, 2]]] == values[[0, 2]]).all() + >>> assert (d.has_key([[0, 1], [1, 2]]) == torch.tensor([False, False])).all() + + Parameters: + keys (LongTensor): keys of shape :math:`(N,)` or :math:`(N, D)` + values (Tensor): values of shape :math:`(N, ...)` + hash (PerfectHash, optional): hash function for keys + """ + def __init__(self, keys, values, hash=None): + self.keys = keys + self.values = values + self.hash = hash or PerfectHash(keys) + + def __getitem__(self, keys): + """ + Return the value for each key. Raise key error if any key is not in the dictionary. + """ + keys = torch.as_tensor(keys, dtype=torch.long, device=self.device) + index = self.hash(keys) + not_found = index == -1 + if not_found.any(): + raise KeyError(keys[not_found].tolist()) + return self.values[index] + + def get(self, keys, default=None): + """ + Return the value for each key if the key is in the dictionary, otherwise the default value is returned. + + Parameters: + keys (LongTensor): keys of arbitrary shape + default (int or Tensor, optional): default return value. By default, 0 is used. + """ + keys = torch.as_tensor(keys, dtype=torch.long, device=self.device) + if default is None: + default = 0 + default = torch.as_tensor(default, dtype=self.values.dtype, device=self.device) + index = self.hash(keys) + shape = list(index.shape) + list(self.values.shape[1:]) + values = torch.ones(shape, dtype=self.values.dtype, device=self.device) * default + found = index != -1 + values[found] = self.values[index[found]] + return values + + def has_key(self, keys): + """Check whether each key exists in the dictionary.""" + index = self.hash(keys) + return index != -1 + + def to_dict(self): + """ + Return a built-in dict object of this dictionary. + """ + keys = self.keys.tolist() + values = self.values.tolist() + dict = {tuple(k): tuple(v) for k, v in zip(keys, values)} + return dict + + @property + def device(self): + """Device.""" + return self.keys.device + + def cpu(self): + """ + Return a copy of this dictionary in CPU memory. + + This is a non-op if the dictionary is already in CPU memory. + """ + keys = self.keys.cpu() + + if keys is self.keys: + return self + else: + return type(self)(keys, self.values.cpu(), hash=self.hash.cpu()) + + def cuda(self, *args, **kwargs): + """ + Return a copy of this dictionary in CUDA memory. + + This is a non-op if the dictionary is already in CUDA memory. + """ + keys = self.keys.cuda(*args, **kwargs) + + if keys is self.keys: + return self + else: + return type(self)(keys, self.values.cuda(*args, **kwargs), hash=self.hash.cuda(*args, **kwargs)) \ No newline at end of file diff --git a/build/lib/torchdrug/data/feature.py b/build/lib/torchdrug/data/feature.py new file mode 100644 index 00000000..397d30df --- /dev/null +++ b/build/lib/torchdrug/data/feature.py @@ -0,0 +1,347 @@ +import warnings + +from rdkit import Chem +from rdkit.Chem import AllChem + +from torchdrug.core import Registry as R + + +# orderd by perodic table +atom_vocab = ["H", "B", "C", "N", "O", "F", "Mg", "Si", "P", "S", "Cl", "Cu", "Zn", "Se", "Br", "Sn", "I"] +atom_vocab = {a: i for i, a in enumerate(atom_vocab)} +degree_vocab = range(7) +num_hs_vocab = range(7) +formal_charge_vocab = range(-5, 6) +chiral_tag_vocab = range(4) +total_valence_vocab = range(8) +num_radical_vocab = range(8) +hybridization_vocab = range(len(Chem.rdchem.HybridizationType.values)) + +bond_type_vocab = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] +bond_type_vocab = {b: i for i, b in enumerate(bond_type_vocab)} +bond_dir_vocab = range(len(Chem.rdchem.BondDir.values)) +bond_stereo_vocab = range(len(Chem.rdchem.BondStereo.values)) + +# orderd by molecular mass +residue_vocab = ["GLY", "ALA", "SER", "PRO", "VAL", "THR", "CYS", "ILE", "LEU", "ASN", + "ASP", "GLN", "LYS", "GLU", "MET", "HIS", "PHE", "ARG", "TYR", "TRP"] + + +def onehot(x, vocab, allow_unknown=False): + if x in vocab: + if isinstance(vocab, dict): + index = vocab[x] + else: + index = vocab.index(x) + else: + index = -1 + if allow_unknown: + feature = [0] * (len(vocab) + 1) + if index == -1: + warnings.warn("Unknown value `%s`" % x) + feature[index] = 1 + else: + feature = [0] * len(vocab) + if index == -1: + raise ValueError("Unknown value `%s`. Available vocabulary is `%s`" % (x, vocab)) + feature[index] = 1 + + return feature + + +# TODO: this one is too slow +@R.register("features.atom.default") +def atom_default(atom): + """Default atom feature. + + Features: + GetSymbol(): one-hot embedding for the atomic symbol + + GetChiralTag(): one-hot embedding for atomic chiral tag + + GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs + + GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule + + GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom + + GetNumRadicalElectrons(): one-hot embedding for the number of radical electrons on the atom + + GetHybridization(): one-hot embedding for the atom's hybridization + + GetIsAromatic(): whether the atom is aromatic + + IsInRing(): whether the atom is in a ring + """ + return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ + onehot(atom.GetChiralTag(), chiral_tag_vocab) + \ + onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \ + onehot(atom.GetFormalCharge(), formal_charge_vocab) + \ + onehot(atom.GetTotalNumHs(), num_hs_vocab) + \ + onehot(atom.GetNumRadicalElectrons(), num_radical_vocab) + \ + onehot(atom.GetHybridization(), hybridization_vocab) + \ + [atom.GetIsAromatic(), atom.IsInRing()] + + +@R.register("features.atom.center_identification") +def atom_center_identification(atom): + """Reaction center identification atom feature. + + Features: + GetSymbol(): one-hot embedding for the atomic symbol + + GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom + + GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs + + GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom + + GetIsAromatic(): whether the atom is aromatic + + IsInRing(): whether the atom is in a ring + """ + return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ + onehot(atom.GetTotalNumHs(), num_hs_vocab) + \ + onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \ + onehot(atom.GetTotalValence(), total_valence_vocab) + \ + [atom.GetIsAromatic(), atom.IsInRing()] + + +@R.register("features.atom.synthon_completion") +def atom_synthon_completion(atom): + """Synthon completion atom feature. + + Features: + GetSymbol(): one-hot embedding for the atomic symbol + + GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom + + GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs + + IsInRing(): whether the atom is in a ring + + IsInRingSize(3, 4, 5, 6): whether the atom is in a ring of a particular size + + IsInRing() and not IsInRingSize(3, 4, 5, 6): whether the atom is in a ring and not in a ring of 3, 4, 5, 6 + """ + return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ + onehot(atom.GetTotalNumHs(), num_hs_vocab) + \ + onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \ + [atom.IsInRing(), atom.IsInRingSize(3), atom.IsInRingSize(4), + atom.IsInRingSize(5), atom.IsInRingSize(6), + atom.IsInRing() and (not atom.IsInRingSize(3)) and (not atom.IsInRingSize(4)) \ + and (not atom.IsInRingSize(5)) and (not atom.IsInRingSize(6))] + + +@R.register("features.atom.symbol") +def atom_symbol(atom): + """Symbol atom feature. + + Features: + GetSymbol(): one-hot embedding for the atomic symbol + """ + return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + + +@R.register("features.atom.explicit_property_prediction") +def atom_explicit_property_prediction(atom): + """Explicit property prediction atom feature. + + Features: + GetSymbol(): one-hot embedding for the atomic symbol + + GetDegree(): one-hot embedding for the degree of the atom in the molecule + + GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom + + GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule + + GetIsAromatic(): whether the atom is aromatic + """ + return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ + onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \ + onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \ + onehot(atom.GetFormalCharge(), formal_charge_vocab) + \ + [atom.GetIsAromatic()] + + +@R.register("features.atom.property_prediction") +def atom_property_prediction(atom): + """Property prediction atom feature. + + Features: + GetSymbol(): one-hot embedding for the atomic symbol + + GetDegree(): one-hot embedding for the degree of the atom in the molecule + + GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom + + GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom + + GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule + + GetIsAromatic(): whether the atom is aromatic + """ + return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ + onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \ + onehot(atom.GetTotalNumHs(), num_hs_vocab, allow_unknown=True) + \ + onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \ + onehot(atom.GetFormalCharge(), formal_charge_vocab, allow_unknown=True) + \ + [atom.GetIsAromatic()] + + +@R.register("features.atom.position") +def atom_position(atom): + """ + Atom position in the molecular conformation. + Return 3D position if available, otherwise 2D position is returned. + + Note it takes much time to compute the conformation for large molecules. + """ + mol = atom.GetOwningMol() + if mol.GetNumConformers() == 0: + mol.Compute2DCoords() + conformer = mol.GetConformer() + pos = conformer.GetAtomPosition(atom.GetIdx()) + return [pos.x, pos.y, pos.z] + + +@R.register("features.atom.pretrain") +def atom_pretrain(atom): + """Atom feature for pretraining. + + Features: + GetSymbol(): one-hot embedding for the atomic symbol + + GetChiralTag(): one-hot embedding for atomic chiral tag + """ + return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ + onehot(atom.GetChiralTag(), chiral_tag_vocab) + + +@R.register("features.atom.residue_symbol") +def atom_residue_symbol(atom): + """Residue symbol as atom feature. Only support atoms in a protein. + + Features: + GetSymbol(): one-hot embedding for the atomic symbol + GetResidueName(): one-hot embedding for the residue symbol + """ + residue = atom.GetPDBResidueInfo() + return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ + onehot(residue.GetResidueName() if residue else -1, residue_vocab, allow_unknown=True) + + +@R.register("features.bond.default") +def bond_default(bond): + """Default bond feature. + + Features: + GetBondType(): one-hot embedding for the type of the bond + + GetBondDir(): one-hot embedding for the direction of the bond + + GetStereo(): one-hot embedding for the stereo configuration of the bond + + GetIsConjugated(): whether the bond is considered to be conjugated + """ + return onehot(bond.GetBondType(), bond_type_vocab) + \ + onehot(bond.GetBondDir(), bond_dir_vocab) + \ + onehot(bond.GetStereo(), bond_stereo_vocab) + \ + [int(bond.GetIsConjugated())] + + +@R.register("features.bond.length") +def bond_length(bond): + """ + Bond length in the molecular conformation. + + Note it takes much time to compute the conformation for large molecules. + """ + mol = bond.GetOwningMol() + if mol.GetNumConformers() == 0: + mol.Compute2DCoords() + conformer = mol.GetConformer() + h = conformer.GetAtomPosition(bond.GetBeginAtomIdx()) + t = conformer.GetAtomPosition(bond.GetEndAtomIdx()) + return [h.Distance(t)] + + +@R.register("features.bond.property_prediction") +def bond_property_prediction(bond): + """Property prediction bond feature. + + Features: + GetBondType(): one-hot embedding for the type of the bond + + GetIsConjugated(): whether the bond is considered to be conjugated + + IsInRing(): whether the bond is in a ring + """ + return onehot(bond.GetBondType(), bond_type_vocab) + \ + [int(bond.GetIsConjugated()), bond.IsInRing()] + + +@R.register("features.bond.pretrain") +def bond_pretrain(bond): + """Bond feature for pretraining. + + Features: + GetBondType(): one-hot embedding for the type of the bond + + GetBondDir(): one-hot embedding for the direction of the bond + """ + return onehot(bond.GetBondType(), bond_type_vocab) + \ + onehot(bond.GetBondDir(), bond_dir_vocab) + + +@R.register("features.residue.symbol") +def residue_symbol(residue): + """Symbol residue feature. + + Features: + GetResidueName(): one-hot embedding for the residue symbol + """ + return onehot(residue.GetResidueName(), residue_vocab, allow_unknown=True) + + +@R.register("features.residue.default") +def residue_default(residue): + """Default residue feature. + + Features: + GetResidueName(): one-hot embedding for the residue symbol + """ + return residue_symbol(residue) + + +@R.register("features.molecule.ecfp") +def ExtendedConnectivityFingerprint(mol, radius=2, length=1024): + """Extended Connectivity Fingerprint molecule feature. + + Features: + GetMorganFingerprintAsBitVect(): a Morgan fingerprint for a molecule as a bit vector + """ + ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, length) + return list(ecfp) + + +@R.register("features.molecule.default") +def molecule_default(mol): + """Default molecule feature.""" + return ExtendedConnectivityFingerprint(mol) + + +ECFP = ExtendedConnectivityFingerprint + + +__all__ = [ + "atom_default", "atom_center_identification", "atom_synthon_completion", + "atom_symbol", "atom_explicit_property_prediction", "atom_property_prediction", + "atom_position", "atom_pretrain", "atom_residue_symbol", + "bond_default", "bond_length", "bond_property_prediction", "bond_pretrain", + "residue_symbol", "residue_default", + "ExtendedConnectivityFingerprint", "molecule_default", + "ECFP", +] \ No newline at end of file diff --git a/build/lib/torchdrug/data/graph.py b/build/lib/torchdrug/data/graph.py new file mode 100644 index 00000000..6439736c --- /dev/null +++ b/build/lib/torchdrug/data/graph.py @@ -0,0 +1,1851 @@ +import math +import warnings +from functools import reduce +from collections import defaultdict + +import networkx as nx + +from matplotlib import pyplot as plt +import torch +from torch_scatter import scatter_add, scatter_min + +from torchdrug import core, utils +from torchdrug.data import Dictionary +from torchdrug.utils import pretty + +plt.switch_backend("agg") + + +class Graph(core._MetaContainer): + r""" + Basic container for sparse graphs. + + To batch graphs with variadic sizes, use :meth:`data.Graph.pack `. + This will return a PackedGraph object with the following block diagonal adjacency matrix. + + .. math:: + + \begin{bmatrix} + A_1 & \cdots & 0 \\ + \vdots & \ddots & \vdots \\ + 0 & \cdots & A_n + \end{bmatrix} + + where :math:`A_i` is the adjacency of :math:`i`-th graph. + + You may register dynamic attributes for each graph. + The registered attributes will be automatically processed during packing. + + .. warning:: + + This class doesn't enforce any order on the edges. + + Example:: + + >>> graph = data.Graph(torch.randint(10, (30, 2))) + >>> with graph.node(): + >>> graph.my_node_attr = torch.rand(10, 5, 5) + + Parameters: + edge_list (array_like, optional): list of edges of shape :math:`(|E|, 2)` or :math:`(|E|, 3)`. + Each tuple is (node_in, node_out) or (node_in, node_out, relation). + edge_weight (array_like, optional): edge weights of shape :math:`(|E|,)` + num_node (int, optional): number of nodes. + By default, it will be inferred from the largest id in `edge_list` + num_relation (int, optional): number of relations + node_feature (array_like, optional): node features of shape :math:`(|V|, ...)` + edge_feature (array_like, optional): edge features of shape :math:`(|E|, ...)` + graph_feature (array_like, optional): graph feature of any shape + """ + + _meta_types = {"node", "edge", "graph", "node reference", "edge reference", "graph reference"} + + def __init__(self, edge_list=None, edge_weight=None, num_node=None, num_relation=None, + node_feature=None, edge_feature=None, graph_feature=None, **kwargs): + super(Graph, self).__init__(**kwargs) + # edge_list: N * [h, t] or N * [h, t, r] + edge_list, num_edge = self._standarize_edge_list(edge_list, num_relation) + edge_weight = self._standarize_edge_weight(edge_weight, edge_list) + + num_node = self._standarize_num_node(num_node, edge_list) + num_relation = self._standarize_num_relation(num_relation, edge_list) + + self._edge_list = edge_list + self._edge_weight = edge_weight + self.num_node = num_node + self.num_edge = num_edge + self.num_relation = num_relation + + if node_feature is not None: + with self.node(): + self.node_feature = torch.as_tensor(node_feature, device=self.device) + if edge_feature is not None: + with self.edge(): + self.edge_feature = torch.as_tensor(edge_feature, device=self.device) + if graph_feature is not None: + with self.graph(): + self.graph_feature = torch.as_tensor(graph_feature, device=self.device) + + def node(self): + """ + Context manager for node attributes. + """ + return self.context("node") + + def edge(self): + """ + Context manager for edge attributes. + """ + return self.context("edge") + + def graph(self): + """ + Context manager for graph attributes. + """ + return self.context("graph") + + def node_reference(self): + """ + Context manager for node references. + """ + return self.context("node reference") + + def edge_reference(self): + """ + Context manager for edge references. + """ + return self.context("edge reference") + + def graph_reference(self): + """ + Context manager for graph references. + """ + return self.context("graph reference") + + def _check_attribute(self, key, value): + for type in self._meta_contexts: + if "reference" in type: + if value.dtype != torch.long: + raise TypeError("Tensors used as reference must be long tensors") + if type == "node": + if len(value) != self.num_node: + raise ValueError("Expect node attribute `%s` to have shape (%d, *), but found %s" % + (key, self.num_node, value.shape)) + elif type == "edge": + if len(value) != self.num_edge: + raise ValueError("Expect edge attribute `%s` to have shape (%d, *), but found %s" % + (key, self.num_edge, value.shape)) + elif type == "node reference": + is_valid = (value >= -1) & (value < self.num_node) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect node reference in [-1, %d), but found %d" % + (self.num_node, error_value[0])) + elif type == "edge reference": + is_valid = (value >= -1) & (value < self.num_edge) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect edge reference in [-1, %d), but found %d" % + (self.num_edge, error_value[0])) + elif type == "graph reference": + is_valid = (value >= -1) & (value < self.batch_size) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect graph reference in [-1, %d), but found %d" % + (self.batch_size, error_value[0])) + + def __setattr__(self, key, value): + if hasattr(self, "meta_dict"): + self._check_attribute(key, value) + super(Graph, self).__setattr__(key, value) + + def _standarize_edge_list(self, edge_list, num_relation): + if edge_list is not None and len(edge_list): + if isinstance(edge_list, torch.Tensor) and edge_list.dtype != torch.long: + try: + edge_list = torch.LongTensor(edge_list) + except TypeError: + raise TypeError("Can't convert `edge_list` to torch.long") + else: + edge_list = torch.as_tensor(edge_list, dtype=torch.long) + else: + num_element = 2 if num_relation is None else 3 + if isinstance(edge_list, torch.Tensor): + device = edge_list.device + else: + device = "cpu" + edge_list = torch.zeros(0, num_element, dtype=torch.long, device=device) + if (edge_list < 0).any(): + raise ValueError("`edge_list` should only contain non-negative indexes") + num_edge = torch.tensor(len(edge_list), device=edge_list.device) + return edge_list, num_edge + + def _standarize_edge_weight(self, edge_weight, edge_list): + if edge_weight is not None: + edge_weight = torch.as_tensor(edge_weight, dtype=torch.float, device=edge_list.device) + if len(edge_list) != len(edge_weight): + raise ValueError("`edge_list` and `edge_weight` should be the same size, but found %d and %d" + % (len(edge_list), len(edge_weight))) + else: + edge_weight = torch.ones(len(edge_list), device=edge_list.device) + return edge_weight + + def _standarize_num_node(self, num_node, edge_list): + if num_node is None: + num_node = self._maybe_num_node(edge_list) + num_node = torch.as_tensor(num_node, device=edge_list.device) + if (edge_list[:, :2] >= num_node).any(): + raise ValueError("`num_node` is %d, but found node %d in `edge_list`" % (num_node, edge_list[:, :2].max())) + return num_node + + def _standarize_num_relation(self, num_relation, edge_list): + if num_relation is None and edge_list.shape[1] > 2: + num_relation = self._maybe_num_relation(edge_list) + if num_relation is not None: + num_relation = torch.as_tensor(num_relation, device=edge_list.device) + if edge_list.shape[1] <= 2: + raise ValueError("`num_relation` is provided, but the number of dims of `edge_list` is less than 3.") + elif (edge_list[:, 2] >= num_relation).any(): + raise ValueError("`num_relation` is %d, but found relation %d in `edge_list`" % (num_relation, edge_list[:, 2].max())) + return num_relation + + def _maybe_num_node(self, edge_list): + warnings.warn("_maybe_num_node() is used to determine the number of nodes. " + "This may underestimate the count if there are isolated nodes.") + if len(edge_list): + return edge_list[:, :2].max().item() + 1 + else: + return 0 + + def _maybe_num_relation(self, edge_list): + warnings.warn("_maybe_num_relation() is used to determine the number of relations. " + "This may underestimate the count if there are unseen relations.") + return edge_list[:, 2].max().item() + 1 + + def _standarize_index(self, index, count): + if isinstance(index, slice): + start = index.start or 0 + if start < 0: + start += count + stop = index.stop or count + if stop < 0: + stop += count + step = index.step or 1 + index = torch.arange(start, stop, step, device=self.device) + else: + index = torch.as_tensor(index, device=self.device) + if index.ndim == 0: + index = index.unsqueeze(0) + if index.dtype == torch.bool: + if index.shape != (count,): + raise IndexError("Invalid mask. Expect mask to have shape %s, but found %s" % + ((int(count),), tuple(index.shape))) + index = index.nonzero().squeeze(-1) + else: + index = index.long() + max_index = -1 if len(index) == 0 else index.max().item() + if max_index >= count: + raise IndexError("Invalid index. Expect index smaller than %d, but found %d" % (count, max_index)) + return index + + def _get_mapping(self, index, count): + index = self._standarize_index(index, count) + if (index.bincount() > 1).any(): + raise ValueError("Can't create mapping for duplicate index") + mapping = -torch.ones(count + 1, dtype=torch.long, device=self.device) + mapping[index] = torch.arange(len(index), device=self.device) + return mapping + + def _get_repeat_pack_offsets(self, num_xs, repeats): + new_num_xs = num_xs.repeat_interleave(repeats) + cum_repeats_shifted = repeats.cumsum(0) - repeats + new_num_xs[cum_repeats_shifted] -= num_xs + offsets = new_num_xs.cumsum(0) + return offsets + + @classmethod + def from_dense(cls, adjacency, node_feature=None, edge_feature=None): + """ + Create a sparse graph from a dense adjacency matrix. + For zero entries in the adjacency matrix, their edge features will be ignored. + + Parameters: + adjacency (array_like): adjacency matrix of shape :math:`(|V|, |V|)` or :math:`(|V|, |V|, |R|)` + node_feature (array_like): node features of shape :math:`(|V|, ...)` + edge_feature (array_like): edge features of shape :math:`(|V|, |V|, ...)` or :math:`(|V|, |V|, |R|, ...)` + """ + adjacency = torch.as_tensor(adjacency) + if adjacency.shape[0] != adjacency.shape[1]: + raise ValueError("`adjacency` should be a square matrix, but found %d and %d" % adjacency.shape[:2]) + + edge_list = adjacency.nonzero() + edge_weight = adjacency[tuple(edge_list.t())] + num_node = adjacency.shape[0] + num_relation = adjacency.shape[2] if adjacency.ndim > 2 else None + if edge_feature is not None: + edge_feature = torch.as_tensor(edge_feature) + edge_feature = edge_feature[tuple(edge_list.t())] + + return cls(edge_list, edge_weight, num_node, num_relation, node_feature, edge_feature) + + def connected_components(self): + """ + Split this graph into connected components. + + Returns: + (PackedGraph, LongTensor): connected components, number of connected components per graph + """ + node_in, node_out = self.edge_list.t()[:2] + range = torch.arange(self.num_node, device=self.device) + node_in, node_out = torch.cat([node_in, node_out, range]), torch.cat([node_out, node_in, range]) + + # find connected component + # O(|E|d), d is the diameter of the graph + min_neighbor = torch.arange(self.num_node, device=self.device) + last = torch.zeros_like(min_neighbor) + while not torch.equal(min_neighbor, last): + last = min_neighbor + min_neighbor = scatter_min(min_neighbor[node_out], node_in, dim_size=self.num_node)[0] + anchor = torch.unique(min_neighbor) + num_cc = self.node2graph[anchor].bincount(minlength=self.batch_size) + return self.split(min_neighbor), num_cc + + def split(self, node2graph): + """ + Split a graph into multiple disconnected graphs. + + Parameters: + node2graph (array_like): ID of the graph each node belongs to + + Returns: + PackedGraph + """ + node2graph = torch.as_tensor(node2graph, dtype=torch.long, device=self.device) + # coalesce arbitrary graph IDs to [0, n) + _, node2graph = torch.unique(node2graph, return_inverse=True) + num_graph = node2graph.max() + 1 + index = node2graph.argsort() + mapping = torch.zeros_like(index) + mapping[index] = torch.arange(len(index), device=self.device) + + node_in, node_out = self.edge_list.t()[:2] + edge_mask = node2graph[node_in] == node2graph[node_out] + edge2graph = node2graph[node_in] + edge_index = edge2graph.argsort() + edge_index = edge_index[edge_mask[edge_index]] + + prepend = torch.tensor([-1], device=self.device) + is_first_node = torch.diff(node2graph[index], prepend=prepend) > 0 + graph_index = self.node2graph[index[is_first_node]] + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + + num_nodes = node2graph.bincount(minlength=num_graph) + num_edges = edge2graph[edge_index].bincount(minlength=num_graph) + + num_cum_nodes = num_nodes.cumsum(0) + offsets = (num_cum_nodes - num_nodes)[edge2graph[edge_index]] + + data_dict, meta_dict = self.data_mask(index, edge_index, graph_index=graph_index, exclude="graph reference") + + return self.packed_type(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, + num_edges=num_edges, num_relation=self.num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + @classmethod + def pack(cls, graphs): + """ + Pack a list of graphs into a PackedGraph object. + + Parameters: + graphs (list of Graph): list of graphs + + Returns: + PackedGraph + """ + edge_list = [] + edge_weight = [] + num_nodes = [] + num_edges = [] + num_relation = -1 + num_cum_node = 0 + num_cum_edge = 0 + num_graph = 0 + data_dict = defaultdict(list) + meta_dict = graphs[0].meta_dict + for graph in graphs: + edge_list.append(graph.edge_list) + edge_weight.append(graph.edge_weight) + num_nodes.append(graph.num_node) + num_edges.append(graph.num_edge) + for k, v in graph.data_dict.items(): + for type in meta_dict[k]: + if type == "graph": + v = v.unsqueeze(0) + elif type == "node reference": + v = v + num_cum_node + elif type == "edge reference": + v = v + num_cum_edge + elif type == "graph reference": + v = v + num_graph + data_dict[k].append(v) + if num_relation == -1: + num_relation = graph.num_relation + elif num_relation != graph.num_relation: + raise ValueError("Inconsistent `num_relation` in graphs. Expect %d but got %d." + % (num_relation, graph.num_relation)) + num_cum_node += graph.num_node + num_cum_edge += graph.num_edge + num_graph += 1 + + edge_list = torch.cat(edge_list) + edge_weight = torch.cat(edge_weight) + data_dict = {k: torch.cat(v) for k, v in data_dict.items()} + + return cls.packed_type(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, + num_relation=num_relation, meta_dict=meta_dict, **data_dict) + + def repeat(self, count): + """ + Repeat this graph. + + Parameters: + count (int): number of repetitions + + Returns: + PackedGraph + """ + edge_list = self.edge_list.repeat(count, 1) + edge_weight = self.edge_weight.repeat(count) + num_nodes = [self.num_node] * count + num_edges = [self.num_edge] * count + num_relation = self.num_relation + + data_dict = {} + for k, v in self.data_dict.items(): + if "graph" in self.meta_dict[k]: + v = v.unsqueeze(0) + shape = [1] * v.ndim + shape[0] = count + length = len(v) + v = v.repeat(shape) + for type in self.meta_dict[k]: + if type == "node reference": + offsets = torch.arange(count, device=self.device) * self.num_node + v = v + offsets.repeat_interleave(length) + elif type == "edge reference": + offsets = torch.arange(count, device=self.device) * self.num_edge + v = v + offsets.repeat_interleave(length) + elif type == "graph reference": + offsets = torch.arange(count, device=self.device) + v = v + offsets.repeat_interleave(length) + data_dict[k] = v + + return self.packed_type(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, + num_relation=num_relation, meta_dict=self.meta_dict, **data_dict) + + def get_edge(self, edge): + """ + Get the weight of of an edge. + + Parameters: + edge (array_like): index of shape :math:`(2,)` or :math:`(3,)` + + Returns: + Tensor: weight of the edge + """ + if len(edge) != self.edge_list.shape[1]: + raise ValueError("Incorrect edge index. Expect %d axes but got %d axes" + % (self.edge_list.shape[1], len(edge))) + + edge_index, num_match = self.match(edge) + return self.edge_weight[edge_index].sum() + + def match(self, pattern): + """ + Return all matched indexes for each pattern. Support patterns with ``-1`` as the wildcard. + + Parameters: + pattern (array_like): index of shape :math:`(N, 2)` or :math:`(N, 3)` + + Returns: + (LongTensor, LongTensor): matched indexes, number of matches per edge + + Examples:: + + >>> graph = data.Graph([[0, 1], [1, 0], [1, 2], [2, 1], [2, 0], [0, 2]]) + >>> index, num_match = graph.match([[0, -1], [1, 2]]) + >>> assert (index == torch.tensor([0, 5, 2])).all() + >>> assert (num_match == torch.tensor([2, 1])).all() + + """ + if len(pattern) == 0: + index = num_match = torch.zeros(0, dtype=torch.long, device=self.device) + return index, num_match + + if not hasattr(self, "edge_inverted_index"): + self.edge_inverted_index = {} + pattern = torch.as_tensor(pattern, dtype=torch.long, device=self.device) + if pattern.ndim == 1: + pattern = pattern.unsqueeze(0) + mask = pattern != -1 + scale = 2 ** torch.arange(pattern.shape[-1], device=self.device) + query_type = (mask * scale).sum(dim=-1) + query_index = query_type.argsort() + num_query = query_type.unique(return_counts=True)[1] + query_ends = num_query.cumsum(0) + query_starts = query_ends - num_query + mask_set = mask[query_index[query_starts]].tolist() + + type_ranges = [] + type_orders = [] + # get matched range for each query type + for i, mask in enumerate(mask_set): + query_type = tuple(mask) + type_index = query_index[query_starts[i]: query_ends[i]] + type_edge = pattern[type_index][:, mask] + if query_type not in self.edge_inverted_index: + self.edge_inverted_index[query_type] = self._build_edge_inverted_index(mask) + inverted_range, order = self.edge_inverted_index[query_type] + ranges = inverted_range.get(type_edge, default=0) + type_ranges.append(ranges) + type_orders.append(order) + ranges = torch.cat(type_ranges) + orders = torch.stack(type_orders) + types = torch.arange(len(mask_set), device=self.device) + types = types.repeat_interleave(num_query) + + # reorder matched ranges according to the query order + ranges = scatter_add(ranges, query_index, dim=0, dim_size=len(pattern)) + types = scatter_add(types, query_index, dim_size=len(pattern)) + # convert range to indexes + starts, ends = ranges.t() + num_match = ends - starts + offsets = num_match.cumsum(0) - num_match + types = types.repeat_interleave(num_match) + ranges = torch.arange(num_match.sum(), device=self.device) + ranges = ranges + (starts - offsets).repeat_interleave(num_match) + index = orders[types, ranges] + + return index, num_match + + def _build_edge_inverted_index(self, mask): + keys = self.edge_list[:, mask] + base = torch.tensor(self.shape, device=self.device) + base = base[mask] + max = reduce(int.__mul__, base.tolist()) + if max > torch.iinfo(torch.int64).max: + raise ValueError("Fail to build an inverted index table based on sorting. " + "The graph is too large.") + scale = base.cumprod(0) + scale = torch.div(scale[-1], scale, rounding_mode="floor") + key = (keys * scale).sum(dim=-1) + order = key.argsort() + num_keys = key.unique(return_counts=True)[1] + ends = num_keys.cumsum(0) + starts = ends - num_keys + ranges = torch.stack([starts, ends], dim=-1) + keys_set = keys[order[starts]] + inverted_range = Dictionary(keys_set, ranges) + return inverted_range, order + + def __getitem__(self, index): + # why do we check tuple? + # case 1: x[0, 1] is parsed as (0, 1) + # case 2: x[[0, 1]] is parsed as [0, 1] + if not isinstance(index, tuple): + index = (index,) + index = list(index) + + while len(index) < 2: + index.append(slice(None)) + if len(index) > 2: + raise ValueError("Graph has only 2 axis, but %d axis is indexed" % len(index)) + + if all([isinstance(axis_index, int) for axis_index in index]): + return self.get_edge(index) + + edge_list = self.edge_list.clone() + for i, axis_index in enumerate(index): + axis_index = self._standarize_index(axis_index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + mapping[axis_index] = axis_index + edge_list[:, i] = mapping[edge_list[:, i]] + edge_index = (edge_list >= 0).all(dim=-1) + + return self.edge_mask(edge_index) + + def __len__(self): + return 1 + + @property + def batch_size(self): + """Batch size.""" + return 1 + + def subgraph(self, index): + """ + Return a subgraph based on the specified nodes. + Equivalent to :meth:`node_mask(index, compact=True) `. + + Parameters: + index (array_like): node index + + Returns: + Graph + + See also: + :meth:`Graph.node_mask` + """ + return self.node_mask(index, compact=True) + + def data_mask(self, node_index=None, edge_index=None, graph_index=None, include=None, exclude=None): + data_dict, meta_dict = self.data_by_meta(include, exclude) + node_mapping = None + edge_mapping = None + graph_mapping = None + for k, v in data_dict.items(): + for type in meta_dict[k]: + if type == "node" and node_index is not None: + v = v[node_index] + elif type == "edge" and edge_index is not None: + v = v[edge_index] + elif type == "graph" and graph_index is not None: + v = v.unsqueeze(0)[graph_index] + elif type == "node reference" and node_index is not None: + if node_mapping is None: + node_mapping = self._get_mapping(node_index, self.num_node) + v = node_mapping[v] + elif type == "edge reference" and edge_index is not None: + if edge_mapping is None: + edge_mapping = self._get_mapping(edge_index, self.num_edge) + v = edge_mapping[v] + elif type == "graph reference" and graph_index is not None: + if graph_mapping is None: + graph_mapping = self._get_mapping(graph_index, self.batch_size) + v = graph_mapping[v] + data_dict[k] = v + + return data_dict, meta_dict + + def node_mask(self, index, compact=False): + """ + Return a masked graph based on the specified nodes. + + This function can also be used to re-order the nodes. + + Parameters: + index (array_like): node index + compact (bool, optional): compact node ids or not + + Returns: + Graph + + Examples:: + + >>> graph = data.Graph.from_dense(torch.eye(3)) + >>> assert graph.node_mask([1, 2]).adjacency.shape == (3, 3) + >>> assert graph.node_mask([1, 2], compact=True).adjacency.shape == (2, 2) + + """ + index = self._standarize_index(index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + mapping[index] = torch.arange(len(index), device=self.device) + num_node = len(index) + else: + mapping[index] = index + num_node = self.num_node + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + + if compact: + data_dict, meta_dict = self.data_mask(index, edge_index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_node=num_node, + num_relation=self.num_relation, meta_dict=meta_dict, **data_dict) + + def compact(self): + """ + Remove isolated nodes and compact node ids. + + Returns: + Graph + """ + index = self.degree_out + self.degree_in > 0 + return self.subgraph(index) + + def edge_mask(self, index): + """ + Return a masked graph based on the specified edges. + + This function can also be used to re-order the edges. + + Parameters: + index (array_like): edge index + + Returns: + Graph + """ + index = self._standarize_index(index, self.num_edge) + data_dict, meta_dict = self.data_mask(edge_index=index) + + return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index], num_node=self.num_node, + num_relation=self.num_relation, meta_dict=meta_dict, **data_dict) + + def line_graph(self): + """ + Construct a line graph of this graph. + The node feature of the line graph is inherited from the edge feature of the original graph. + + In the line graph, each node corresponds to an edge in the original graph. + For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph, + there is a directed edge (a, b) -> (b, c) in the line graph. + + Returns: + Graph + """ + node_in, node_out = self.edge_list.t()[:2] + edge_index = torch.arange(self.num_edge, device=self.device) + edge_in = edge_index[node_out.argsort()] + edge_out = edge_index[node_in.argsort()] + + degree_in = node_in.bincount(minlength=self.num_node) + degree_out = node_out.bincount(minlength=self.num_node) + size = degree_out * degree_in + starts = (size.cumsum(0) - size).repeat_interleave(size) + range = torch.arange(size.sum(), device=self.device) + # each node u has degree_out[u] * degree_in[u] local edges + local_index = range - starts + local_inner_size = degree_in.repeat_interleave(size) + edge_in_offset = (degree_out.cumsum(0) - degree_out).repeat_interleave(size) + edge_out_offset = (degree_in.cumsum(0) - degree_in).repeat_interleave(size) + edge_in_index = torch.div(local_index, local_inner_size, rounding_mode="floor") + edge_in_offset + edge_out_index = local_index % local_inner_size + edge_out_offset + + edge_in = edge_in[edge_in_index] + edge_out = edge_out[edge_out_index] + edge_list = torch.stack([edge_in, edge_out], dim=-1) + node_feature = getattr(self, "edge_feature", None) + num_node = self.num_edge + num_edge = size.sum() + + return Graph(edge_list, num_node=num_node, num_edge=num_edge, node_feature=node_feature) + + def full(self): + """ + Return a fully connected graph over the nodes. + + Returns: + Graph + """ + index = torch.arange(self.num_node, device=self.device) + if self.num_relation: + edge_list = torch.meshgrid(index, index, torch.arange(self.num_relation, device=self.device)) + else: + edge_list = torch.meshgrid(index, index) + edge_list = torch.stack(edge_list).flatten(1) + edge_weight = torch.ones(len(edge_list)) + + data_dict, meta_dict = self.data_by_meta(exclude="edge") + + return type(self)(edge_list, edge_weight=edge_weight, num_node=self.num_node, num_relation=self.num_relation, + meta_dict=meta_dict, **data_dict) + + def directed(self, order=None): + """ + Mask the edges to create a directed graph. + Edges that go from a node index to a larger or equal node index will be kept. + + Parameters: + order (Tensor, optional): topological order of the nodes + """ + node_in, node_out = self.edge_list.t()[:2] + if order is not None: + edge_index = order[node_in] <= order[node_out] + else: + edge_index = node_in <= node_out + + return self.edge_mask(edge_index) + + def undirected(self, add_inverse=False): + """ + Flip all the edges to create an undirected graph. + + For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. + The inverse relation for relation :math:`r` is defined as :math:`|R| + r`. + + Parameters: + add_inverse (bool, optional): whether to use inverse relations for flipped edges + """ + edge_list = self.edge_list.clone() + edge_list[:, :2] = edge_list[:, :2].flip(1) + num_relation = self.num_relation + if num_relation and add_inverse: + edge_list[:, 2] += num_relation + num_relation = num_relation * 2 + edge_list = torch.stack([self.edge_list, edge_list], dim=1).flatten(0, 1) + + index = torch.arange(self.num_edge, device=self.device).unsqueeze(-1).expand(-1, 2).flatten() + data_dict, meta_dict = self.data_mask(edge_index=index) + + return type(self)(edge_list, edge_weight=self.edge_weight[index], num_node=self.num_node, + num_relation=num_relation, meta_dict=meta_dict, **data_dict) + + @utils.cached_property + def adjacency(self): + """ + Adjacency matrix of this graph. + + If :attr:`num_relation` is specified, a sparse tensor of shape :math:`(|V|, |V|, num\_relation)` will be + returned. + Otherwise, a sparse tensor of shape :math:`(|V|, |V|)` will be returned. + """ + return utils.sparse_coo_tensor(self.edge_list.t(), self.edge_weight, self.shape) + + _tensor_names = ["edge_list", "edge_weight", "num_node", "num_relation", "edge_feature"] + + def to_tensors(self): + edge_feature = getattr(self, "edge_feature", torch.tensor(0, device=self.device)) + return self.edge_list, self.edge_weight, self.num_node, self.num_relation, edge_feature + + @classmethod + def from_tensors(cls, tensors): + edge_list, edge_weight, num_node, num_relation, edge_feature = tensors + if edge_feature.ndim == 0: + edge_feature = None + return cls(edge_list, edge_weight, num_node, num_relation, edge_feature=edge_feature) + + @property + def node2graph(self): + """Node id to graph id mapping.""" + return torch.zeros(self.num_node, dtype=torch.long, device=self.device) + + @property + def edge2graph(self): + """Edge id to graph id mapping.""" + return torch.zeros(self.num_edge, dtype=torch.long, device=self.device) + + @utils.cached_property + def degree_out(self): + """ + Weighted number of edges containing each node as output. + + Note this is the **in-degree** in graph theory. + """ + return scatter_add(self.edge_weight, self.edge_list[:, 1], dim_size=self.num_node) + + @utils.cached_property + def degree_in(self): + """ + Weighted number of edges containing each node as input. + + Note this is the **out-degree** in graph theory. + """ + return scatter_add(self.edge_weight, self.edge_list[:, 0], dim_size=self.num_node) + + @property + def edge_list(self): + """List of edges.""" + return self._edge_list + + @property + def edge_weight(self): + """Edge weights.""" + return self._edge_weight + + @property + def device(self): + """Device.""" + return self.edge_list.device + + @property + def requires_grad(self): + return self.edge_weight.requires_grad + + @property + def grad(self): + return self.edge_weight.grad + + @property + def data(self): + return self + + def requires_grad_(self): + self.edge_weight.requires_grad_() + return self + + def size(self, dim=None): + if self.num_relation: + size = torch.Size((self.num_node, self.num_node, self.num_relation)) + else: + size = torch.Size((self.num_node, self.num_node)) + if dim is None: + return size + return size[dim] + + @property + def shape(self): + return self.size() + + def copy_(self, src): + """ + Copy data from ``src`` into ``self`` and return ``self``. + + The ``src`` graph must have the same set of attributes as ``self``. + """ + self.edge_list.copy_(src.edge_list) + self.edge_weight.copy_(src.edge_weight) + self.num_node.copy_(src.num_node) + self.num_edge.copy_(src.num_edge) + if self.num_relation is not None: + self.num_relation.copy_(src.num_relation) + + keys = set(self.data_dict.keys()) + src_keys = set(src.data_dict.keys()) + if keys != src_keys: + raise RuntimeError("Attributes mismatch. Trying to assign attributes %s, " + "but current graph has attributes %s" % (src_keys, keys)) + for k, v in self.data_dict.items(): + v.copy_(src.data_dict[k]) + + return self + + def detach(self): + """ + Detach this graph. + """ + return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(), + num_node=self.num_node, num_relation=self.num_relation, + meta_dict=self.meta_dict, **utils.detach(self.data_dict)) + + def clone(self): + """ + Clone this graph. + """ + return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(), + num_node=self.num_node, num_relation=self.num_relation, + meta_dict=self.meta_dict, **utils.clone(self.data_dict)) + + def cuda(self, *args, **kwargs): + """ + Return a copy of this graph in CUDA memory. + + This is a non-op if the graph is already on the correct device. + """ + edge_list = self.edge_list.cuda(*args, **kwargs) + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, + num_node=self.num_node, num_relation=self.num_relation, + meta_dict=self.meta_dict, **utils.cuda(self.data_dict, *args, **kwargs)) + + def cpu(self): + """ + Return a copy of this graph in CPU memory. + + This is a non-op if the graph is already in CPU memory. + """ + edge_list = self.edge_list.cpu() + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, num_node=self.num_node, + num_relation=self.num_relation, meta_dict=self.meta_dict, **utils.cpu(self.data_dict)) + + def to(self, device, *args, **kwargs): + """ + Return a copy of this graph on the given device. + """ + device = torch.device(device) + if device.type == "cpu": + return self.cpu(*args, **kwargs) + else: + return self.cuda(device, *args, **kwargs) + + def __repr__(self): + fields = ["num_node=%d" % self.num_node, "num_edge=%d" % self.num_edge] + if self.num_relation is not None: + fields.append("num_relation=%d" % self.num_relation) + if self.device.type != "cpu": + fields.append("device='%s'" % self.device) + return "%s(%s)" % (self.__class__.__name__, ", ".join(fields)) + + def visualize(self, title=None, save_file=None, figure_size=(3, 3), ax=None, layout="spring"): + """ + Visualize this graph with matplotlib. + + Parameters: + title (str, optional): title for this graph + save_file (str, optional): ``png`` or ``pdf`` file to save visualization. + If not provided, show the figure in window. + figure_size (tuple of int, optional): width and height of the figure + ax (matplotlib.axes.Axes, optional): axis to plot the figure + layout (str, optional): graph layout + + See also: + `NetworkX graph layout`_ + + .. _NetworkX graph layout: + https://networkx.github.io/documentation/stable/reference/drawing.html#module-networkx.drawing.layout + """ + is_root = ax is None + if ax is None: + fig = plt.figure(figsize=figure_size) + if title is not None: + ax = plt.gca() + else: + ax = fig.add_axes([0, 0, 1, 1]) + if title is not None: + ax.set_title(title) + + edge_list = self.edge_list[:, :2].tolist() + G = nx.DiGraph(edge_list) + G.add_nodes_from(range(self.num_node)) + if hasattr(nx, "%s_layout" % layout): + func = getattr(nx, "%s_layout" % layout) + else: + raise ValueError("Unknown networkx layout `%s`" % layout) + if layout == "spring" or layout == "random": + pos = func(G, seed=0) + else: + pos = func(G) + nx.draw_networkx(G, pos, ax=ax) + if self.num_relation: + edge_labels = self.edge_list[:, 2].tolist() + edge_labels = {tuple(e): l for e, l in zip(edge_list, edge_labels)} + nx.draw_networkx_edge_labels(G, pos, edge_labels, ax=ax) + ax.set_frame_on(False) + + if is_root: + if save_file: + fig.savefig(save_file) + else: + fig.show() + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return NotImplemented + + def __getstate__(self): + state = {} + cls = self.__class__ + for k, v in self.__dict__.items(): + # do not pickle property / cached property + if hasattr(cls, k) and isinstance(getattr(cls, k), property): + continue + state[k] = v + return state + + +class PackedGraph(Graph): + """ + Container for sparse graphs with variadic sizes. + + To create a PackedGraph from Graph objects + + >>> batch = data.Graph.pack(graphs) + + To retrieve Graph objects from a PackedGraph + + >>> graphs = batch.unpack() + + .. warning:: + + Edges of the same graph are guaranteed to be consecutive in the edge list. + However, this class doesn't enforce any order on the edges. + + Parameters: + edge_list (array_like, optional): list of edges of shape :math:`(|E|, 2)` or :math:`(|E|, 3)`. + Each tuple is (node_in, node_out) or (node_in, node_out, relation). + edge_weight (array_like, optional): edge weights of shape :math:`(|E|,)` + num_nodes (array_like, optional): number of nodes in each graph + By default, it will be inferred from the largest id in `edge_list` + num_edges (array_like, optional): number of edges in each graph + num_relation (int, optional): number of relations + node_feature (array_like, optional): node features of shape :math:`(|V|, ...)` + edge_feature (array_like, optional): edge features of shape :math:`(|E|, ...)` + offsets (array_like, optional): node id offsets of shape :math:`(|E|,)`. + If not provided, nodes in `edge_list` should be relative index, i.e., the index in each graph. + If provided, nodes in `edge_list` should be absolute index, i.e., the index in the packed graph. + """ + + unpacked_type = Graph + + def __init__(self, edge_list=None, edge_weight=None, num_nodes=None, num_edges=None, num_relation=None, + offsets=None, **kwargs): + edge_list, num_nodes, num_edges, num_cum_nodes, num_cum_edges, offsets = \ + self._get_cumulative(edge_list, num_nodes, num_edges, offsets) + + if offsets is None: + offsets = self._get_offsets(num_nodes, num_edges, num_cum_nodes) + edge_list = edge_list.clone() + edge_list[:, :2] += offsets.unsqueeze(-1) + + num_node = num_nodes.sum() + if (edge_list[:, :2] >= num_node).any(): + raise ValueError("Sum of `num_nodes` is %d, but found %d in `edge_list`" % + (num_node, edge_list[:, :2].max())) + + self._offsets = offsets + self.num_nodes = num_nodes + self.num_edges = num_edges + self.num_cum_nodes = num_cum_nodes + self.num_cum_edges = num_cum_edges + + super(PackedGraph, self).__init__(edge_list, edge_weight=edge_weight, num_node=num_node, + num_relation=num_relation, **kwargs) + + def _get_offsets(self, num_nodes=None, num_edges=None, num_cum_nodes=None, num_cum_edges=None): + if num_nodes is None: + prepend = torch.tensor([0], device=self.device) + num_nodes = torch.diff(num_cum_nodes, prepend=prepend) + if num_edges is None: + prepend = torch.tensor([0], device=self.device) + num_edges = torch.diff(num_cum_edges, prepend=prepend) + if num_cum_nodes is None: + num_cum_nodes = num_nodes.cumsum(0) + return (num_cum_nodes - num_nodes).repeat_interleave(num_edges) + + def merge(self, graph2graph): + """ + Merge multiple graphs into a single graph. + + Parameters: + graph2graph (array_like): ID of the new graph each graph belongs to + """ + graph2graph = torch.as_tensor(graph2graph, dtype=torch.long, device=self.device) + # coalesce arbitrary graph IDs to [0, n) + _, graph2graph = torch.unique(graph2graph, return_inverse=True) + + graph_key = graph2graph * self.batch_size + torch.arange(self.batch_size, device=self.device) + graph_index = graph_key.argsort() + graph = self.subbatch(graph_index) + graph2graph = graph2graph[graph_index] + + num_graph = graph2graph[-1] + 1 + num_nodes = scatter_add(graph.num_nodes, graph2graph, dim_size=num_graph) + num_edges = scatter_add(graph.num_edges, graph2graph, dim_size=num_graph) + offsets = self._get_offsets(num_nodes, num_edges) + + data_dict, meta_dict = graph.data_mask(exclude="graph") + + return type(self)(graph.edge_list, edge_weight=graph.edge_weight, num_nodes=num_nodes, + num_edges=num_edges, num_relation=graph.num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def unpack(self): + """ + Unpack this packed graph into a list of graphs. + + Returns: + list of Graph + """ + graphs = [] + for i in range(self.batch_size): + graphs.append(self.get_item(i)) + return graphs + + def __iter__(self): + self._iter_index = 0 + return self + + def __next__(self): + if self._iter_index < self.batch_size: + item = self[self._iter_index] + self._iter_index += 1 + return item + raise StopIteration + + def _check_attribute(self, key, value): + for type in self._meta_contexts: + if "reference" in type: + if value.dtype != torch.long: + raise TypeError("Tensors used as reference must be long tensors") + if type == "node": + if len(value) != self.num_node: + raise ValueError("Expect node attribute `%s` to have shape (%d, *), but found %s" % + (key, self.num_node, value.shape)) + elif type == "edge": + if len(value) != self.num_edge: + raise ValueError("Expect edge attribute `%s` to have shape (%d, *), but found %s" % + (key, self.num_edge, value.shape)) + elif type == "graph": + if len(value) != self.batch_size: + raise ValueError("Expect graph attribute `%s` to have shape (%d, *), but found %s" % + (key, self.batch_size, value.shape)) + elif type == "node reference": + is_valid = (value >= -1) & (value < self.num_node) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect node reference in [-1, %d), but found %d" % + (self.num_node, error_value[0])) + elif type == "edge reference": + is_valid = (value >= -1) & (value < self.num_edge) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect edge reference in [-1, %d), but found %d" % + (self.num_edge, error_value[0])) + elif type == "graph reference": + is_valid = (value >= -1) & (value < self.batch_size) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect graph reference in [-1, %d), but found %d" % + (self.batch_size, error_value[0])) + + def unpack_data(self, data, type="auto"): + """ + Unpack node or edge data according to the packed graph. + + Parameters: + data (Tensor): data to unpack + type (str, optional): data type. Can be ``auto``, ``node``, or ``edge``. + + Returns: + list of Tensor + """ + if type == "auto": + if self.num_node == self.num_edge: + raise ValueError("Ambiguous type. Please specify either `node` or `edge`") + if len(data) == self.num_node: + type = "node" + elif len(data) == self.num_edge: + type = "edge" + else: + raise ValueError("Graph has %d nodes and %d edges, but data has %d entries" % + (self.num_node, self.num_edge, len(data))) + data_list = [] + if type == "node": + for i in range(self.batch_size): + data_list.append(data[self.num_cum_nodes[i] - self.num_nodes[i]: self.num_cum_nodes[i]]) + elif type == "edge": + for i in range(self.batch_size): + data_list.append(data[self.num_cum_edges[i] - self.num_edges[i]: self.num_cum_edges[i]]) + + return data_list + + def repeat(self, count): + """ + Repeat this packed graph. This function behaves similarly to `torch.Tensor.repeat`_. + + .. _torch.Tensor.repeat: + https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html + + Parameters: + count (int): number of repetitions + + Returns: + PackedGraph + """ + num_nodes = self.num_nodes.repeat(count) + num_edges = self.num_edges.repeat(count) + offsets = self._get_offsets(num_nodes, num_edges) + edge_list = self.edge_list.repeat(count, 1) + edge_list[:, :2] += (offsets - self._offsets.repeat(count)).unsqueeze(-1) + + data_dict = {} + for k, v in self.data_dict.items(): + shape = [1] * v.ndim + shape[0] = count + length = len(v) + v = v.repeat(shape) + for _type in self.meta_dict[k]: + if _type == "node reference": + pack_offsets = torch.arange(count, device=self.device) * self.num_node + v = v + pack_offsets.repeat_interleave(length) + elif _type == "edge reference": + pack_offsets = torch.arange(count, device=self.device) * self.num_edge + v = v + pack_offsets.repeat_interleave(length) + elif _type == "graph reference": + pack_offsets = torch.arange(count, device=self.device) * self.batch_size + v = v + pack_offsets.repeat_interleave(length) + data_dict[k] = v + + return type(self)(edge_list, edge_weight=self.edge_weight.repeat(count), + num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, + offsets=offsets, meta_dict=self.meta_dict, **data_dict) + + def repeat_interleave(self, repeats): + """ + Repeat this packed graph. This function behaves similarly to `torch.repeat_interleave`_. + + .. _torch.repeat_interleave: + https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html + + Parameters: + repeats (Tensor or int): number of repetitions for each graph + + Returns: + PackedGraph + """ + repeats = torch.as_tensor(repeats, dtype=torch.long, device=self.device) + if repeats.numel() == 1: + repeats = repeats * torch.ones(self.batch_size, dtype=torch.long, device=self.device) + num_nodes = self.num_nodes.repeat_interleave(repeats) + num_edges = self.num_edges.repeat_interleave(repeats) + num_cum_nodes = num_nodes.cumsum(0) + num_cum_edges = num_edges.cumsum(0) + num_node = num_nodes.sum() + num_edge = num_edges.sum() + batch_size = repeats.sum() + num_graphs = torch.ones(batch_size, device=self.device) + + # special case 1: graphs[i] may have no node or no edge + # special case 2: repeats[i] may be 0 + cum_repeats_shifted = repeats.cumsum(0) - repeats + graph_mask = cum_repeats_shifted < batch_size + cum_repeats_shifted = cum_repeats_shifted[graph_mask] + + index = num_cum_nodes - num_nodes + index = torch.cat([index, index[cum_repeats_shifted]]) + value = torch.cat([-num_nodes, self.num_nodes[graph_mask]]) + mask = index < num_node + node_index = scatter_add(value[mask], index[mask], dim_size=num_node) + node_index = (node_index + 1).cumsum(0) - 1 + + index = num_cum_edges - num_edges + index = torch.cat([index, index[cum_repeats_shifted]]) + value = torch.cat([-num_edges, self.num_edges[graph_mask]]) + mask = index < num_edge + edge_index = scatter_add(value[mask], index[mask], dim_size=num_edge) + edge_index = (edge_index + 1).cumsum(0) - 1 + + graph_index = torch.repeat_interleave(repeats) + + offsets = self._get_offsets(num_nodes, num_edges) + edge_list = self.edge_list[edge_index] + edge_list[:, :2] += (offsets - self._offsets[edge_index]).unsqueeze(-1) + + node_offsets = None + edge_offsets = None + graph_offsets = None + data_dict = {} + for k, v in self.data_dict.items(): + num_xs = None + pack_offsets = None + for _type in self.meta_dict[k]: + if _type == "node": + v = v[node_index] + num_xs = num_nodes + elif _type == "edge": + v = v[edge_index] + num_xs = num_edges + elif _type == "graph": + v = v[graph_index] + num_xs = num_graphs + elif _type == "node reference": + if node_offsets is None: + node_offsets = self._get_repeat_pack_offsets(self.num_nodes, repeats) + pack_offsets = node_offsets + elif _type == "edge reference": + if edge_offsets is None: + edge_offsets = self._get_repeat_pack_offsets(self.num_edges, repeats) + pack_offsets = edge_offsets + elif _type == "graph reference": + if graph_offsets is None: + graph_offsets = self._get_repeat_pack_offsets(num_graphs, repeats) + pack_offsets = graph_offsets + # add offsets to make references point to indexes in their own graph + if num_xs is not None and pack_offsets is not None: + v = v + pack_offsets.repeat_interleave(num_xs) + data_dict[k] = v + + return type(self)(edge_list, edge_weight=self.edge_weight[edge_index], + num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, + offsets=offsets, meta_dict=self.meta_dict, **data_dict) + + def get_item(self, index): + """ + Get the i-th graph from this packed graph. + + Parameters: + index (int): graph index + + Returns: + Graph + """ + node_index = torch.arange(self.num_cum_nodes[index] - self.num_nodes[index], self.num_cum_nodes[index], + device=self.device) + edge_index = torch.arange(self.num_cum_edges[index] - self.num_edges[index], self.num_cum_edges[index], + device=self.device) + graph_index = index + edge_list = self.edge_list[edge_index].clone() + edge_list[:, :2] -= self._offsets[edge_index].unsqueeze(-1) + data_dict, meta_dict = self.data_mask(node_index, edge_index, graph_index=graph_index) + + return self.unpacked_type(edge_list, edge_weight=self.edge_weight[edge_index], num_node=self.num_nodes[index], + num_relation=self.num_relation, meta_dict=meta_dict, **data_dict) + + def _get_cumulative(self, edge_list, num_nodes, num_edges, offsets): + if edge_list is None: + raise ValueError("`edge_list` should be provided") + if num_edges is None: + raise ValueError("`num_edges` should be provided") + + edge_list = torch.as_tensor(edge_list) + num_edges = torch.as_tensor(num_edges, device=edge_list.device) + num_edge = num_edges.sum() + if num_edge != len(edge_list): + raise ValueError("Sum of `num_edges` is %d, but found %d edges in `edge_list`" % (num_edge, len(edge_list))) + num_cum_edges = num_edges.cumsum(0) + + if offsets is None: + _edge_list = edge_list + else: + offsets = torch.as_tensor(offsets, device=edge_list.device) + _edge_list = edge_list.clone() + _edge_list[:, :2] -= offsets.unsqueeze(-1) + if num_nodes is None: + num_nodes = [] + for num_edge, num_cum_edge in zip(num_edges, num_cum_edges): + num_nodes.append(self._maybe_num_node(_edge_list[num_cum_edge - num_edge: num_cum_edge])) + num_nodes = torch.as_tensor(num_nodes, device=edge_list.device) + num_cum_nodes = num_nodes.cumsum(0) + + return edge_list, num_nodes, num_edges, num_cum_nodes, num_cum_edges, offsets + + def _get_num_xs(self, index, num_cum_xs): + x = torch.zeros(num_cum_xs[-1], dtype=torch.long, device=self.device) + x[index] = 1 + num_cum_indexes = x.cumsum(0) + num_cum_indexes = torch.cat([torch.zeros(1, dtype=torch.long, device=self.device), num_cum_indexes]) + new_num_cum_xs = num_cum_indexes[num_cum_xs] + prepend = torch.zeros(1, dtype=torch.long, device=self.device) + new_num_xs = torch.diff(new_num_cum_xs, prepend=prepend) + return new_num_xs + + def data_mask(self, node_index=None, edge_index=None, graph_index=None, include=None, exclude=None): + data_dict, meta_dict = self.data_by_meta(include, exclude) + node_mapping = None + edge_mapping = None + graph_mapping = None + for k, v in data_dict.items(): + for type in meta_dict[k]: + if type == "node" and node_index is not None: + v = v[node_index] + elif type == "edge" and edge_index is not None: + v = v[edge_index] + elif type == "graph" and graph_index is not None: + v = v[graph_index] + elif type == "node reference" and node_index is not None: + if node_mapping is None: + node_mapping = self._get_mapping(node_index, self.num_node) + v = node_mapping[v] + elif type == "edge reference" and edge_index is not None: + if edge_mapping is None: + edge_mapping = self._get_mapping(edge_index, self.num_edge) + v = edge_mapping[v] + elif type == "graph reference" and graph_index is not None: + if graph_mapping is None: + graph_mapping = self._get_mapping(graph_index, self.batch_size) + v = graph_mapping[v] + data_dict[k] = v + + return data_dict, meta_dict + + def __getitem__(self, index): + # why do we check tuple? + # case 1: x[0, 1] is parsed as (0, 1) + # case 2: x[[0, 1]] is parsed as [0, 1] + if not isinstance(index, tuple): + index = (index,) + + if isinstance(index[0], int): + item = self.get_item(index[0]) + if len(index) > 1: + item = item[index[1:]] + return item + if len(index) > 1: + raise ValueError("Complex indexing is not supported for PackedGraph") + + index = self._standarize_index(index[0], self.batch_size) + count = index.bincount(minlength=self.batch_size) + if self.batch_size > 0 and count.max() > 1: + graph = self.repeat_interleave(count) + index_order = index.argsort() + order = torch.zeros_like(index) + order[index_order] = torch.arange(len(index), dtype=torch.long, device=self.device) + return graph.subbatch(order) + + return self.subbatch(index) + + def __len__(self): + return len(self.num_nodes) + + def full(self): + """ + Return a pack of fully connected graphs. + + This is useful for computing node-pair-wise features. + The computation can be implemented as message passing over a fully connected graph. + + Returns: + PackedGraph + """ + # TODO: more efficient implementation? + graphs = self.unpack() + graphs = [graph.full() for graph in graphs] + return graphs[0].pack(graphs) + + @utils.cached_property + def node2graph(self): + """Node id to graph id mapping.""" + node2graph = torch.repeat_interleave(self.num_nodes) + return node2graph + + @utils.cached_property + def edge2graph(self): + """Edge id to graph id mapping.""" + edge2graph = torch.repeat_interleave(self.num_edges) + return edge2graph + + @property + def batch_size(self): + """Batch size.""" + return len(self.num_nodes) + + def node_mask(self, index, compact=False): + """ + Return a masked packed graph based on the specified nodes. + + Note the compact option is only applied to node ids but not graph ids. + To generate compact graph ids, use :meth:`subbatch`. + + Parameters: + index (array_like): node index + compact (bool, optional): compact node ids or not + + Returns: + PackedGraph + """ + index = self._standarize_index(index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + mapping[index] = torch.arange(len(index), device=self.device) + num_nodes = self._get_num_xs(index, self.num_cum_nodes) + offsets = self._get_offsets(num_nodes, self.num_edges) + else: + mapping[index] = index + num_nodes = self.num_nodes + offsets = self._offsets + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + num_edges = self._get_num_xs(edge_index, self.num_cum_edges) + + if compact: + data_dict, meta_dict = self.data_mask(index, edge_index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, + num_edges=num_edges, num_relation=self.num_relation, offsets=offsets[edge_index], + meta_dict=meta_dict, **data_dict) + + def edge_mask(self, index): + """ + Return a masked packed graph based on the specified edges. + + Parameters: + index (array_like): edge index + + Returns: + PackedGraph + """ + index = self._standarize_index(index, self.num_edge) + data_dict, meta_dict = self.data_mask(edge_index=index) + num_edges = self._get_num_xs(index, self.num_cum_edges) + + return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index], num_nodes=self.num_nodes, + num_edges=num_edges, num_relation=self.num_relation, offsets=self._offsets[index], + meta_dict=meta_dict, **data_dict) + + def graph_mask(self, index, compact=False): + """ + Return a masked packed graph based on the specified graphs. + + This function can also be used to re-order the graphs. + + Parameters: + index (array_like): graph index + compact (bool, optional): compact graph ids or not + + Returns: + PackedGraph + """ + index = self._standarize_index(index, self.batch_size) + graph_mapping = -torch.ones(self.batch_size, dtype=torch.long, device=self.device) + graph_mapping[index] = torch.arange(len(index), device=self.device) + + node_index = graph_mapping[self.node2graph] >= 0 + node_index = self._standarize_index(node_index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + key = graph_mapping[self.node2graph[node_index]] * self.num_node + node_index + order = key.argsort() + node_index = node_index[order] + mapping[node_index] = torch.arange(len(node_index), device=self.device) + num_nodes = self.num_nodes[index] + else: + mapping[node_index] = node_index + num_nodes = torch.zeros_like(self.num_nodes) + num_nodes[index] = self.num_nodes[index] + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + edge_index = self._standarize_index(edge_index, self.num_edge) + if compact: + key = graph_mapping[self.edge2graph[edge_index]] * self.num_edge + edge_index + order = key.argsort() + edge_index = edge_index[order] + num_edges = self.num_edges[index] + else: + num_edges = torch.zeros_like(self.num_edges) + num_edges[index] = self.num_edges[index] + offsets = self._get_offsets(num_nodes, num_edges) + + if compact: + data_dict, meta_dict = self.data_mask(node_index, edge_index, graph_index=index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, + num_edges=num_edges, num_relation=self.num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def subbatch(self, index): + """ + Return a subbatch based on the specified graphs. + Equivalent to :meth:`graph_mask(index, compact=True) `. + + Parameters: + index (array_like): graph index + + Returns: + PackedGraph + + See also: + :meth:`PackedGraph.graph_mask` + """ + return self.graph_mask(index, compact=True) + + def line_graph(self): + """ + Construct a packed line graph of this packed graph. + The node features of the line graphs are inherited from the edge features of the original graphs. + + In the line graph, each node corresponds to an edge in the original graph. + For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph, + there is a directed edge (a, b) -> (b, c) in the line graph. + + Returns: + PackedGraph + """ + node_in, node_out = self.edge_list.t()[:2] + edge_index = torch.arange(self.num_edge, device=self.device) + edge_in = edge_index[node_out.argsort()] + edge_out = edge_index[node_in.argsort()] + + degree_in = node_in.bincount(minlength=self.num_node) + degree_out = node_out.bincount(minlength=self.num_node) + size = degree_out * degree_in + starts = (size.cumsum(0) - size).repeat_interleave(size) + range = torch.arange(size.sum(), device=self.device) + # each node u has degree_out[u] * degree_in[u] local edges + local_index = range - starts + local_inner_size = degree_in.repeat_interleave(size) + edge_in_offset = (degree_out.cumsum(0) - degree_out).repeat_interleave(size) + edge_out_offset = (degree_in.cumsum(0) - degree_in).repeat_interleave(size) + edge_in_index = torch.div(local_index, local_inner_size, rounding_mode="floor") + edge_in_offset + edge_out_index = local_index % local_inner_size + edge_out_offset + + edge_in = edge_in[edge_in_index] + edge_out = edge_out[edge_out_index] + edge_list = torch.stack([edge_in, edge_out], dim=-1) + node_feature = getattr(self, "edge_feature", None) + num_nodes = self.num_edges + num_edges = scatter_add(size, self.node2graph, dim=0, dim_size=self.batch_size) + offsets = self._get_offsets(num_nodes, num_edges) + + return PackedGraph(edge_list, num_nodes=num_nodes, num_edges=num_edges, offsets=offsets, + node_feature=node_feature) + + def undirected(self, add_inverse=False): + """ + Flip all the edges to create undirected graphs. + + For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. + The inverse relation for relation :math:`r` is defined as :math:`|R| + r`. + + Parameters: + add_inverse (bool, optional): whether to use inverse relations for flipped edges + """ + edge_list = self.edge_list.clone() + edge_list[:, :2] = edge_list[:, :2].flip(1) + num_relation = self.num_relation + if num_relation and add_inverse: + edge_list[:, 2] += num_relation + num_relation = num_relation * 2 + edge_list = torch.stack([self.edge_list, edge_list], dim=1).flatten(0, 1) + offsets = self._offsets.unsqueeze(-1).expand(-1, 2).flatten() + + index = torch.arange(self.num_edge, device=self.device).unsqueeze(-1).expand(-1, 2).flatten() + data_dict, meta_dict = self.data_mask(edge_index=index, exclude="edge reference") + + return type(self)(edge_list, edge_weight=self.edge_weight[index], num_nodes=self.num_nodes, + num_edges=self.num_edges * 2, num_relation=num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def detach(self): + """ + Detach this packed graph. + """ + return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(), + num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, + offsets=self._offsets, meta_dict=self.meta_dict, **utils.detach(self.data_dict)) + + def clone(self): + """ + Clone this packed graph. + """ + return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(), + num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, + offsets=self._offsets, meta_dict=self.meta_dict, **utils.clone(self.data_dict)) + + def cuda(self, *args, **kwargs): + """ + Return a copy of this packed graph in CUDA memory. + + This is a non-op if the graph is already on the correct device. + """ + edge_list = self.edge_list.cuda(*args, **kwargs) + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, + num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, + offsets=self._offsets, meta_dict=self.meta_dict, + **utils.cuda(self.data_dict, *args, **kwargs)) + + def cpu(self): + """ + Return a copy of this packed graph in CPU memory. + + This is a non-op if the graph is already in CPU memory. + """ + edge_list = self.edge_list.cpu() + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, + num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, + offsets=self._offsets, meta_dict=self.meta_dict, **utils.cpu(self.data_dict)) + + def __repr__(self): + fields = ["batch_size=%d" % self.batch_size, + "num_nodes=%s" % pretty.long_array(self.num_nodes.tolist()), + "num_edges=%s" % pretty.long_array(self.num_edges.tolist())] + if self.num_relation is not None: + fields.append("num_relation=%d" % self.num_relation) + if self.device.type != "cpu": + fields.append("device='%s'" % self.device) + return "%s(%s)" % (self.__class__.__name__, ", ".join(fields)) + + def visualize(self, titles=None, save_file=None, figure_size=(3, 3), layout="spring", num_row=None, num_col=None): + """ + Visualize the packed graphs with matplotlib. + + Parameters: + titles (list of str, optional): title for each graph. Default is the ID of each graph. + save_file (str, optional): ``png`` or ``pdf`` file to save visualization. + If not provided, show the figure in window. + figure_size (tuple of int, optional): width and height of the figure + layout (str, optional): graph layout + num_row (int, optional): number of rows in the figure + num_col (int, optional): number of columns in the figure + + See also: + `NetworkX graph layout`_ + + .. _NetworkX graph layout: + https://networkx.github.io/documentation/stable/reference/drawing.html#module-networkx.drawing.layout + """ + if titles is None: + graph = self.get_item(0) + titles = ["%s %d" % (type(graph).__name__, i) for i in range(self.batch_size)] + if num_col is None: + if num_row is None: + num_col = math.ceil(self.batch_size ** 0.5) + else: + num_col = math.ceil(self.batch_size / num_row) + if num_row is None: + num_row = math.ceil(self.batch_size / num_col) + + figure_size = (num_col * figure_size[0], num_row * figure_size[1]) + fig = plt.figure(figsize=figure_size) + + for i in range(self.batch_size): + graph = self.get_item(i) + ax = fig.add_subplot(num_row, num_col, i + 1) + graph.visualize(title=titles[i], ax=ax, layout=layout) + # remove the space of axis labels + fig.tight_layout() + + if save_file: + fig.savefig(save_file) + else: + fig.show() + + +Graph.packed_type = PackedGraph + + +def cat(graphs): + for i, graph in enumerate(graphs): + if not isinstance(graph, PackedGraph): + graphs[i] = graph.pack([graph]) + + edge_list = torch.cat([graph.edge_list for graph in graphs]) + pack_num_nodes = torch.stack([graph.num_node for graph in graphs]) + pack_num_edges = torch.stack([graph.num_edge for graph in graphs]) + pack_num_cum_edges = pack_num_edges.cumsum(0) + graph_index = pack_num_cum_edges < len(edge_list) + pack_offsets = scatter_add(pack_num_nodes[graph_index], pack_num_cum_edges[graph_index], + dim_size=len(edge_list)) + pack_offsets = pack_offsets.cumsum(0) + + edge_list[:, :2] += pack_offsets.unsqueeze(-1) + offsets = torch.cat([graph._offsets for graph in graphs]) + pack_offsets + + edge_weight = torch.cat([graph.edge_weight for graph in graphs]) + num_nodes = torch.cat([graph.num_nodes for graph in graphs]) + num_edges = torch.cat([graph.num_edges for graph in graphs]) + num_relation = graphs[0].num_relation + assert all(graph.num_relation == num_relation for graph in graphs) + + # only keep attributes that exist in all graphs + # TODO: this interface is not safe. re-design the interface + keys = set(graphs[0].meta_dict.keys()) + for graph in graphs: + keys = keys.intersection(graph.meta_dict.keys()) + + meta_dict = {k: graphs[0].meta_dict[k] for k in keys} + data_dict = {} + for k in keys: + data_dict[k] = torch.cat([graph.data_dict[k] for graph in graphs]) + + return type(graphs[0])(edge_list, edge_weight=edge_weight, + num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) \ No newline at end of file diff --git a/build/lib/torchdrug/data/molecule.py b/build/lib/torchdrug/data/molecule.py new file mode 100644 index 00000000..edf5406f --- /dev/null +++ b/build/lib/torchdrug/data/molecule.py @@ -0,0 +1,1034 @@ +import math +import warnings +from copy import copy +from collections.abc import Sequence + +from matplotlib import pyplot as plt +from rdkit import Chem +from rdkit.Chem.Scaffolds import MurckoScaffold +import torch +from torch_scatter import scatter_add, scatter_min + +from torchdrug import utils +from torchdrug.data import constant, Graph, PackedGraph +from torchdrug.core import Registry as R +from torchdrug.data.rdkit import draw +from torchdrug.utils import pretty + +plt.switch_backend("agg") + + +class Molecule(Graph): + """ + Molecules with predefined chemical features. + + By nature, molecules are undirected graphs. Each bond is stored as two directed edges in this class. + + .. warning:: + + This class doesn't enforce any order on edges. + + Parameters: + edge_list (array_like, optional): list of edges of shape :math:`(|E|, 3)`. + Each tuple is (node_in, node_out, bond_type). + atom_type (array_like, optional): atom types of shape :math:`(|V|,)` + bond_type (array_like, optional): bond types of shape :math:`(|E|,)` + formal_charge (array_like, optional): formal charges of shape :math:`(|V|,)` + explicit_hs (array_like, optional): number of explicit hydrogens of shape :math:`(|V|,)` + chiral_tag (array_like, optional): chirality tags of shape :math:`(|V|,)` + radical_electrons (array_like, optional): number of radical electrons of shape :math:`(|V|,)` + atom_map (array_likeb optional): atom mappings of shape :math:`(|V|,)` + bond_stereo (array_like, optional): bond stereochem of shape :math:`(|E|,)` + stereo_atoms (array_like, optional): ids of stereo atoms of shape :math:`(|E|,)` + """ + + bond2id = {"SINGLE": 0, "DOUBLE": 1, "TRIPLE": 2, "AROMATIC": 3} + atom2valence = {1: 1, 5: 3, 6: 4, 7: 3, 8: 2, 9: 1, 14: 4, 15: 5, 16: 6, 17: 1, 35: 1, 53: 7} + bond2valence = [1, 2, 3, 1.5] + id2bond = {v: k for k, v in bond2id.items()} + empty_mol = Chem.MolFromSmiles("") + dummy_mol = Chem.MolFromSmiles("CC") + + def __init__(self, edge_list=None, atom_type=None, bond_type=None, atom_feature=None, bond_feature=None, + mol_feature=None, formal_charge=None, explicit_hs=None, chiral_tag=None, radical_electrons=None, + atom_map=None, bond_stereo=None, stereo_atoms=None, node_position=None, **kwargs): + if "num_relation" not in kwargs: + kwargs["num_relation"] = len(self.bond2id) + super(Molecule, self).__init__(edge_list=edge_list, **kwargs) + atom_type, bond_type = self._standarize_atom_bond(atom_type, bond_type) + + formal_charge = self._standarize_attribute(formal_charge, self.num_node) + explicit_hs = self._standarize_attribute(explicit_hs, self.num_node) + chiral_tag = self._standarize_attribute(chiral_tag, self.num_node) + radical_electrons = self._standarize_attribute(radical_electrons, self.num_node) + atom_map = self._standarize_attribute(atom_map, self.num_node) + bond_stereo = self._standarize_attribute(bond_stereo, self.num_edge) + stereo_atoms = self._standarize_attribute(stereo_atoms, (self.num_edge, 2)) + if node_position is not None: + node_position = torch.as_tensor(node_position, dtype=torch.float, device=self.device) + + with self.atom(): + if atom_feature is not None: + self.atom_feature = torch.as_tensor(atom_feature, device=self.device) + self.atom_type = atom_type + self.formal_charge = formal_charge + self.explicit_hs = explicit_hs + self.chiral_tag = chiral_tag + self.radical_electrons = radical_electrons + self.atom_map = atom_map + if node_position is not None: + self.node_position = node_position + + with self.bond(): + if bond_feature is not None: + self.bond_feature = torch.as_tensor(bond_feature, device=self.device) + self.bond_type = bond_type + self.bond_stereo = bond_stereo + self.stereo_atoms = stereo_atoms + + with self.mol(): + if mol_feature is not None: + self.mol_feature = torch.as_tensor(mol_feature, device=self.device) + + def _standarize_atom_bond(self, atom_type, bond_type): + if atom_type is None: + raise ValueError("`atom_type` should be provided") + if bond_type is None: + raise ValueError("`bond_type` should be provided") + + atom_type = torch.as_tensor(atom_type, dtype=torch.long, device=self.device) + bond_type = torch.as_tensor(bond_type, dtype=torch.long, device=self.device) + return atom_type, bond_type + + def _standarize_attribute(self, attribute, size, dtype=torch.long, default=0): + if attribute is not None: + attribute = torch.as_tensor(attribute, dtype=dtype, device=self.device) + else: + if isinstance(size, torch.Tensor): + size = size.tolist() + if not isinstance(size, Sequence): + size = [size] + attribute = torch.full(size, default, dtype=dtype, device=self.device) + return attribute + + @classmethod + def _standarize_option(cls, option): + if option is None: + option = [] + elif isinstance(option, str): + option = [option] + return option + + def _check_no_stereo(self): + if (self.bond_stereo > 0).any(): + warnings.warn("Try to apply masks on molecules with stereo bonds. This may produce invalid molecules. " + "To discard stereo information, call `mol.bond_stereo[:] = 0` before applying masks.") + + def _maybe_num_node(self, edge_list): + if len(edge_list): + return edge_list[:, :2].max().item() + 1 + else: + return 0 + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_molecule(cls, mol, atom_feature="default", bond_feature="default", mol_feature=None, + with_hydrogen=False, kekulize=False): + """ + Create a molecule from an RDKit object. + + Parameters: + mol (rdchem.Mol): molecule + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + mol_feature (str or list of str, optional): molecule features to extract + with_hydrogen (bool, optional): store hydrogens in the molecule graph. + By default, hydrogens are dropped + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + if mol is None: + mol = cls.empty_mol + # some RDKit operations are in-place + # copy the object to avoid undesired behavior in the caller + mol = copy(mol) + if with_hydrogen: + mol = Chem.AddHs(mol) + if kekulize: + Chem.Kekulize(mol) + + atom_feature = cls._standarize_option(atom_feature) + bond_feature = cls._standarize_option(bond_feature) + mol_feature = cls._standarize_option(mol_feature) + + atom_type = [] + formal_charge = [] + explicit_hs = [] + chiral_tag = [] + radical_electrons = [] + atom_map = [] + _atom_feature = [] + dummy_atom = copy(cls.dummy_mol).GetAtomWithIdx(0) + atoms = [mol.GetAtomWithIdx(i) for i in range(mol.GetNumAtoms())] + [dummy_atom] + if mol.GetNumConformers() > 0: + node_position = torch.tensor(mol.GetConformer().GetPositions()) + else: + node_position = None + for atom in atoms: + atom_type.append(atom.GetAtomicNum()) + formal_charge.append(atom.GetFormalCharge()) + explicit_hs.append(atom.GetNumExplicitHs()) + chiral_tag.append(atom.GetChiralTag()) + radical_electrons.append(atom.GetNumRadicalElectrons()) + atom_map.append(atom.GetAtomMapNum()) + feature = [] + for name in atom_feature: + func = R.get("features.atom.%s" % name) + feature += func(atom) + _atom_feature.append(feature) + atom_type = torch.tensor(atom_type)[:-1] + atom_map = torch.tensor(atom_map)[:-1] + formal_charge = torch.tensor(formal_charge)[:-1] + explicit_hs = torch.tensor(explicit_hs)[:-1] + chiral_tag = torch.tensor(chiral_tag)[:-1] + radical_electrons = torch.tensor(radical_electrons)[:-1] + if len(atom_feature) > 0: + _atom_feature = torch.tensor(_atom_feature)[:-1] + else: + _atom_feature = None + + edge_list = [] + bond_type = [] + bond_stereo = [] + stereo_atoms = [] + _bond_feature = [] + dummy_bond = copy(cls.dummy_mol).GetBondWithIdx(0) + bonds = [mol.GetBondWithIdx(i) for i in range(mol.GetNumBonds())] + [dummy_bond] + for bond in bonds: + type = str(bond.GetBondType()) + stereo = bond.GetStereo() + if stereo: + _atoms = [a for a in bond.GetStereoAtoms()] + else: + _atoms = [0, 0] + if type not in cls.bond2id: + continue + type = cls.bond2id[type] + h, t = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_list += [[h, t, type], [t, h, type]] + # always explicitly store aromatic bonds, no matter kekulize or not + if bond.GetIsAromatic(): + type = cls.bond2id["AROMATIC"] + bond_type += [type, type] + bond_stereo += [stereo, stereo] + stereo_atoms += [_atoms, _atoms] + feature = [] + for name in bond_feature: + func = R.get("features.bond.%s" % name) + feature += func(bond) + _bond_feature += [feature, feature] + edge_list = edge_list[:-2] + bond_type = torch.tensor(bond_type)[:-2] + bond_stereo = torch.tensor(bond_stereo)[:-2] + stereo_atoms = torch.tensor(stereo_atoms)[:-2] + if len(bond_feature) > 0: + _bond_feature = torch.tensor(_bond_feature)[:-2] + else: + _bond_feature = None + + _mol_feature = [] + for name in mol_feature: + func = R.get("features.molecule.%s" % name) + _mol_feature += func(mol) + if len(mol_feature) > 0: + _mol_feature = torch.tensor(_mol_feature) + else: + _mol_feature = None + + num_relation = len(cls.bond2id) - 1 if kekulize else len(cls.bond2id) + return cls(edge_list, atom_type, bond_type, + formal_charge=formal_charge, explicit_hs=explicit_hs, + chiral_tag=chiral_tag, radical_electrons=radical_electrons, atom_map=atom_map, + bond_stereo=bond_stereo, stereo_atoms=stereo_atoms, node_position=node_position, + atom_feature=_atom_feature, bond_feature=_bond_feature, mol_feature=_mol_feature, + num_node=mol.GetNumAtoms(), num_relation=num_relation) + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_smiles(cls, smiles, atom_feature="default", bond_feature="default", mol_feature=None, + with_hydrogen=False, kekulize=False): + """ + Create a molecule from a SMILES string. + + Parameters: + smiles (str): SMILES string + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + mol_feature (str or list of str, optional): molecule features to extract + with_hydrogen (bool, optional): store hydrogens in the molecule graph. + By default, hydrogens are dropped + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError("Invalid SMILES `%s`" % smiles) + + return cls.from_molecule(mol, atom_feature, bond_feature, mol_feature, with_hydrogen, kekulize) + + def to_smiles(self, isomeric=True, atom_map=True, canonical=False): + """ + Return a SMILES string of this molecule. + + Parameters: + isomeric (bool, optional): keep isomeric information or not + atom_map (bool, optional): keep atom mapping or not + canonical (bool, optional): if true, return the canonical form of smiles + + Returns: + str + """ + mol = self.to_molecule() + if not atom_map: + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + smiles = Chem.MolToSmiles(mol, isomericSmiles=isomeric) + if canonical: + smiles_set = set() + while smiles not in smiles_set: + smiles_set.add(smiles) + mol = Chem.MolFromSmiles(smiles) + smiles = Chem.MolToSmiles(mol, isomericSmiles=isomeric) + return smiles + + def to_molecule(self, ignore_error=False): + """ + Return an RDKit object of this molecule. + + Parameters: + ignore_error (bool, optional): if true, return ``None`` for illegal molecules. + Otherwise, raise an exception. + + Returns: + rdchem.Mol + """ + mol = Chem.RWMol() + + atom_type = self.atom_type.tolist() + bond_type = self.bond_type.tolist() + formal_charge = self.formal_charge.tolist() + explicit_hs = self.explicit_hs.tolist() + chiral_tag = self.chiral_tag.tolist() + radical_electrons = self.radical_electrons.tolist() + atom_map = self.atom_map.tolist() + bond_stereo = self.bond_stereo.tolist() + stereo_atoms = self.stereo_atoms.tolist() + if hasattr(self, "node_position"): + node_position = self.node_position.tolist() + conformer = Chem.Conformer() + else: + conformer = None + for i in range(self.num_node): + atom = Chem.Atom(atom_type[i]) + atom.SetFormalCharge(formal_charge[i]) + atom.SetNumExplicitHs(explicit_hs[i]) + atom.SetChiralTag(Chem.ChiralType(chiral_tag[i])) + atom.SetNumRadicalElectrons(radical_electrons[i]) + atom.SetNoImplicit(explicit_hs[i] > 0 or radical_electrons[i] > 0) + atom.SetAtomMapNum(atom_map[i]) + if conformer: + conformer.SetAtomPosition(i, node_position[i]) + mol.AddAtom(atom) + if conformer: + mol.AddConformer(conformer) + + edge_list = self.edge_list.tolist() + for i in range(self.num_edge): + h, t, type = edge_list[i] + if h < t: + j = mol.AddBond(h, t, Chem.BondType.names[self.id2bond[type]]) + bond = mol.GetBondWithIdx(j - 1) + bond.SetIsAromatic(bond_type[i] == self.bond2id["AROMATIC"]) + bond.SetStereo(Chem.BondStereo(bond_stereo[i])) + j = 0 + for i in range(self.num_edge): + h, t, type = edge_list[i] + if h < t: + if bond_stereo[i]: + bond = mol.GetBondWithIdx(j) + bond.SetStereoAtoms(*stereo_atoms[i]) + j += 1 + + if ignore_error: + try: + with utils.no_rdkit_log(): + mol.UpdatePropertyCache() + Chem.AssignStereochemistry(mol) + mol.ClearComputedProps() + mol.UpdatePropertyCache() + except: + mol = None + else: + mol.UpdatePropertyCache() + Chem.AssignStereochemistry(mol) + mol.ClearComputedProps() + mol.UpdatePropertyCache() + + return mol + + def ion_to_molecule(self): + """ + Convert ions to molecules by adjusting hydrogens and electrons. + + Note [N+] will not be converted. + """ + data_dict = self.data_dict + + formal_charge = data_dict.pop("formal_charge") + explicit_hs = data_dict.pop("explicit_hs") + radical_electrons = data_dict.pop("radical_electrons") + pos_nitrogen = (self.atom_type == 7) & (self.explicit_valence > 3) + formal_charge = pos_nitrogen.long() + explicit_hs = torch.zeros_like(explicit_hs) + radical_electrons = torch.zeros_like(radical_electrons) + + return type(self)(self.edge_list, edge_weight=self.edge_weight, + num_node=self.num_node, num_relation=self.num_relation, + formal_charge=formal_charge, explicit_hs=explicit_hs, radical_electrons=radical_electrons, + meta_dict=self.meta_dict, **data_dict) + + def to_scaffold(self, chirality=False): + """ + Return a scaffold SMILES string of this molecule. + + Parameters: + chirality (bool, optional): consider chirality in the scaffold or not + + Returns: + str + """ + smiles = self.to_smiles() + scaffold = MurckoScaffold.MurckoScaffoldSmiles(smiles, includeChirality=chirality) + return scaffold + + def node_mask(self, index, compact=False): + self._check_no_stereo() + return super(Molecule, self).node_mask(index, compact) + + def edge_mask(self, index): + self._check_no_stereo() + return super(Molecule, self).edge_mask(index) + + def undirected(self, add_inverse=False): + if add_inverse: + raise ValueError("Bonds are undirected relations, but `add_inverse` is specified") + return super(Molecule, self).undirected(add_inverse) + + def atom(self): + """ + Context manager for atom attributes. + """ + return self.node() + + def bond(self): + """ + Context manager for bond attributes. + """ + return self.edge() + + def mol(self): + """ + Context manager for molecule attributes. + """ + return self.graph() + + def atom_reference(self): + """ + Context manager for atom references. + """ + return self.node_reference() + + def bond_reference(self): + """ + Context manager for bond references. + """ + return self.edge_reference() + + def mol_reference(self): + """ + Context mangaer for molecule references. + """ + return self.graph_reference() + + @property + def num_node(self): + return self.num_atom + + @num_node.setter + def num_node(self, value): + self.num_atom = value + + @property + def num_edge(self): + return self.num_bond + + @num_edge.setter + def num_edge(self, value): + self.num_bond = value + + atom2graph = Graph.node2graph + bond2graph = Graph.edge2graph + + @property + def node_feature(self): + return self.atom_feature + + @node_feature.setter + def node_feature(self, value): + self.atom_feature = value + + @property + def edge_feature(self): + return self.bond_feature + + @edge_feature.setter + def edge_feature(self, value): + self.bond_feature = value + + @property + def graph_feature(self): + return self.mol_feature + + @graph_feature.setter + def graph_feature(self, value): + self.mol_feature = value + + @utils.cached_property + def explicit_valence(self): + bond2valence = torch.tensor(self.bond2valence, device=self.device) + explicit_valence = scatter_add(bond2valence[self.edge_list[:, 2]], self.edge_list[:, 0], dim_size=self.num_node) + return explicit_valence.round().long() + + @utils.cached_property + def is_valid(self): + """A coarse implementation of valence check.""" + # TODO: cross-check by any domain expert + atom2valence = torch.tensor(float("nan")).repeat(constant.NUM_ATOM) + for k, v in self.atom2valence: + atom2valence[k] = v + atom2valence = torch.as_tensor(atom2valence, device=self.device) + + max_atom_valence = atom2valence[self.atom_type] + # special case for nitrogen + pos_nitrogen = (self.atom_type == 7) & (self.formal_charge == 1) + max_atom_valence[pos_nitrogen] = 4 + if torch.isnan(max_atom_valence).any(): + index = torch.isnan(max_atom_valence).nonzero()[0] + raise ValueError("Fail to check valence. Unknown atom type %d" % self.atom_type[index]) + + is_valid = (self.explicit_valence <= max_atom_valence).all() + return is_valid + + @utils.cached_property + def is_valid_rdkit(self): + try: + with utils.no_rdkit_log(): + mol = self.to_molecule() + Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) + is_valid = torch.ones(1, dtype=torch.bool, device=self.device) + except ValueError: + is_valid = torch.zeros(1, dtype=torch.bool, device=self.device) + return is_valid + + def __repr__(self): + fields = ["num_atom=%d" % self.num_atom, "num_bond=%d" % self.num_bond] + if self.device.type != "cpu": + fields.append("device='%s'" % self.device) + return "%s(%s)" % (self.__class__.__name__, ", ".join(fields)) + + def visualize(self, title=None, save_file=None, figure_size=(3, 3), ax=None, atom_map=False): + """ + Visualize this molecule with matplotlib. + + Parameters: + title (str, optional): title for this molecule + save_file (str, optional): ``png`` or ``pdf`` file to save visualization. + If not provided, show the figure in window. + figure_size (tuple of int, optional): width and height of the figure + ax (matplotlib.axes.Axes, optional): axis to plot the figure + atom_map (bool, optional): visualize atom mapping or not + """ + is_root = ax is None + if ax is None: + fig = plt.figure(figsize=figure_size) + if title is not None: + ax = plt.gca() + else: + ax = fig.add_axes([0, 0, 1, 1]) + if title is not None: + ax.set_title(title) + + mol = self.to_molecule() + if not atom_map: + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + draw.MolToMPL(mol, ax=ax) + ax.set_frame_on(False) + + if is_root: + if save_file: + fig.savefig(save_file) + else: + fig.show() + + def __eq__(self, other): + smiles = self.to_smiles(isomeric=False, atom_map=False, canonical=True) + other_smiles = other.to_smiles(isomeric=False, atom_map=False, canonical=True) + return smiles == other_smiles + + +class PackedMolecule(PackedGraph, Molecule): + """ + Container for molecules with variadic sizes. + + .. warning:: + + Edges of the same molecule are guaranteed to be consecutive in the edge list. + However, this class doesn't enforce any order on the edges. + + Parameters: + edge_list (array_like, optional): list of edges of shape :math:`(|E|, 3)`. + Each tuple is (node_in, node_out, bond_type). + atom_type (array_like, optional): atom types of shape :math:`(|V|,)` + bond_type (array_like, optional): bond types of shape :math:`(|E|,)` + num_nodes (array_like, optional): number of nodes in each graph + By default, it will be inferred from the largest id in `edge_list` + num_edges (array_like, optional): number of edges in each graph + offsets (array_like, optional): node id offsets of shape :math:`(|E|,)`. + If not provided, nodes in `edge_list` should be relative index, i.e., the index in each graph. + If provided, nodes in `edge_list` should be absolute index, i.e., the index in the packed graph. + """ + + unpacked_type = Molecule + atom2graph = PackedGraph.node2graph + bond2graph = PackedGraph.edge2graph + + def __init__(self, edge_list=None, atom_type=None, bond_type=None, num_nodes=None, num_edges=None, offsets=None, + **kwargs): + if "num_relation" not in kwargs: + kwargs["num_relation"] = len(self.bond2id) + super(PackedMolecule, self).__init__(edge_list=edge_list, num_nodes=num_nodes, num_edges=num_edges, + offsets=offsets, atom_type=atom_type, bond_type=bond_type, **kwargs) + + def ion_to_molecule(self): + """ + Convert ions to molecules by adjusting hydrogens and electrons. + + Note [N+] will not be converted. + """ + data_dict = self.data_dict + + formal_charge = data_dict.pop("formal_charge") + explicit_hs = data_dict.pop("explicit_hs") + radical_electrons = data_dict.pop("radical_electrons") + pos_nitrogen = (self.atom_type == 7) & (self.explicit_valence > 3) + formal_charge = pos_nitrogen.long() + explicit_hs = torch.zeros_like(explicit_hs) + radical_electrons = torch.zeros_like(radical_electrons) + + return type(self)(self.edge_list, edge_weight=self.edge_weight, + num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, + offsets=self._offsets, formal_charge=formal_charge, explicit_hs=explicit_hs, + radical_electrons=radical_electrons, meta_dict=self.meta_dict, **data_dict) + + @utils.cached_property + def is_valid(self): + """A coarse implementation of valence check.""" + # TODO: cross-check by any domain expert + atom2valence = torch.tensor(float("nan")).repeat(constant.NUM_ATOM) + for k, v in self.atom2valence.items(): + atom2valence[k] = v + atom2valence = torch.as_tensor(atom2valence, device=self.device) + + max_atom_valence = atom2valence[self.atom_type] + # special case for nitrogen + pos_nitrogen = (self.atom_type == 7) & (self.formal_charge == 1) + max_atom_valence[pos_nitrogen] = 4 + if torch.isnan(max_atom_valence).any(): + index = torch.isnan(max_atom_valence).nonzero()[0] + raise ValueError("Fail to check valence. Unknown atom type %d" % self.atom_type[index]) + + is_valid = self.explicit_valence <= max_atom_valence + is_valid = scatter_min(is_valid.long(), self.node2graph, dim_size=self.batch_size)[0].bool() + return is_valid + + @utils.cached_property + def is_valid_rdkit(self): + return torch.cat([mol.is_valid_rdkit for mol in self]) + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_molecule(cls, mols, atom_feature="default", bond_feature="default", mol_feature=None, + with_hydrogen=False, kekulize=False): + """ + Create a packed molecule from a list of RDKit objects. + + Parameters: + mols (list of rdchem.Mol): molecules + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + mol_feature (str or list of str, optional): molecule features to extract + with_hydrogen (bool, optional): store hydrogens in the molecule graph. + By default, hydrogens are dropped + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + atom_feature = cls._standarize_option(atom_feature) + bond_feature = cls._standarize_option(bond_feature) + mol_feature = cls._standarize_option(mol_feature) + + atom_type = [] + formal_charge = [] + explicit_hs = [] + chiral_tag = [] + radical_electrons = [] + atom_map = [] + + edge_list = [] + bond_type = [] + bond_stereo = [] + stereo_atoms = [] + node_position = [] + + _atom_feature = [] + _bond_feature = [] + _mol_feature = [] + num_nodes = [] + num_edges = [] + + mols = mols + [cls.dummy_mol] + for mol in mols: + if mol is None: + mol = cls.empty_mol + # some RDKit operations are in-place + # copy the object to avoid undesired behavior in the caller + mol = copy(mol) + if with_hydrogen: + mol = Chem.AddHs(mol) + if kekulize: + Chem.Kekulize(mol) + + if mol.GetNumConformers() > 0: + node_position += mol.GetConformer().GetPositions().tolist() + for atom in mol.GetAtoms(): + atom_type.append(atom.GetAtomicNum()) + formal_charge.append(atom.GetFormalCharge()) + explicit_hs.append(atom.GetNumExplicitHs()) + chiral_tag.append(atom.GetChiralTag()) + radical_electrons.append(atom.GetNumRadicalElectrons()) + atom_map.append(atom.GetAtomMapNum()) + feature = [] + for name in atom_feature: + func = R.get("features.atom.%s" % name) + feature += func(atom) + _atom_feature.append(feature) + + for bond in mol.GetBonds(): + type = str(bond.GetBondType()) + stereo = bond.GetStereo() + if stereo: + _atoms = list(bond.GetStereoAtoms()) + else: + _atoms = [0, 0] + if type not in cls.bond2id: + continue + type = cls.bond2id[type] + h, t = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + feature = [] + for name in bond_feature: + func = R.get("features.bond.%s" % name) + feature += func(bond) + edge_list += [[h, t, type], [t, h, type]] + # always explicitly store aromatic bonds + if bond.GetIsAromatic(): + type = cls.bond2id["AROMATIC"] + bond_type += [type, type] + bond_stereo += [stereo, stereo] + stereo_atoms += [_atoms, _atoms] + _bond_feature += [feature, feature] + + feature = [] + for name in mol_feature: + func = R.get("features.molecule.%s" % name) + feature += func(mol) + _mol_feature.append(feature) + + num_nodes.append(mol.GetNumAtoms()) + num_edges.append(mol.GetNumBonds() * 2) + + atom_type = torch.tensor(atom_type)[:-2] + atom_map = torch.tensor(atom_map)[:-2] + formal_charge = torch.tensor(formal_charge)[:-2] + explicit_hs = torch.tensor(explicit_hs)[:-2] + chiral_tag = torch.tensor(chiral_tag)[:-2] + radical_electrons = torch.tensor(radical_electrons)[:-2] + if len(node_position) > 0: + node_position = torch.tensor(node_position) + else: + node_position = None + if len(atom_feature) > 0: + _atom_feature = torch.tensor(_atom_feature)[:-2] + else: + _atom_feature = None + + num_nodes = num_nodes[:-1] + num_edges = num_edges[:-1] + edge_list = torch.tensor(edge_list)[:-2] + bond_type = torch.tensor(bond_type)[:-2] + bond_stereo = torch.tensor(bond_stereo)[:-2] + stereo_atoms = torch.tensor(stereo_atoms)[:-2] + if len(bond_feature) > 0: + _bond_feature = torch.tensor(_bond_feature)[:-2] + else: + _bond_feature = None + if len(mol_feature) > 0: + _mol_feature = torch.tensor(_mol_feature)[:-1] + else: + _mol_feature = None + + num_relation = len(cls.bond2id) - 1 if kekulize else len(cls.bond2id) + return cls(edge_list, atom_type, bond_type, + formal_charge=formal_charge, explicit_hs=explicit_hs, + chiral_tag=chiral_tag, radical_electrons=radical_electrons, atom_map=atom_map, + bond_stereo=bond_stereo, stereo_atoms=stereo_atoms, node_position=node_position, + atom_feature=_atom_feature, bond_feature=_bond_feature, mol_feature=_mol_feature, + num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation) + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_smiles(cls, smiles_list, atom_feature="default", bond_feature="default", mol_feature=None, + with_hydrogen=False, kekulize=False): + """ + Create a packed molecule from a list of SMILES strings. + + Parameters: + smiles_list (str): list of SMILES strings + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + mol_feature (str or list of str, optional): molecule features to extract + with_hydrogen (bool, optional): store hydrogens in the molecule graph. + By default, hydrogens are dropped + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + mols = [] + for smiles in smiles_list: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError("Invalid SMILES `%s`" % smiles) + mols.append(mol) + + return cls.from_molecule(mols, atom_feature, bond_feature, mol_feature, with_hydrogen, kekulize) + + def to_smiles(self, isomeric=True, atom_map=True, canonical=False): + """ + Return a list of SMILES strings. + + Parameters: + isomeric (bool, optional): keep isomeric information or not + atom_map (bool, optional): keep atom mapping or not + canonical (bool, optional): if true, return the canonical form of smiles + + Returns: + list of str + """ + mols = self.to_molecule() + smiles_list = [] + for mol in mols: + if not atom_map: + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + smiles = Chem.MolToSmiles(mol, isomericSmiles=isomeric) + if canonical: + smiles_set = set() + while smiles not in smiles_set: + smiles_set.add(smiles) + mol = Chem.MolFromSmiles(smiles) + smiles = Chem.MolToSmiles(mol, isomericSmiles=isomeric) + smiles_list.append(smiles) + return smiles_list + + def to_molecule(self, ignore_error=False): + """ + Return a list of RDKit objects. + + Parameters: + ignore_error (bool, optional): if true, return ``None`` for illegal molecules. + Otherwise, raise an exception. + + Returns: + list of rdchem.Mol + """ + atom_type = self.atom_type.tolist() + bond_type = self.bond_type.tolist() + formal_charge = self.formal_charge.tolist() + explicit_hs = self.explicit_hs.tolist() + chiral_tag = self.chiral_tag.tolist() + radical_electrons = self.radical_electrons.tolist() + atom_map = self.atom_map.tolist() + bond_stereo = self.bond_stereo.tolist() + stereo_atoms = self.stereo_atoms.tolist() + if hasattr(self, "node_position"): + node_position = self.node_position.tolist() + else: + node_position = None + num_cum_nodes = [0] + self.num_cum_nodes.tolist() + num_cum_edges = [0] + self.num_cum_edges.tolist() + edge_list = self.edge_list.clone() + edge_list[:, :2] -= self._offsets.unsqueeze(-1) + edge_list = edge_list.tolist() + + mols = [] + for i in range(self.batch_size): + mol = Chem.RWMol() + if node_position: + conformer = Chem.Conformer() + else: + conformer = None + for j in range(num_cum_nodes[i], num_cum_nodes[i + 1]): + atom = Chem.Atom(atom_type[j]) + atom.SetFormalCharge(formal_charge[j]) + atom.SetNumExplicitHs(explicit_hs[j]) + atom.SetChiralTag(Chem.ChiralType(chiral_tag[j])) + atom.SetNumRadicalElectrons(radical_electrons[j]) + atom.SetNoImplicit(explicit_hs[j] > 0 or radical_electrons[j] > 0) + atom.SetAtomMapNum(atom_map[j]) + if conformer: + conformer.SetAtomPosition(j - num_cum_nodes[i], node_position[j]) + mol.AddAtom(atom) + if conformer: + mol.AddConformer(conformer) + + for j in range(num_cum_edges[i], num_cum_edges[i + 1]): + h, t, type = edge_list[j] + if h < t: + k = mol.AddBond(h, t, Chem.BondType.names[self.id2bond[type]]) + bond = mol.GetBondWithIdx(k - 1) + bond.SetIsAromatic(bond_type[j] == self.bond2id["AROMATIC"]) + bond.SetStereo(Chem.BondStereo(bond_stereo[j])) + k = 0 + for j in range(num_cum_edges[i], num_cum_edges[i + 1]): + h, t, type = edge_list[j] + if h < t: + if bond_stereo[j]: + bond = mol.GetBondWithIdx(k) + # These do not necessarily need to be the highest 'ranking' atoms like CIP stereo requires. + # They can be any arbitrary atoms neighboring the begin and end atoms of this bond respectively. + # STEREOCIS or STEREOTRANS is then set relative to only these atoms. + bond.SetStereoAtoms(*stereo_atoms[j]) + k += 1 + + if ignore_error: + try: + with utils.no_rdkit_log(): + mol.UpdatePropertyCache() + Chem.AssignStereochemistry(mol) + mol.ClearComputedProps() + mol.UpdatePropertyCache() + except: + mol = None + else: + mol.UpdatePropertyCache() + Chem.AssignStereochemistry(mol) + mol.ClearComputedProps() + mol.UpdatePropertyCache() + mols.append(mol) + + return mols + + def node_mask(self, index, compact=False): + self._check_no_stereo() + return super(PackedMolecule, self).node_mask(index, compact) + + def edge_mask(self, index): + self._check_no_stereo() + return super(PackedMolecule, self).edge_mask(index) + + def undirected(self, add_inverse=False): + if add_inverse: + raise ValueError("Bonds are undirected relations, but `add_inverse` is specified") + return super(PackedMolecule, self).undirected(add_inverse) + + @property + def num_nodes(self): + return self.num_atoms + + @num_nodes.setter + def num_nodes(self, value): + self.num_atoms = value + + @property + def num_edges(self): + return self.num_bonds + + @num_edges.setter + def num_edges(self, value): + self.num_bonds = value + + def __repr__(self): + fields = ["batch_size=%d" % self.batch_size, + "num_atoms=%s" % pretty.long_array(self.num_atoms.tolist()), + "num_bonds=%s" % pretty.long_array(self.num_bonds.tolist())] + if self.device.type != "cpu": + fields.append("device='%s'" % self.device) + return "%s(%s)" % (self.__class__.__name__, ", ".join(fields)) + + def visualize(self, titles=None, save_file=None, figure_size=(3, 3), num_row=None, num_col=None, atom_map=False): + """ + Visualize the packed molecules with matplotlib. + + Parameters: + titles (list of str, optional): title for each molecule. Default is the ID of each molecule. + save_file (str, optional): ``png`` or ``pdf`` file to save visualization. + If not provided, show the figure in window. + figure_size (tuple of int, optional): width and height of the figure + num_row (int, optional): number of rows in the figure + num_col (int, optional): number of columns in the figure + atom_map (bool, optional): visualize atom mapping or not + """ + if titles is None: + graph = self.get_item(0) + titles = ["%s %d" % (type(graph).__name__, i) for i in range(self.batch_size)] + if num_col is None: + if num_row is None: + num_col = math.ceil(self.batch_size ** 0.5) + else: + num_col = math.ceil(self.batch_size / num_row) + if num_row is None: + num_row = math.ceil(self.batch_size / num_col) + + figure_size = (num_col * figure_size[0], num_row * figure_size[1]) + fig = plt.figure(figsize=figure_size) + + for i in range(self.batch_size): + graph = self.get_item(i) + ax = fig.add_subplot(num_row, num_col, i + 1) + graph.visualize(title=titles[i], ax=ax, atom_map=atom_map) + # remove the space of axis labels + fig.tight_layout() + + if save_file: + fig.savefig(save_file) + else: + fig.show() + + +Molecule.packed_type = PackedMolecule \ No newline at end of file diff --git a/build/lib/torchdrug/data/protein.py b/build/lib/torchdrug/data/protein.py new file mode 100644 index 00000000..4d9724d4 --- /dev/null +++ b/build/lib/torchdrug/data/protein.py @@ -0,0 +1,1358 @@ +import os +import string +import warnings +from collections import defaultdict + +from rdkit import Chem +import torch +from torch_scatter import scatter_add, scatter_max, scatter_min + +from torchdrug import utils +from torchdrug.data import Molecule, PackedMolecule, Dictionary, feature +from torchdrug.core import Registry as R +from torchdrug.utils import pretty + + +class Protein(Molecule): + """ + Proteins with predefined chemical features. + Support both residue-level and atom-level operations and ensure consistency between two views. + + .. warning:: + + The order of residues must be the same as the protein sequence. + However, this class doesn't enforce any order on nodes or edges. + Nodes may have a different order with residues. + + Parameters: + edge_list (array_like, optional): list of edges of shape :math:`(|E|, 3)`. + Each tuple is (node_in, node_out, bond_type). + atom_type (array_like, optional): atom types of shape :math:`(|V|,)` + bond_type (array_like, optional): bond types of shape :math:`(|E|,)` + residue_type (array_like, optional): residue types of shape :math:`(|V_{res}|,)` + view (str, optional): default view for this protein. Can be ``atom`` or ``residue``. + atom_name (array_like, optional): atom names in a residue of shape :math:`(|V|,)` + atom2residue (array_like, optional): atom id to residue id mapping of shape :math:`(|V|,)` + residue_feature (array_like, optional): residue features of shape :math:`(|V_{res}|, ...)` + is_hetero_atom (array_like, optional): hetero atom indicators of shape :math:`(|V|,)` + occupancy (array_like, optional): occupancy of shape :math:`(|V|,)` + b_factor (array_like, optional): temperature factors of shape :math:`(|V|,)` + residue_number (array_like, optional): residue numbers of shape :math:`(|V_{res}|,)` + insertion_code (array_like, optional): insertion codes of shape :math:`(|V_{res}|,)` + chain_id (array_like, optional): chain ids of shape :math:`(|V_{res}|,)` + """ + + _meta_types = {"node", "edge", "residue", "graph", + "node reference", "edge reference", "residue reference", "graph reference"} + dummy_protein = Chem.MolFromSequence("G") + dummy_atom = dummy_protein.GetAtomWithIdx(0) + + # TODO: rdkit isn't compatible with X in the sequence + residue2id = {"GLY": 0, "ALA": 1, "SER": 2, "PRO": 3, "VAL": 4, "THR": 5, "CYS": 6, "ILE": 7, "LEU": 8, + "ASN": 9, "ASP": 10, "GLN": 11, "LYS": 12, "GLU": 13, "MET": 14, "HIS": 15, "PHE": 16, + "ARG": 17, "TYR": 18, "TRP": 19} + residue_symbol2id = {"G": 0, "A": 1, "S": 2, "P": 3, "V": 4, "T": 5, "C": 6, "I": 7, "L": 8, "N": 9, + "D": 10, "Q": 11, "K": 12, "E": 13, "M": 14, "H": 15, "F": 16, "R": 17, "Y": 18, "W": 19} + atom_name2id = {"C": 0, "CA": 1, "CB": 2, "CD": 3, "CD1": 4, "CD2": 5, "CE": 6, "CE1": 7, "CE2": 8, + "CE3": 9, "CG": 10, "CG1": 11, "CG2": 12, "CH2": 13, "CZ": 14, "CZ2": 15, "CZ3": 16, + "N": 17, "ND1": 18, "ND2": 19, "NE": 20, "NE1": 21, "NE2": 22, "NH1": 23, "NH2": 24, + "NZ": 25, "O": 26, "OD1": 27, "OD2": 28, "OE1": 29, "OE2": 30, "OG": 31, "OG1": 32, + "OH": 33, "OXT": 34, "SD": 35, "SG": 36, "UNK": 37} + alphabet2id = {c: i for i, c in enumerate(" " + string.ascii_uppercase + string.ascii_lowercase + string.digits)} + id2residue = {v: k for k, v in residue2id.items()} + id2residue_symbol = {v: k for k, v in residue_symbol2id.items()} + id2atom_name = {v: k for k, v in atom_name2id.items()} + id2alphabet = {v: k for k, v in alphabet2id.items()} + + def __init__(self, edge_list=None, atom_type=None, bond_type=None, residue_type=None, view=None, + atom_name=None, atom2residue=None, residue_feature=None, is_hetero_atom=None, occupancy=None, + b_factor=None, residue_number=None, insertion_code=None, chain_id=None, **kwargs): + super(Protein, self).__init__(edge_list, atom_type, bond_type, **kwargs) + residue_type, num_residue = self._standarize_num_residue(residue_type) + self.num_residue = num_residue + self.view = self._standarize_view(view) + + atom_name = self._standarize_attribute(atom_name, self.num_node) + atom2residue = self._standarize_attribute(atom2residue, self.num_node) + is_hetero_atom = self._standarize_attribute(is_hetero_atom, self.num_node, dtype=torch.bool) + occupancy = self._standarize_attribute(occupancy, self.num_node, dtype=torch.float, default=1) + b_factor = self._standarize_attribute(b_factor, self.num_node, dtype=torch.float) + residue_number = self._standarize_attribute(residue_number, self.num_residue) + insertion_code = self._standarize_attribute(insertion_code, self.num_residue) + chain_id = self._standarize_attribute(chain_id, self.num_residue) + + with self.atom(): + self.atom_name = atom_name + with self.residue_reference(): + self.atom2residue = atom2residue + self.is_hetero_atom = is_hetero_atom + self.occupancy = occupancy + self.b_factor = b_factor + + with self.residue(): + self.residue_type = residue_type + if residue_feature is not None: + self.residue_feature = torch.as_tensor(residue_feature, device=self.device) + self.residue_number = residue_number + self.insertion_code = insertion_code + self.chain_id = chain_id + + def residue(self): + """ + Context manager for residue attributes. + """ + return self.context("residue") + + def residue_reference(self): + """ + Context manager for residue references. + """ + return self.context("residue reference") + + @property + def node_feature(self): + if getattr(self, "view", "atom") == "atom": + return self.atom_feature + else: + return self.residue_feature + + @node_feature.setter + def node_feature(self, value): + self.atom_feature = value + + @property + def num_node(self): + return self.num_atom + + @num_node.setter + def num_node(self, value): + self.num_atom = value + + def _check_attribute(self, key, value): + super(Protein, self)._check_attribute(key, value) + for type in self._meta_contexts: + if type == "residue": + if len(value) != self.num_residue: + raise ValueError("Expect residue attribute `%s` to have shape (%d, *), but found %s" % + (key, self.num_residue, value.shape)) + elif type == "residue reference": + is_valid = (value >= -1) & (value < self.num_residue) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect residue reference in [-1, %d), but found %d" % + (self.num_residue, error_value[0])) + + def _standarize_num_residue(self, residue_type): + if residue_type is None: + raise ValueError("`residue_type` should be provided") + + residue_type = torch.as_tensor(residue_type, dtype=torch.long, device=self.device) + num_residue = torch.tensor(len(residue_type), device=self.device) + return residue_type, num_residue + + def __setattr__(self, key, value): + if key == "view" and value not in ["atom", "residue"]: + raise ValueError("Expect `view` to be either `atom` or `residue`, but found `%s`" % value) + return super(Protein, self).__setattr__(key, value) + + def _standarize_view(self, view): + if view is None: + if self.num_atom > 0: + view = "atom" + else: + view = "residue" + return view + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_molecule(cls, mol, atom_feature="default", bond_feature="default", residue_feature="default", + mol_feature=None, kekulize=False): + """ + Create a protein from an RDKit object. + + Parameters: + mol (rdchem.Mol): molecule + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + residue_feature (str, list of str, optional): residue features to extract + mol_feature (str or list of str, optional): molecule features to extract + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + protein = Molecule.from_molecule(mol, atom_feature=atom_feature, bond_feature=bond_feature, + mol_feature=mol_feature, with_hydrogen=False, kekulize=kekulize) + residue_feature = cls._standarize_option(residue_feature) + + if kekulize: + Chem.Kekulize(mol) + + residue_type = [] + atom_name = [] + is_hetero_atom = [] + occupancy = [] + b_factor = [] + atom2residue = [] + residue_number = [] + insertion_code = [] + chain_id = [] + _residue_feature = [] + last_residue = None + atoms = [mol.GetAtomWithIdx(i) for i in range(mol.GetNumAtoms())] + [cls.dummy_atom] + for atom in atoms: + pdbinfo = atom.GetPDBResidueInfo() + number = pdbinfo.GetResidueNumber() + code = pdbinfo.GetInsertionCode() + type = pdbinfo.GetResidueName().strip() + canonical_residue = (number, code, type) + if canonical_residue != last_residue: + last_residue = canonical_residue + if type not in cls.residue2id: + warnings.warn("Unknown residue `%s`. Treat as glycine" % type) + type = "GLY" + residue_type.append(cls.residue2id[type]) + residue_number.append(number) + if pdbinfo.GetInsertionCode() not in cls.alphabet2id: + warnings.warn(f"Fail to create the protein. Unknown insertion code {pdbinfo.GetInsertionCode()}.") + return None + if pdbinfo.GetChainId() not in cls.alphabet2id: + warnings.warn(f"Fail to create the protein. Unknown chain id {pdbinfo.GetChainId()}.") + return None + insertion_code.append(cls.alphabet2id[pdbinfo.GetInsertionCode()]) + chain_id.append(cls.alphabet2id[pdbinfo.GetChainId()]) + feature = [] + for name in residue_feature: + func = R.get("features.residue.%s" % name) + feature += func(pdbinfo) + _residue_feature.append(feature) + name = pdbinfo.GetName().strip() + if name not in cls.atom_name2id: + name = "UNK" + atom_name.append(cls.atom_name2id[name]) + is_hetero_atom.append(pdbinfo.GetIsHeteroAtom()) + occupancy.append(pdbinfo.GetOccupancy()) + b_factor.append(pdbinfo.GetTempFactor()) + atom2residue.append(len(residue_type) - 1) + residue_type = torch.tensor(residue_type)[:-1] + atom_name = torch.tensor(atom_name)[:-1] + is_hetero_atom = torch.tensor(is_hetero_atom)[:-1] + occupancy = torch.tensor(occupancy)[:-1] + b_factor = torch.tensor(b_factor)[:-1] + atom2residue = torch.tensor(atom2residue)[:-1] + residue_number = torch.tensor(residue_number)[:-1] + insertion_code = torch.tensor(insertion_code)[:-1] + chain_id = torch.tensor(chain_id)[:-1] + if len(residue_feature) > 0: + _residue_feature = torch.tensor(_residue_feature)[:-1] + else: + _residue_feature = None + + return cls(protein.edge_list, num_node=protein.num_node, residue_type=residue_type, + atom_name=atom_name, atom2residue=atom2residue, residue_feature=_residue_feature, + is_hetero_atom=is_hetero_atom, occupancy=occupancy, b_factor=b_factor, + residue_number=residue_number, insertion_code=insertion_code, chain_id=chain_id, + meta_dict=protein.meta_dict, **protein.data_dict) + + @classmethod + def _residue_from_sequence(cls, sequence): + residue_type = [] + residue_feature = [] + sequence = sequence + "G" + for residue in sequence: + if residue not in cls.residue_symbol2id: + warnings.warn("Unknown residue symbol `%s`. Treat as glycine" % residue) + residue = "G" + residue_type.append(cls.residue_symbol2id[residue]) + residue_feature.append(feature.onehot(residue, cls.residue_symbol2id, allow_unknown=True)) + + residue_type = residue_type[:-1] + residue_feature = torch.tensor(residue_feature)[:-1] + + return cls(edge_list=None, atom_type=[], bond_type=[], num_node=0, residue_type=residue_type, + residue_feature=residue_feature) + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_sequence(cls, sequence, atom_feature="default", bond_feature="default", residue_feature="default", + mol_feature=None, kekulize=False): + """ + Create a protein from a sequence. + + .. note:: + + It takes considerable time to construct proteins with a large number of atoms and bonds. + If you only need residue information, you may speed up the construction by setting + ``atom_feature`` and ``bond_feature`` to ``None``. + + Parameters: + sequence (str): protein sequence + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + residue_feature (str, list of str, optional): residue features to extract + mol_feature (str or list of str, optional): molecule features to extract + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + if atom_feature is None and bond_feature is None and residue_feature == "default": + return cls._residue_from_sequence(sequence) + + mol = Chem.MolFromSequence(sequence) + if mol is None: + raise ValueError("Invalid sequence `%s`" % sequence) + + return cls.from_molecule(mol, atom_feature, bond_feature, residue_feature, mol_feature, kekulize) + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_pdb(cls, pdb_file, atom_feature="default", bond_feature="default", residue_feature="default", + mol_feature=None, kekulize=False): + """ + Create a protein from a PDB file. + + Parameters: + pdb_file (str): file name + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + residue_feature (str, list of str, optional): residue features to extract + mol_feature (str or list of str, optional): molecule features to extract + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + if not os.path.exists(pdb_file): + raise FileNotFoundError("No such file `%s`" % pdb_file) + mol = Chem.MolFromPDBFile(pdb_file) + if mol is None: + raise ValueError("RDKit cannot read PDB file `%s`" % pdb_file) + return cls.from_molecule(mol, atom_feature, bond_feature, residue_feature, mol_feature, kekulize) + + def to_molecule(self, ignore_error=False): + """ + Return an RDKit object of this protein. + + Parameters: + ignore_error (bool, optional): if true, return ``None`` for illegal molecules. + Otherwise, raise an exception. + + Returns: + rdchem.Mol + """ + mol = super(Protein, self).to_molecule(ignore_error) + if mol is None: + return mol + + residue_type = self.residue_type.tolist() + atom_name = self.atom_name.tolist() + atom2residue = self.atom2residue.tolist() + is_hetero_atom = self.is_hetero_atom.tolist() + occupancy = self.occupancy.tolist() + b_factor = self.b_factor.tolist() + residue_number = self.residue_number.tolist() + chain_id = self.chain_id.tolist() + insertion_code = self.insertion_code.tolist() + for i, atom in enumerate(mol.GetAtoms()): + r = atom2residue[i] + residue = Chem.AtomPDBResidueInfo() + residue.SetResidueNumber(residue_number[r]) + residue.SetChainId(self.id2alphabet[chain_id[r]]) + residue.SetInsertionCode(self.id2alphabet[insertion_code[r]]) + residue.SetName(" %-3s" % self.id2atom_name[atom_name[i]]) + residue.SetResidueName(self.id2residue[residue_type[r]]) + residue.SetIsHeteroAtom(is_hetero_atom[i]) + residue.SetOccupancy(occupancy[i]) + residue.SetTempFactor(b_factor[i]) + atom.SetPDBResidueInfo(residue) + + return mol + + def to_sequence(self): + """ + Return a sequence of this protein. + + Returns: + str + """ + residue_type = self.residue_type.tolist() + cc_id = self.connected_component_id.tolist() + sequence = [] + for i in range(self.num_residue): + if i > 0 and cc_id[i] > cc_id[i - 1]: + sequence.append(".") + sequence.append(self.id2residue_symbol[residue_type[i]]) + return "".join(sequence) + + def to_pdb(self, pdb_file): + """ + Write this protein to a pdb file. + + Parameters: + pdb_file (str): file name + """ + mol = self.to_molecule() + Chem.MolToPDBFile(mol, pdb_file, flavor=10) + + def split(self, node2graph): + node2graph = torch.as_tensor(node2graph, dtype=torch.long, device=self.device) + # coalesce arbitrary graph IDs to [0, n) + _, node2graph = torch.unique(node2graph, return_inverse=True) + num_graph = node2graph.max() + 1 + index = node2graph.argsort() + mapping = torch.zeros_like(index) + mapping[index] = torch.arange(len(index), device=self.device) + + node_in, node_out = self.edge_list.t()[:2] + edge_mask = node2graph[node_in] == node2graph[node_out] + edge2graph = node2graph[node_in] + edge_index = edge2graph.argsort() + edge_index = edge_index[edge_mask[edge_index]] + + prepend = torch.tensor([-1], device=self.device) + is_first_node = torch.diff(node2graph[index], prepend=prepend) > 0 + graph_index = self.node2graph[index[is_first_node]] + + # a residue can be split into multiple graphs + max_num_node = node2graph.bincount(minlength=num_graph).max() + key = node2graph[index] * max_num_node + self.atom2residue[index] + key_set, atom2residue = key.unique(return_inverse=True) + residue_index = key_set % max_num_node + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + + num_nodes = node2graph.bincount(minlength=num_graph) + num_edges = edge2graph[edge_index].bincount(minlength=num_graph) + num_cum_residues = scatter_max(atom2residue, node2graph[index], dim_size=num_graph)[0] + 1 + prepend = torch.tensor([0], device=self.device) + num_residues = torch.diff(num_cum_residues, prepend=prepend) + + num_cum_nodes = num_nodes.cumsum(0) + offsets = (num_cum_nodes - num_nodes)[edge2graph[edge_index]] + + data_dict, meta_dict = self.data_mask(index, edge_index, residue_index, graph_index, + exclude=("residue reference", "graph reference")) + + return self.packed_type(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], + num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=self.view, + offsets=offsets, atom2residue=atom2residue, meta_dict=meta_dict, **data_dict) + + @classmethod + def pack(cls, graphs): + edge_list = [] + edge_weight = [] + num_nodes = [] + num_edges = [] + num_residues = [] + num_cum_node = 0 + num_cum_edge = 0 + num_cum_residue = 0 + num_graph = 0 + data_dict = defaultdict(list) + meta_dict = graphs[0].meta_dict + view = graphs[0].view + for graph in graphs: + edge_list.append(graph.edge_list) + edge_weight.append(graph.edge_weight) + num_nodes.append(graph.num_node) + num_edges.append(graph.num_edge) + num_residues.append(graph.num_residue) + for k, v in graph.data_dict.items(): + for type in meta_dict[k]: + if type == "graph": + v = v.unsqueeze(0) + elif type == "node reference": + v = torch.where(v != -1, v + num_cum_node, -1) + elif type == "edge reference": + v = torch.where(v != -1, v + num_cum_edge, -1) + elif type == "residue reference": + v = torch.where(v != -1, v + num_cum_residue, -1) + elif type == "graph reference": + v = torch.where(v != -1, v + num_graph, -1) + data_dict[k].append(v) + num_cum_node += graph.num_node + num_cum_edge += graph.num_edge + num_cum_residue += graph.num_residue + num_graph += 1 + + edge_list = torch.cat(edge_list) + edge_weight = torch.cat(edge_weight) + data_dict = {k: torch.cat(v) for k, v in data_dict.items()} + + return cls.packed_type(edge_list, edge_weight=edge_weight, num_relation=graphs[0].num_relation, + num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=view, + meta_dict=meta_dict, **data_dict) + + def repeat(self, count): + edge_list = self.edge_list.repeat(count, 1) + edge_weight = self.edge_weight.repeat(count) + num_nodes = [self.num_node] * count + num_edges = [self.num_edge] * count + num_residues = [self.num_residue] * count + num_relation = self.num_relation + + data_dict = {} + for k, v in self.data_dict.items(): + if "graph" in self.meta_dict[k]: + v = v.unsqueeze(0) + shape = [1] * v.ndim + shape[0] = count + length = len(v) + v = v.repeat(shape) + for type in self.meta_dict[k]: + if type == "node reference": + offsets = torch.arange(count, device=self.device) * self.num_node + v = v + offsets.repeat_interleave(length) + elif type == "edge reference": + offsets = torch.arange(count, device=self.device) * self.num_edge + v = v + offsets.repeat_interleave(length) + elif type == "residue reference": + offsets = torch.arange(count, device=self.device) * self.num_residue + v = v + offsets.repeat_interleave(length) + elif type == "graph reference": + offsets = torch.arange(count, device=self.device) + v = v + offsets.repeat_interleave(length) + data_dict[k] = v + + return self.packed_type(edge_list, edge_weight=edge_weight, + num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=self.view, + num_relation=num_relation, meta_dict=self.meta_dict, **data_dict) + + def residue2atom(self, residue_index): + """Map residue ids to atom ids.""" + residue_index = self._standarize_index(residue_index, self.num_residue) + if not hasattr(self, "node_inverted_index"): + self.node_inverted_index = self._build_node_inverted_index() + inverted_range, order = self.node_inverted_index + starts, ends = inverted_range[residue_index].t() + num_match = ends - starts + offsets = num_match.cumsum(0) - num_match + ranges = torch.arange(num_match.sum(), device=self.device) + ranges = ranges + (starts - offsets).repeat_interleave(num_match) + index = order[ranges] + return index + + def _build_node_inverted_index(self): + keys = self.atom2residue + order = keys.argsort() + keys_set, num_keys = keys.unique(return_counts=True) + ends = num_keys.cumsum(0) + starts = ends - num_keys + ranges = torch.stack([starts, ends], dim=-1) + inverted_range = Dictionary(keys_set, ranges) + return inverted_range, order + + def __getitem__(self, index): + # why do we check tuple? + # case 1: x[0, 1] is parsed as (0, 1) + # case 2: x[[0, 1]] is parsed as [0, 1] + if not isinstance(index, tuple): + index = (index,) + + if len(index) > 1: + raise ValueError("Protein has only 1 axis, but %d axis is indexed" % len(index)) + + return self.residue_mask(index[0], compact=True) + + def data_mask(self, node_index=None, edge_index=None, residue_index=None, graph_index=None, include=None, + exclude=None): + data_dict, meta_dict = super(Protein, self).data_mask(node_index, edge_index, graph_index=graph_index, + include=include, exclude=exclude) + residue_mapping = None + for k, v in data_dict.items(): + for type in meta_dict[k]: + if type == "residue" and residue_index is not None: + if v.is_sparse: + v = v.to_dense()[residue_index].to_sparse() + else: + v = v[residue_index] + elif type == "residue reference" and residue_index is not None: + if residue_mapping is None: + residue_mapping = self._get_mapping(residue_index, self.num_residue) + v = residue_mapping[v] + data_dict[k] = v + + return data_dict, meta_dict + + def residue_mask(self, index, compact=False): + """ + Return a masked protein based on the specified residues. + + Note the compact option is applied to both residue and atom ids. + + Parameters: + index (array_like): residue index + compact (bool, optional): compact residue ids or not + + Returns: + Protein + """ + index = self._standarize_index(index, self.num_residue) + if (torch.diff(index) <= 0).any(): + warnings.warn("`residue_mask()` is called to re-order the residues. This will change the protein sequence. " + "If this is not desired, you might have passed a wrong index to this function.") + residue_mapping = -torch.ones(self.num_residue, dtype=torch.long, device=self.device) + residue_mapping[index] = torch.arange(len(index), device=self.device) + + node_index = residue_mapping[self.atom2residue] >= 0 + node_index = self._standarize_index(node_index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + mapping[node_index] = torch.arange(len(node_index), device=self.device) + num_node = len(node_index) + else: + mapping[node_index] = node_index + num_node = self.num_node + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + edge_index = self._standarize_index(edge_index, self.num_edge) + + if compact: + data_dict, meta_dict = self.data_mask(node_index, edge_index, residue_index=index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_node=num_node, + view=self.view, num_relation=self.num_relation, meta_dict=meta_dict, **data_dict) + + def subresidue(self, index): + """ + Return a subgraph based on the specified residues. + Equivalent to :meth:`residue_mask(index, compact=True) `. + + Parameters: + index (array_like): residue index + + Returns: + Protein + + See also: + :meth:`Protein.residue_mask` + """ + return self.residue_mask(index, compact=True) + + @property + def residue2graph(self): + """Residue id to graph id mapping.""" + return torch.zeros(self.num_residue, dtype=torch.long, device=self.device) + + @utils.cached_property + def connected_component_id(self): + """Connected component id of each residue.""" + node_in, node_out = self.edge_list.t()[:2] + residue_in, residue_out = self.atom2residue[node_in], self.atom2residue[node_out] + mask = residue_in != residue_out + residue_in, residue_out = residue_in[mask], residue_out[mask] + range = torch.arange(self.num_residue, device=self.device) + residue_in, residue_out = torch.cat([residue_in, residue_out, range]), \ + torch.cat([residue_out, residue_in, range]) + + min_neighbor = torch.arange(self.num_residue, device=self.device) + last = torch.zeros_like(min_neighbor) + while not torch.equal(min_neighbor, last): + last = min_neighbor + min_neighbor = scatter_min(min_neighbor[residue_out], residue_in, dim_size=self.num_residue)[0] + cc_id = torch.unique(min_neighbor, return_inverse=True)[1] + return cc_id + + def __repr__(self): + fields = ["num_atom=%d" % self.num_node, "num_bond=%d" % self.num_edge, + "num_residue=%d" % self.num_residue] + if self.device.type != "cpu": + fields.append("device='%s'" % self.device) + return "%s(%s)" % (self.__class__.__name__, ", ".join(fields)) + + +class PackedProtein(PackedMolecule, Protein): + """ + Container for proteins with variadic sizes. + Support both residue-level and atom-level operations and ensure consistency between two views. + + .. warning:: + + Edges of the same graph are guaranteed to be consecutive in the edge list. + The order of residues must be the same as the protein sequence. + However, this class doesn't enforce any order on nodes or edges. + Nodes may have a different order with residues. + + Parameters: + edge_list (array_like, optional): list of edges of shape :math:`(|E|, 3)`. + Each tuple is (node_in, node_out, bond_type). + atom_type (array_like, optional): atom types of shape :math:`(|V|,)` + bond_type (array_like, optional): bond types of shape :math:`(|E|,)` + residue_type (array_like, optional): residue types of shape :math:`(|V_{res}|,)` + view (str, optional): default view for this protein. Can be ``atom`` or ``residue``. + num_nodes (array_like, optional): number of nodes in each graph + By default, it will be inferred from the largest id in `edge_list` + num_edges (array_like, optional): number of edges in each graph + num_residues (array_like, optional): number of residues in each graph + offsets (array_like, optional): node id offsets of shape :math:`(|E|,)`. + If not provided, nodes in `edge_list` should be relative index, i.e., the index in each graph. + If provided, nodes in `edge_list` should be absolute index, i.e., the index in the packed graph. + """ + + unpacked_type = Protein + _check_attribute = Protein._check_attribute + + def __init__(self, edge_list=None, atom_type=None, bond_type=None, residue_type=None, view=None, num_nodes=None, + num_edges=None, num_residues=None, offsets=None, **kwargs): + super(PackedProtein, self).__init__(edge_list=edge_list, num_nodes=num_nodes, num_edges=num_edges, + offsets=offsets, atom_type=atom_type, bond_type=bond_type, + residue_type=residue_type, view=view, **kwargs) + + num_residues = torch.as_tensor(num_residues, device=self.device) + num_cum_residues = num_residues.cumsum(0) + + self.num_residues = num_residues + self.num_cum_residues = num_cum_residues + + @property + def num_nodes(self): + return self.num_atoms + + @num_nodes.setter + def num_nodes(self, value): + self.num_atoms = value + + def data_mask(self, node_index=None, edge_index=None, residue_index=None, graph_index=None, include=None, + exclude=None): + data_dict, meta_dict = super(PackedProtein, self).data_mask(node_index, edge_index, graph_index=graph_index, + include=include, exclude=exclude) + residue_mapping = None + for k, v in data_dict.items(): + for type in meta_dict[k]: + if type == "residue" and residue_index is not None: + if v.is_sparse: + v = v.to_dense()[residue_index].to_sparse() + else: + v = v[residue_index] + elif type == "residue reference" and residue_index is not None: + if residue_mapping is None: + residue_mapping = self._get_mapping(residue_index, self.num_residue) + v = residue_mapping[v] + data_dict[k] = v + + return data_dict, meta_dict + + def node_mask(self, index, compact=True): + index = self._standarize_index(index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + mapping[index] = torch.arange(len(index), device=self.device) + num_nodes = self._get_num_xs(index, self.num_cum_nodes) + offsets = self._get_offsets(num_nodes, self.num_edges) + else: + mapping[index] = index + num_nodes = self.num_nodes + offsets = self._offsets + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + num_edges = self._get_num_xs(edge_index, self.num_cum_edges) + + if compact: + data_dict, meta_dict = self.data_mask(index, edge_index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], + num_nodes=num_nodes, num_edges=num_edges, num_residues=self.num_residues, + view=self.view, num_relation=self.num_relation, offsets=offsets[edge_index], + meta_dict=meta_dict, **data_dict) + + def edge_mask(self, index): + index = self._standarize_index(index, self.num_edge) + data_dict, meta_dict = self.data_mask(edge_index=index) + num_edges = self._get_num_xs(index, self.num_cum_edges) + + return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index], + num_nodes=self.num_nodes, num_edges=num_edges, num_residues=self.num_residues, + view=self.view, num_relation=self.num_relation, offsets=self._offsets[index], + meta_dict=meta_dict, **data_dict) + + def residue_mask(self, index, compact=False): + """ + Return a masked packed protein based on the specified residues. + + Note the compact option is applied to both residue and atom ids, but not graph ids. + + Parameters: + index (array_like): residue index + compact (bool, optional): compact residue ids or not + + Returns: + PackedProtein + """ + index = self._standarize_index(index, self.num_residue) + residue_mapping = -torch.ones(self.num_residue, dtype=torch.long, device=self.device) + residue_mapping[index] = torch.arange(len(index), device=self.device) + + node_index = residue_mapping[self.atom2residue] >= 0 + node_index = self._standarize_index(node_index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + mapping[node_index] = torch.arange(len(node_index), device=self.device) + num_nodes = self._get_num_xs(node_index, self.num_cum_nodes) + num_residues = self._get_num_xs(index, self.num_cum_residues) + else: + mapping[node_index] = node_index + num_nodes = self.num_nodes + num_residues = self.num_residues + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + edge_index = self._standarize_index(edge_index, self.num_edge) + num_edges = self._get_num_xs(edge_index, self.num_cum_edges) + offsets = self._get_offsets(num_nodes, num_edges) + + if compact: + data_dict, meta_dict = self.data_mask(node_index, edge_index, residue_index=index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], + num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, + view=self.view, num_relation=self.num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def graph_mask(self, index, compact=False): + index = self._standarize_index(index, self.batch_size) + graph_mapping = -torch.ones(self.batch_size, dtype=torch.long, device=self.device) + graph_mapping[index] = torch.arange(len(index), device=self.device) + + node_index = graph_mapping[self.node2graph] >= 0 + node_index = self._standarize_index(node_index, self.num_node) + residue_index = graph_mapping[self.residue2graph] >= 0 + residue_index = self._standarize_index(residue_index, self.num_residue) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + key = graph_mapping[self.node2graph[node_index]] * self.num_node + node_index + order = key.argsort() + node_index = node_index[order] + key = graph_mapping[self.residue2graph[residue_index]] * self.num_residue + residue_index + order = key.argsort() + residue_index = residue_index[order] + mapping[node_index] = torch.arange(len(node_index), device=self.device) + num_nodes = self.num_nodes[index] + num_residues = self.num_residues[index] + else: + mapping[node_index] = node_index + num_nodes = torch.zeros_like(self.num_nodes) + num_nodes[index] = self.num_nodes[index] + num_residues = torch.zeros_like(self.num_residues) + num_residues[index] = self.num_residues[index] + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + edge_index = self._standarize_index(edge_index, self.num_edge) + if compact: + key = graph_mapping[self.edge2graph[edge_index]] * self.num_edge + edge_index + order = key.argsort() + edge_index = edge_index[order] + num_edges = self.num_edges[index] + else: + num_edges = torch.zeros_like(self.num_edges) + num_edges[index] = self.num_edges[index] + offsets = self._get_offsets(num_nodes, num_edges) + + if compact: + data_dict, meta_dict = self.data_mask(node_index, edge_index, + residue_index=residue_index, graph_index=index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], + num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, + view=self.view, num_relation=self.num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def get_item(self, index): + node_index = torch.arange(self.num_cum_nodes[index] - self.num_nodes[index], self.num_cum_nodes[index], + device=self.device) + edge_index = torch.arange(self.num_cum_edges[index] - self.num_edges[index], self.num_cum_edges[index], + device=self.device) + residue_index = torch.arange(self.num_cum_residues[index] - self.num_residues[index], + self.num_cum_residues[index], device=self.device) + graph_index = index + edge_list = self.edge_list[edge_index].clone() + edge_list[:, :2] -= self._offsets[edge_index].unsqueeze(-1) + data_dict, meta_dict = self.data_mask(node_index, edge_index, + residue_index=residue_index, graph_index=graph_index) + + return self.unpacked_type(edge_list, edge_weight=self.edge_weight[edge_index], num_node=self.num_nodes[index], + num_relation=self.num_relation, meta_dict=meta_dict, **data_dict) + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_molecule(cls, mols, atom_feature="default", bond_feature="default", residue_feature="default", + mol_feature=None, kekulize=False): + """ + Create a packed protein from a list of RDKit objects. + + Parameters: + mols (list of rdchem.Mol): molecules + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + residue_feature (str or list of str, optional): residue features to extract + mol_feature (str or list of str, optional): molecule features to extract + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + protein = PackedMolecule.from_molecule(mols, atom_feature=atom_feature, bond_feature=bond_feature, + mol_feature=mol_feature, with_hydrogen=False, kekulize=kekulize) + residue_feature = cls._standarize_option(residue_feature) + + residue_type = [] + atom_name = [] + is_hetero_atom = [] + occupancy = [] + b_factor = [] + atom2residue = [] + residue_number = [] + insertion_code = [] + chain_id = [] + _residue_feature = [] + last_residue = None + num_residues = [] + num_cum_residue = 0 + + mols = mols + [cls.dummy_protein] + for mol in mols: + if mol is None: + mol = cls.empty_mol + + if kekulize: + Chem.Kekulize(mol) + + for atom in mol.GetAtoms(): + residue = atom.GetPDBResidueInfo() + number = residue.GetResidueNumber() + code = residue.GetInsertionCode() + type = residue.GetResidueName().strip() + canonical_residue = (number, code, type) + if canonical_residue != last_residue: + last_residue = canonical_residue + if type not in cls.residue2id: + warnings.warn("Unknown residue `%s`. Treat as glycine" % type) + type = "GLY" + residue_type.append(cls.residue2id[type]) + residue_number.append(number) + insertion_code.append(cls.alphabet2id[residue.GetInsertionCode()]) + chain_id.append(cls.alphabet2id[residue.GetChainId()]) + feature = [] + for name in residue_feature: + func = R.get("features.residue.%s" % name) + feature += func(residue) + _residue_feature.append(feature) + name = residue.GetName().strip() + if name not in cls.atom_name2id: + name = "UNK" + atom_name.append(cls.atom_name2id[name]) + is_hetero_atom.append(residue.GetIsHeteroAtom()) + occupancy.append(residue.GetOccupancy()) + b_factor.append(residue.GetTempFactor()) + atom2residue.append(len(residue_type) - 1) + + num_residues.append(len(residue_type) - num_cum_residue) + num_cum_residue = len(residue_type) + + residue_type = torch.tensor(residue_type)[:-1] + atom_name = torch.tensor(atom_name)[:-5] + is_hetero_atom = torch.tensor(is_hetero_atom)[:-5] + occupancy = torch.tensor(occupancy)[:-5] + b_factor = torch.tensor(b_factor)[:-5] + atom2residue = torch.tensor(atom2residue)[:-5] + residue_number = torch.tensor(residue_number)[:-1] + insertion_code = torch.tensor(insertion_code)[:-1] + chain_id = torch.tensor(chain_id)[:-1] + if len(residue_feature) > 0: + _residue_feature = torch.tensor(_residue_feature)[:-1] + else: + _residue_feature = None + + num_residues = num_residues[:-1] + + return cls(protein.edge_list, residue_type=residue_type, + num_nodes=protein.num_nodes, num_edges=protein.num_edges, num_residues=num_residues, + atom_name=atom_name, atom2residue=atom2residue, residue_feature=_residue_feature, + is_hetero_atom=is_hetero_atom, occupancy=occupancy, b_factor=b_factor, + residue_number=residue_number, insertion_code=insertion_code, chain_id=chain_id, + offsets=protein._offsets, meta_dict=protein.meta_dict, **protein.data_dict) + + @classmethod + def _residue_from_sequence(cls, sequences): + num_residues = [] + residue_type = [] + residue_feature = [] + sequences = sequences + ["G"] + for sequence in sequences: + for residue in sequence: + if residue not in cls.residue_symbol2id: + warnings.warn("Unknown residue symbol `%s`. Treat as glycine" % residue) + residue = "G" + residue_type.append(cls.residue_symbol2id[residue]) + residue_feature.append(feature.onehot(residue, cls.residue_symbol2id, allow_unknown=True)) + num_residues.append(len(sequence)) + + residue_type = residue_type[:-1] + residue_feature = torch.tensor(residue_feature)[:-1] + + edge_list = torch.zeros(0, 3, dtype=torch.long) + num_nodes = [0] * (len(sequences) - 1) + num_edges = [0] * (len(sequences) - 1) + num_residues = num_residues[:-1] + + return cls(edge_list=edge_list, atom_type=[], bond_type=[], residue_type=residue_type, + num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, + residue_feature=residue_feature) + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_sequence(cls, sequences, atom_feature="default", bond_feature="default", residue_feature="default", + mol_feature=None, kekulize=False): + """ + Create a packed protein from a list of sequences. + + .. note:: + + It takes considerable time to construct proteins with a large number of atoms and bonds. + If you only need residue information, you may speed up the construction by setting + ``atom_feature`` and ``bond_feature`` to ``None``. + + Parameters: + sequences (str): list of protein sequences + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + residue_feature (str or list of str, optional): residue features to extract + mol_feature (str or list of str, optional): molecule features to extract + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + if atom_feature is None and bond_feature is None and residue_feature == "default": + return cls._residue_from_sequence(sequences) + + mols = [] + for sequence in sequences: + mol = Chem.MolFromSequence(sequence) + if mol is None: + raise ValueError("Invalid sequence `%s`" % sequence) + mols.append(mol) + + return cls.from_molecule(mols, atom_feature, bond_feature, residue_feature, mol_feature, kekulize) + + @classmethod + @utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature") + def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", residue_feature="default", + mol_feature=None, kekulize=False): + """ + Create a protein from a list of PDB files. + + Parameters: + pdb_files (str): list of file names + atom_feature (str or list of str, optional): atom features to extract + bond_feature (str or list of str, optional): bond features to extract + residue_feature (str, list of str, optional): residue features to extract + mol_feature (str or list of str, optional): molecule features to extract + kekulize (bool, optional): convert aromatic bonds to single/double bonds. + Note this only affects the relation in ``edge_list``. + For ``bond_type``, aromatic bonds are always stored explicitly. + By default, aromatic bonds are stored. + """ + mols = [] + for pdb_file in pdb_files: + mol = Chem.MolFromPDBFile(pdb_file) + mols.append(mol) + + return cls.from_molecule(mols, atom_feature, bond_feature, residue_feature, mol_feature, kekulize) + + def to_molecule(self, ignore_error=False): + mols = super(PackedProtein, self).to_molecule(ignore_error) + + residue_type = self.residue_type.tolist() + atom_name = self.atom_name.tolist() + atom2residue = self.atom2residue.tolist() + is_hetero_atom = self.is_hetero_atom.tolist() + occupancy = self.occupancy.tolist() + b_factor = self.b_factor.tolist() + residue_number = self.residue_number.tolist() + chain_id = self.chain_id.tolist() + insertion_code = self.insertion_code.tolist() + num_cum_nodes = [0] + self.num_cum_nodes.tolist() + + for i, mol in enumerate(mols): + for j, atom in enumerate(mol.GetAtoms(), num_cum_nodes[i]): + r = atom2residue[j] + residue = Chem.AtomPDBResidueInfo() + residue.SetResidueNumber(residue_number[r]) + residue.SetChainId(self.id2alphabet[chain_id[r]]) + residue.SetInsertionCode(self.id2alphabet[insertion_code[r]]) + residue.SetName(" %-3s" % self.id2atom_name[atom_name[j]]) + residue.SetResidueName(self.id2residue[residue_type[r]]) + residue.SetIsHeteroAtom(is_hetero_atom[j]) + residue.SetOccupancy(occupancy[j]) + residue.SetTempFactor(b_factor[j]) + atom.SetPDBResidueInfo(residue) + + return mols + + def to_sequence(self): + """ + Return a list of sequences. + + Returns: + list of str + """ + residue_type = self.residue_type.tolist() + cc_id = self.connected_component_id.tolist() + num_cum_residues = [0] + self.num_cum_residues.tolist() + sequences = [] + for i in range(self.batch_size): + sequence = [] + for j in range(num_cum_residues[i], num_cum_residues[i + 1]): + if j > num_cum_residues[i] and cc_id[j] > cc_id[j - 1]: + sequence.append(".") + sequence.append(self.id2residue_symbol[residue_type[j]]) + sequence = "".join(sequence) + sequences.append(sequence) + return sequences + + def to_pdb(self, pdb_files): + """ + Write this packed protein to several pdb files. + + Parameters: + pdb_files (list of str): list of file names + """ + mols = self.to_molecule() + for mol, pdb_file in zip(mols, pdb_files): + Chem.MolToPDBFile(mol, pdb_file, flavor=10) + + def merge(self, graph2graph): + graph2graph = torch.as_tensor(graph2graph, dtype=torch.long, device=self.device) + # coalesce arbitrary graph IDs to [0, n) + _, graph2graph = torch.unique(graph2graph, return_inverse=True) + + graph_key = graph2graph * self.batch_size + torch.arange(self.batch_size, device=self.device) + graph_index = graph_key.argsort() + graph = self.subbatch(graph_index) + graph2graph = graph2graph[graph_index] + + num_graph = graph2graph[-1] + 1 + num_nodes = scatter_add(graph.num_nodes, graph2graph, dim_size=num_graph) + num_edges = scatter_add(graph.num_edges, graph2graph, dim_size=num_graph) + num_residues = scatter_add(graph.num_residues, graph2graph, dim_size=num_graph) + offsets = self._get_offsets(num_nodes, num_edges) + + data_dict, meta_dict = graph.data_mask(exclude="graph") + + return type(self)(graph.edge_list, edge_weight=graph.edge_weight, num_nodes=num_nodes, + num_edges=num_edges, num_residues=num_residues, view=self.view, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def repeat(self, count): + num_nodes = self.num_nodes.repeat(count) + num_edges = self.num_edges.repeat(count) + num_residues = self.num_residues.repeat(count) + offsets = self._get_offsets(num_nodes, num_edges) + edge_list = self.edge_list.repeat(count, 1) + edge_list[:, :2] += (offsets - self._offsets.repeat(count)).unsqueeze(-1) + + data_dict = {} + for k, v in self.data_dict.items(): + shape = [1] * v.ndim + shape[0] = count + length = len(v) + v = v.repeat(shape) + for _type in self.meta_dict[k]: + if _type == "node reference": + pack_offsets = torch.arange(count, device=self.device) * self.num_node + v = v + pack_offsets.repeat_interleave(length) + elif _type == "edge reference": + pack_offsets = torch.arange(count, device=self.device) * self.num_edge + v = v + pack_offsets.repeat_interleave(length) + elif _type == "residue reference": + pack_offsets = torch.arange(count, device=self.device) * self.num_residue + v = v + pack_offsets.repeat_interleave(length) + elif _type == "graph reference": + pack_offsets = torch.arange(count, device=self.device) * self.batch_size + v = v + pack_offsets.repeat_interleave(length) + data_dict[k] = v + + return type(self)(edge_list, edge_weight=self.edge_weight.repeat(count), + num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=self.view, + num_relation=self.num_relation, offsets=offsets, + meta_dict=self.meta_dict, **data_dict) + + def repeat_interleave(self, repeats): + repeats = torch.as_tensor(repeats, dtype=torch.long, device=self.device) + if repeats.numel() == 1: + repeats = repeats * torch.ones(self.batch_size, dtype=torch.long, device=self.device) + num_nodes = self.num_nodes.repeat_interleave(repeats) + num_edges = self.num_edges.repeat_interleave(repeats) + num_residues = self.num_residues.repeat_interleave(repeats) + num_cum_nodes = num_nodes.cumsum(0) + num_cum_edges = num_edges.cumsum(0) + num_cum_residues = num_residues.cumsum(0) + num_node = num_nodes.sum() + num_edge = num_edges.sum() + num_residue = num_residues.sum() + batch_size = repeats.sum() + num_graphs = torch.ones(batch_size, device=self.device) + + # special case 1: graphs[i] may have no node or no edge + # special case 2: repeats[i] may be 0 + cum_repeats_shifted = repeats.cumsum(0) - repeats + graph_mask = cum_repeats_shifted < batch_size + cum_repeats_shifted = cum_repeats_shifted[graph_mask] + + index = num_cum_nodes - num_nodes + index = torch.cat([index, index[cum_repeats_shifted]]) + value = torch.cat([-num_nodes, self.num_nodes[graph_mask]]) + mask = index < num_node + node_index = scatter_add(value[mask], index[mask], dim_size=num_node) + node_index = (node_index + 1).cumsum(0) - 1 + + index = num_cum_edges - num_edges + index = torch.cat([index, index[cum_repeats_shifted]]) + value = torch.cat([-num_edges, self.num_edges[graph_mask]]) + mask = index < num_edge + edge_index = scatter_add(value[mask], index[mask], dim_size=num_edge) + edge_index = (edge_index + 1).cumsum(0) - 1 + + index = num_cum_residues - num_residues + index = torch.cat([index, index[cum_repeats_shifted]]) + value = torch.cat([-num_residues, self.num_residues[graph_mask]]) + mask = index < num_residue + residue_index = scatter_add(value[mask], index[mask], dim_size=num_residue) + residue_index = (residue_index + 1).cumsum(0) - 1 + + graph_index = torch.repeat_interleave(repeats) + + offsets = self._get_offsets(num_nodes, num_edges) + edge_list = self.edge_list[edge_index] + edge_list[:, :2] += (offsets - self._offsets[edge_index]).unsqueeze(-1) + + node_offsets = None + edge_offsets = None + residue_offsets = None + graph_offsets = None + data_dict = {} + for k, v in self.data_dict.items(): + num_xs = None + pack_offsets = None + for _type in self.meta_dict[k]: + if _type == "node": + v = v[node_index] + num_xs = num_nodes + elif _type == "edge": + v = v[edge_index] + num_xs = num_edges + elif _type == "residue": + v = v[residue_index] + num_xs = num_residues + elif _type == "graph": + v = v[graph_index] + num_xs = num_graphs + elif _type == "node reference": + if node_offsets is None: + node_offsets = self._get_repeat_pack_offsets(self.num_nodes, repeats) + pack_offsets = node_offsets + elif _type == "edge reference": + if edge_offsets is None: + edge_offsets = self._get_repeat_pack_offsets(self.num_edges, repeats) + pack_offsets = edge_offsets + elif _type == "residue reference": + if residue_offsets is None: + residue_offsets = self._get_repeat_pack_offsets(self.num_residues, repeats) + pack_offsets = residue_offsets + elif _type == "graph reference": + if graph_offsets is None: + graph_offsets = self._get_repeat_pack_offsets(num_graphs, repeats) + pack_offsets = graph_offsets + # add offsets to make references point to indexes in their own graph + if num_xs is not None and pack_offsets is not None: + v = v + pack_offsets.repeat_interleave(num_xs) + data_dict[k] = v + + return type(self)(edge_list, edge_weight=self.edge_weight[edge_index], + num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=self.view, + num_relation=self.num_relation, offsets=offsets, meta_dict=self.meta_dict, **data_dict) + + def undirected(self, add_inverse=True): + undirected = PackedMolecule.undirected(self, add_inverse=add_inverse) + + return type(self)(undirected.edge_list, edge_weight=undirected.edge_weight, + num_nodes=undirected.num_nodes, num_edges=undirected.num_edges, + num_residues=self.num_residues, view=self.view, num_relation=undirected.num_relation, + offsets=undirected._offsets, meta_dict=undirected.meta_dict, **undirected.data_dict) + + def detach(self): + return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(), + num_nodes=self.num_nodes, num_edges=self.num_edges, num_residues=self.num_residues, + view=self.view, num_relation=self.num_relation, offsets=self._offsets, + meta_dict=self.meta_dict, **utils.detach(self.data_dict)) + + def clone(self): + return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(), + num_nodes=self.num_nodes, num_edges=self.num_edges, num_residues=self.num_residues, + view=self.view, num_relation=self.num_relation, offsets=self._offsets, + meta_dict=self.meta_dict, **utils.clone(self.data_dict)) + + def cuda(self, *args, **kwargs): + edge_list = self.edge_list.cuda(*args, **kwargs) + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, + num_nodes=self.num_nodes, num_edges=self.num_edges, num_residues=self.num_residues, + view=self.view, num_relation=self.num_relation, offsets=self._offsets, + meta_dict=self.meta_dict, **utils.cuda(self.data_dict, *args, **kwargs)) + + def cpu(self): + edge_list = self.edge_list.cpu() + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, + num_nodes=self.num_nodes, num_edges=self.num_edges, num_residues=self.num_residues, + view=self.view, num_relation=self.num_relation, offsets=self._offsets, + meta_dict=self.meta_dict, **utils.cpu(self.data_dict)) + + @utils.cached_property + def residue2graph(self): + """Residue id to graph id mapping.""" + range = torch.arange(self.batch_size, device=self.device) + residue2graph = range.repeat_interleave(self.num_residues) + return residue2graph + + @utils.cached_property + def connected_component_id(self): + cc_id = super(PackedProtein, self).connected_component_id + cc_id_offsets = scatter_min(cc_id, self.residue2graph, dim_size=self.num_residue)[0][self.residue2graph] + cc_id = cc_id - cc_id_offsets + return cc_id + + def __repr__(self): + fields = ["batch_size=%d" % self.batch_size, + "num_atoms=%s" % pretty.long_array(self.num_nodes.tolist()), + "num_bonds=%s" % pretty.long_array(self.num_edges.tolist()), + "num_residues=%s" % pretty.long_array(self.num_residues.tolist())] + if self.device.type != "cpu": + fields.append("device='%s'" % self.device) + return "%s(%s)" % (self.__class__.__name__, ", ".join(fields)) + + +Protein.packed_type = PackedProtein diff --git a/build/lib/torchdrug/data/rdkit/__init__.py b/build/lib/torchdrug/data/rdkit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/torchdrug/data/rdkit/draw.py b/build/lib/torchdrug/data/rdkit/draw.py new file mode 100644 index 00000000..a0e2a4b7 --- /dev/null +++ b/build/lib/torchdrug/data/rdkit/draw.py @@ -0,0 +1,54 @@ +from matplotlib import pyplot as plt + +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions +from rdkit.Chem.Draw import mplCanvas + + +class Canvas(mplCanvas.Canvas): + + def __init__(self, ax, name="", imageType="png"): + self._name = name + if ax is None: + size = (3, 3) + self._figure = plt.figure(figsize=size) + self._axes = self._figure.add_axes([0, 0, 1, 1]) + else: + bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) + size = (bbox.width, bbox.height) + self._figure = ax.figure + self._axes = ax + self._axes.set_axis_off() + # these are rdkit internal size and dpi + self.size = tuple(s * 100 for s in size) + self._dpi = max(self.size) + + +def MolToMPL(mol, ax=None, kekulize=True, wedgeBonds=True, imageType=None, fitImage=False, + options=None, **kwargs): + """Generates a drawing of a molecule on a matplotlib canvas.""" + if not mol: + raise ValueError("Null molecule provided") + + canvas = Canvas(ax) + if options is None: + options = DrawingOptions() + options.bgColor = None + if fitImage: + options.dotsPerAngstrom = int(min(canvas.size) / 10) + options.wedgeDashedBonds = wedgeBonds + drawer = MolDrawing(canvas=canvas, drawingOptions=options) + omol = mol + if kekulize: + mol = Chem.Mol(mol.ToBinary()) + Chem.Kekulize(mol) + + if not mol.GetNumConformers(): + AllChem.Compute2DCoords(mol) + + drawer.AddMol(mol, **kwargs) + omol._atomPs = drawer.atomPs[mol] + for k, v in omol._atomPs.items(): + omol._atomPs[k] = canvas.rescalePt(v) + return canvas._figure \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/__init__.py b/build/lib/torchdrug/datasets/__init__.py new file mode 100644 index 00000000..7eb75646 --- /dev/null +++ b/build/lib/torchdrug/datasets/__init__.py @@ -0,0 +1,64 @@ +from .bace import BACE +from .bbbp import BBBP +from .cep import CEP +from .clintox import ClinTox +from .delaney import Delaney +from .freesolv import FreeSolv +from .hiv import HIV +from .lipophilicity import Lipophilicity +from .malaria import Malaria +from .moses import MOSES +from .muv import MUV +from .opv import OPV +from .qm8 import QM8 +from .qm9 import QM9 +from .sider import SIDER +from .tox21 import Tox21 +from .toxcast import ToxCast +from .uspto50k import USPTO50k +from .zinc250k import ZINC250k +from .zinc2m import ZINC2m +from .pcqm4m import PCQM4M +from .pubchem110m import PubChem110m +from .chembl_filtered import ChEMBLFiltered + +from .beta_lactamase import BetaLactamase +from .fluorescence import Fluorescence +from .stability import Stability +from .solubility import Solubility +from .fold import Fold +from .binary_localization import BinaryLocalization +from .subcellular_localization import SubcellularLocalization +from .secondary_structure import SecondaryStructure +from .human_ppi import HumanPPI +from .yeast_ppi import YeastPPI +from .ppi_affinity import PPIAffinity +from .bindingdb import BindingDB +from .pdbbind import PDBBind +from .proteinnet import ProteinNet + +from .enzyme_commission import EnzymeCommission +from .gene_ontology import GeneOntology +from .alphafolddb import AlphaFoldDB + +from .fb15k import FB15k, FB15k237 +from .wn18 import WN18, WN18RR +from .yago310 import YAGO310 +from .hetionet import Hetionet + +from .cora import Cora +from .citeseer import CiteSeer +from .pubmed import PubMed + +__all__ = [ + "BACE", "BBBP", "CEP", "ClinTox", "Delaney", "FreeSolv", "HIV", "Lipophilicity", + "Malaria", "MOSES", "MUV", "OPV", "QM8", "QM9", "SIDER", "Tox21", "ToxCast", + "USPTO50k", "ZINC250k", + "ZINC2m", "PCQM4M", "PubChem110m", "ChEMBLFiltered", + "EnzymeCommission", "GeneOntology", "AlphaFoldDB", + "BetaLactamase", "Fluorescence", "Stability", "Solubility", "Fold", + "BinaryLocalization", "SubcellularLocalization", "SecondaryStructure", + "HumanPPI", "YeastPPI", "PPIAffinity", "BindingDB", "PDBBind", "ProteinNet", + "FB15k", "FB15k237", "WN18", "WN18RR", "Hetionet", + "Cora", "CiteSeer", "PubMed", +] diff --git a/build/lib/torchdrug/datasets/alphafolddb.py b/build/lib/torchdrug/datasets/alphafolddb.py new file mode 100644 index 00000000..c8ce6e88 --- /dev/null +++ b/build/lib/torchdrug/datasets/alphafolddb.py @@ -0,0 +1,155 @@ +import os +import glob + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.AlphaFoldDB") +@utils.copy_args(data.ProteinDataset.load_pdbs) +class AlphaFoldDB(data.ProteinDataset): + """ + 3D protein structures predicted by AlphaFold. + This dataset covers proteomes of 48 organisms, as well as the majority of Swiss-Prot. + + Statistics: + See https://alphafold.ebi.ac.uk/download + + Parameters: + path (str): path to store the dataset + species_id (int, optional): the id of species to be loaded. The species are numbered + by the order appeared on https://alphafold.ebi.ac.uk/download (0-20 for model + organism proteomes, 21 for Swiss-Prot) + split_id (int, optional): the id of split to be loaded. To avoid large memory consumption + for one dataset, we have cut each species into several splits, each of which contains + at most 22000 proteins. + verbose (int, optional): output verbose level + **kwargs + """ + + urls = [ + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000006548_3702_ARATH_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001940_6239_CAEEL_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000559_237561_CANAL_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000437_7955_DANRE_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002195_44689_DICDI_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000803_7227_DROME_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000625_83333_ECOLI_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008827_3847_SOYBN_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000005640_9606_HUMAN_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008153_5671_LEIIN_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000805_243232_METJA_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000589_10090_MOUSE_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001584_83332_MYCTU_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000059680_39947_ORYSJ_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001450_36329_PLAF7_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002494_10116_RAT_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002311_559292_YEAST_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002485_284812_SCHPO_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008816_93061_STAA8_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002296_353153_TRYCC_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000007305_4577_MAIZE_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/swissprot_pdb_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001631_447093_AJECG_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000006672_6279_BRUMA_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000799_192222_CAMJE_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000094526_86049_9EURO1_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000274756_318479_DRAME_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000325664_1352_ENTFC_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000053029_1442368_9EURO2_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000579_71421_HAEIN_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000429_85962_HELPY_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000007841_1125630_KLEPH_v2.tar", + # "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008153_5671_LEIIN_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000078237_100816_9PEZI1_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000806_272631_MYCLE_v2.tar", + # "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001584_83332_MYCTU_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000020681_1299332_MYCUL_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000535_242231_NEIG1_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000006304_1133849_9NOCA1_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000024404_6282_ONCVO_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002059_502779_PARBA_v2.tar", + # "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001450_36329_PLAF7_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002438_208964_PSEAE_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001014_99287_SALTY_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008854_6183_SCHMA_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002716_300267_SHIDS_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000018087_1391915_SPOS1_v2.tar", + # "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008816_93061_STAA8_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000586_171101_STRR6_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000035681_6248_STRER_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000030665_36087_TRITR_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008524_185431_TRYB2_v2.tar", + # "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002296_353153_TRYCC_v2.tar", + "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000270924_6293_WUCBA_v2.tar" + ] + md5s = [ + "4cd5f596ebfc3d45d9f6b647dc5684af", "b89bee5507f78f971417cc8fd75b40f7", "a6459a1f1a0a22fbf25f1c05c2889ae3", + "24dfba8ab93dbf3f51e7db6b912dd6b4", "6b81b3086ed9e57e04a54f148ecf974c", "a50f4fd9f581c89e79e1b2857e54b786", + "fdd16245769bf1f7d91a0e285ac00e52", "66b9750c511182bc5f8ee71fe2ab2a17", "5dadeb5aac704025cac33f7557794858", + "99b22e0f050d845782d914becbfe4d2f", "da938dfae4fabf6e144f4b5ede5885ec", "2003c09d437cfb4093552c588a33e06d", + "fba59f386cfa33af3f70ae664b7feac0", "d7a1a6c02213754ee1a1ffb3b41ad4ba", "8a0e8deadffec2aba3b7edd6534b7481", + "1854d0bbcf819de1de7b0cfdb6d32b2e", "d9720e3809db6916405db096b520c236", "6b918e9e4d645b12a80468bcea805f1f", + "ed0eefe927eb8c3b81cf87eaabbb8d6e", "051369e0dc8fed4798c8b2c68e6cbe2e", "b05ff57164167851651c625dca66ed28", + "68e7a6e57bd43cb52e344b3190073387", "75d027ac7833f284fda65ea620353e8a", "7d85bb2ee4130096a6d905ab8d726bcc", + "63498210c88e8bfb1a7346c4ddf73bb1", "5bf2211304ef91d60bb3838ec12d89cd", "4981758eb8980e9df970ac6113e4084c", + "322431789942595b599d2b86670f41b3", "35d7b32e37bcc23d02b12b03b1e0c093", "1b8847dd786fa41b5b38f5e7aa58b813", + "126bdbe59fa82d55bfa098b710bdf650", "6c6d3248ed943dd7137637fc92d7ba37", "532203c6877433df5651b95d27685825", + "6e7112411da5843bec576271c44e0a0a", "0e4f913a9b4672b0ad3cc9c4f2de5c8d", "a138d0060b2e8a0ef1f90cf3ab7b7ca0", + "04d491dd1c679e91b5a2f3b9f14db555", "889c051e39305614accdff00414bfa67", "cd87cf24e5135c9d729940194ccc65c8", + "75eb8bfe866cf3040f4c08a566c32bc1", "fd8e6ddb9c159aab781a11c287c85feb", "b91a2e103980b96f755712f2b559ad66", + "26187d09b093649686d7c158aa4fd113", "62e16894bb4b8951a82befd24ad4ee21", "85c001df1d91788bf3cc1f97230b1dac", + "91a25af808351757b101a8c9c787db9e", "8b3e8645cc4c2484c331759b9d1df5bc", "e8a76a6ab290e6743233510e8d1eb4a5", + "38280bd7804f4c060b0775c4abed9b89" + ] + species_nsplit = [ + 2, 1, 1, 2, 1, 1, 1, 3, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 20, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, #1, 1, 1, 1, 1 + ] + split_length = 22000 + + def __init__(self, path, species_id=0, split_id=0, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + species_name = os.path.basename(self.urls[species_id])[:-4] + if split_id >= self.species_nsplit[species_id]: + raise ValueError("Split id %d should be less than %d in species %s" % + (split_id, self.species_nsplit[species_id], species_name)) + self.processed_file = "%s_%d.pkl.gz" % (species_name, split_id) + pkl_file = os.path.join(path, self.processed_file) + + if os.path.exists(pkl_file): + self.load_pickle(pkl_file, verbose=verbose, **kwargs) + else: + tar_file = utils.download(self.urls[species_id], path, md5=self.md5s[species_id]) + pdb_path = utils.extract(tar_file) + gz_files = sorted(glob.glob(os.path.join(pdb_path, "*.pdb.gz"))) + pdb_files = [] + index = slice(split_id * self.split_length, (split_id + 1) * self.split_length) + for gz_file in gz_files[index]: + pdb_files.append(utils.extract(gz_file)) + self.load_pdbs(pdb_files, verbose=verbose, **kwargs) + self.save_pickle(pkl_file, verbose=verbose) + + def get_item(self, index): + if getattr(self, "lazy", False): + protein = data.Protein.from_pdb(self.pdb_files[index], self.kwargs) + else: + protein = self.data[index].clone() + if hasattr(protein, "residue_feature"): + with protein.residue(): + protein.residue_feature = protein.residue_feature.to_dense() + item = {"graph": protein} + if self.transform: + item = self.transform(item) + return item + + def __repr__(self): + lines = [ + "#sample: %d" % len(self), + ] + return "%s(\n %s\n)" % (self.__class__.__name__, "\n ".join(lines)) diff --git a/build/lib/torchdrug/datasets/bace.py b/build/lib/torchdrug/datasets/bace.py new file mode 100644 index 00000000..f3dc53b3 --- /dev/null +++ b/build/lib/torchdrug/datasets/bace.py @@ -0,0 +1,36 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.BACE") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class BACE(data.MoleculeDataset): + r""" + Binary binding results for a set of inhibitors of human :math:`\beta`-secretase 1(BACE-1). + + Statistics: + - #Molecule: 1,513 + - #Classification task: 1 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/bace.csv" + md5 = "ba7f8fa3fdf463a811fa7edea8c982c2" + target_fields = ["Class"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + file_name = utils.download(self.url, path, md5=self.md5) + + self.load_csv(file_name, smiles_field="mol", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/bbbp.py b/build/lib/torchdrug/datasets/bbbp.py new file mode 100644 index 00000000..3256419a --- /dev/null +++ b/build/lib/torchdrug/datasets/bbbp.py @@ -0,0 +1,36 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.BBBP") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class BBBP(data.MoleculeDataset): + """ + Binary labels of blood-brain barrier penetration. + + Statistics: + - #Molecule: 2,039 + - #Classification task: 1 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/BBBP.csv" + md5 = "66286cb9e6b148bd75d80c870df580fb" + target_fields = ["p_np"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + file_name = utils.download(self.url, path, md5=self.md5) + + self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/beta_lactamase.py b/build/lib/torchdrug/datasets/beta_lactamase.py new file mode 100644 index 00000000..3b50c64f --- /dev/null +++ b/build/lib/torchdrug/datasets/beta_lactamase.py @@ -0,0 +1,51 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.BetaLactamase") +@utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) +class BetaLactamase(data.ProteinDataset): + """ + The activity values of first-order mutants of the TEM-1 beta-lactamase protein. + + Statistics: + - #Train: 4,158 + - #Valid: 520 + - #Test: 520 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/beta_lactamase.tar.gz" + md5 = "65766a3969cc0e94b101d4063d204ba4" + splits = ["train", "valid", "test"] + target_fields = ["scaled_effect1"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "beta_lactamase/beta_lactamase_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/binary_localization.py b/build/lib/torchdrug/datasets/binary_localization.py new file mode 100644 index 00000000..d7b8d95e --- /dev/null +++ b/build/lib/torchdrug/datasets/binary_localization.py @@ -0,0 +1,52 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.BinaryLocalization") +@utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) +class BinaryLocalization(data.ProteinDataset): + """ + Simpler version of the Subcellular Localization with binary labels indicating + whether a protein is membrane-bound or soluble. + + Statistics: + - #Train: 5,161 + - #Valid: 1,727 + - #Test: 1,746 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/subcellular_localization_2.tar.gz" + md5 = "5d2309bf1c0c2aed450102578e434f4e" + splits = ["train", "valid", "test"] + target_fields = ["localization"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "subcellular_localization_2/subcellular_localization_2_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/bindingdb.py b/build/lib/torchdrug/datasets/bindingdb.py new file mode 100644 index 00000000..03d8ae9a --- /dev/null +++ b/build/lib/torchdrug/datasets/bindingdb.py @@ -0,0 +1,72 @@ +import os + +from rdkit import Chem + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.BindingDB") +@utils.copy_args(data.ProteinLigandDataset.load_lmdbs, ignore=("sequence_field", "smiles_field", "target_fields")) +class BindingDB(data.ProteinLigandDataset): + """ + The BindingDB dataset with binding affinity indicating the interaction strength + between pairs of protein and ligand. + + Statistics: + - #Train: 7,900 + - #Valid: 878 + - #Test: 5,230 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/BindingDB_Kd.tar.gz" + md5 = "0b207cb962c4945f9003fc020b415a74" + splits = ["train", "valid", "random_test", "holdout_test"] + target_fields = ["affinity"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "BindingDB_Kd_%s.lmdb" % split) for split in self.splits] + + self.load_lmdbs(lmdb_files, sequence_field="target", smiles_field="drug", + target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self, keys=None): + keys = keys or self.splits + offset = 0 + splits = [] + for split_name, num_sample in zip(self.splits, self.num_samples): + if split_name in keys: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + + def get_item(self, index): + if self.lazy: + graph1 = data.Protein.from_sequence(self.sequences[index], **self.kwargs) + mol = Chem.MolFromSmiles(self.smiles[index]) + if not mol: + graph2 = None + else: + graph2 = data.Molecule.from_molecule(mol, **self.kwargs) + else: + graph1 = self.data[index][0] + graph2 = self.data[index][1] + item = {"graph1": graph1, "graph2": graph2} + item.update({k: v[index] for k, v in self.targets.items()}) + if self.transform: + item = self.transform(item) + return item \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/cep.py b/build/lib/torchdrug/datasets/cep.py new file mode 100644 index 00000000..44a1b805 --- /dev/null +++ b/build/lib/torchdrug/datasets/cep.py @@ -0,0 +1,36 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.CEP") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class CEP(data.MoleculeDataset): + """ + Photovoltaic efficiency estimated by Havard clean energy project. + + Statistics: + - #Molecule: 20,000 + - #Regression task: 1 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-02-cep-pce/cep-processed.csv" + md5 = "b6d257ff416917e4e6baa5e1103f3929" + target_fields = ["PCE"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + file_name = utils.download(self.url, self.path, md5=self.md5) + + self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) diff --git a/build/lib/torchdrug/datasets/chembl_filtered.py b/build/lib/torchdrug/datasets/chembl_filtered.py new file mode 100644 index 00000000..74f3ea68 --- /dev/null +++ b/build/lib/torchdrug/datasets/chembl_filtered.py @@ -0,0 +1,36 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.ChEMBLFiltered") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class ChEMBLFiltered(data.MoleculeDataset): + """ + Statistics: + - #Molecule: 430,710 + - #Regression task: 1,310 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://zenodo.org/record/5528681/files/chembl_filtered_torchdrug.csv.gz" + md5 = "2fff04fecd6e697f28ebb127e8a37561" + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + csv_file = utils.extract(zip_file) + + self.target_fields = ["target_{}".format(i) for i in range(1310)] + + self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/citeseer.py b/build/lib/torchdrug/datasets/citeseer.py new file mode 100644 index 00000000..d61f0a49 --- /dev/null +++ b/build/lib/torchdrug/datasets/citeseer.py @@ -0,0 +1,35 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.CiteSeer") +class CiteSeer(data.NodeClassificationDataset): + """ + A citation network of scientific publications with binary word features. + + Statistics: + - #Node: 3,327 + - #Edge: 8,059 + - #Class: 6 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + """ + + url = "https://linqs-data.soe.ucsc.edu/public/lbc/citeseer.tgz" + md5 = "c8ded8ed395b31899576bfd1e91e4d6e" + + def __init__(self, path, verbose=1): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + node_file = utils.extract(zip_file, "citeseer/citeseer.content") + edge_file = utils.extract(zip_file, "citeseer/citeseer.cites") + + self.load_tsv(node_file, edge_file, verbose=verbose) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/clintox.py b/build/lib/torchdrug/datasets/clintox.py new file mode 100644 index 00000000..aba7a440 --- /dev/null +++ b/build/lib/torchdrug/datasets/clintox.py @@ -0,0 +1,38 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.ClinTox") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class ClinTox(data.MoleculeDataset): + """ + Qualitative data of drugs approved by the FDA and those that have failed clinical + trials for toxicity reasons. + + Statistics: + - #Molecule: 1,478 + - #Classification task: 2 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz" + md5 = "db4f2df08be8ae92814e9d6a2d015284" + target_fields = ["FDA_APPROVED", "CT_TOX"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + csv_file = utils.extract(zip_file) + + self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/cora.py b/build/lib/torchdrug/datasets/cora.py new file mode 100644 index 00000000..72935a97 --- /dev/null +++ b/build/lib/torchdrug/datasets/cora.py @@ -0,0 +1,35 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Cora") +class Cora(data.NodeClassificationDataset): + """ + A citation network of scientific publications with binary word features. + + Statistics: + - #Node: 2,708 + - #Edge: 5,429 + - #Class: 7 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + """ + + url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" + md5 = "2fc040bee8ce3d920e4204effd1e9214" + + def __init__(self, path, verbose=1): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + node_file = utils.extract(zip_file, "cora/cora.content") + edge_file = utils.extract(zip_file, "cora/cora.cites") + + self.load_tsv(node_file, edge_file, verbose=verbose) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/delaney.py b/build/lib/torchdrug/datasets/delaney.py new file mode 100644 index 00000000..b744eadd --- /dev/null +++ b/build/lib/torchdrug/datasets/delaney.py @@ -0,0 +1,36 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Delaney") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class Delaney(data.MoleculeDataset): + """ + Log-scale water solubility of molecules. + + Statistics: + - #Molecule: 1,128 + - #Regression task: 1 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/delaney-processed.csv" + md5 = "0c90a51668d446b9e3ab77e67662bd1c" + target_fields = ["measured log solubility in mols per litre"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + file_name = utils.download(self.url, self.path, md5=self.md5) + + self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/enzyme_commission.py b/build/lib/torchdrug/datasets/enzyme_commission.py new file mode 100644 index 00000000..799b9a02 --- /dev/null +++ b/build/lib/torchdrug/datasets/enzyme_commission.py @@ -0,0 +1,135 @@ +import os +import csv +import glob + +import torch +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.EnzymeCommission") +@utils.copy_args(data.ProteinDataset.load_pdbs) +class EnzymeCommission(data.ProteinDataset): + """ + A set of proteins with their 3D structures and EC numbers, which describes their + catalysis of biochemical reactions. + + Statistics (test_cutoff=0.95): + - #Train: 15,011 + - #Valid: 1,664 + - #Test: 1,840 + + Parameters: + path (str): the path to store the dataset + test_cutoff (float, optional): the test cutoff used to split the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://zenodo.org/record/6622158/files/EnzymeCommission.zip" + md5 = "33f799065f8ad75f87b709a87293bc65" + processed_file = "enzyme_commission.pkl.gz" + test_cutoffs = [0.3, 0.4, 0.5, 0.7, 0.95] + + def __init__(self, path, test_cutoff=0.95, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + if test_cutoff not in self.test_cutoffs: + raise ValueError("Unknown test cutoff `%.2f` for EnzymeCommission dataset" % test_cutoff) + self.test_cutoff = test_cutoff + + zip_file = utils.download(self.url, path, md5=self.md5) + path = os.path.join(utils.extract(zip_file), "EnzymeCommission") + pkl_file = os.path.join(path, self.processed_file) + + csv_file = os.path.join(path, "nrPDB-EC_test.csv") + pdb_ids = [] + with open(csv_file, "r") as fin: + reader = csv.reader(fin, delimiter=",") + idx = self.test_cutoffs.index(test_cutoff) + 1 + _ = next(reader) + for line in reader: + if line[idx] == "0": + pdb_ids.append(line[0]) + + if os.path.exists(pkl_file): + self.load_pickle(pkl_file, verbose=verbose, **kwargs) + else: + pdb_files = [] + for split in ["train", "valid", "test"]: + split_path = utils.extract(os.path.join(path, "%s.zip" % split)) + pdb_files += sorted(glob.glob(os.path.join(split_path, split, "*.pdb"))) + self.load_pdbs(pdb_files, verbose=verbose, **kwargs) + self.save_pickle(pkl_file, verbose=verbose) + if len(pdb_ids) > 0: + self.filter_pdb(pdb_ids) + + tsv_file = os.path.join(path, "nrPDB-EC_annot.tsv") + pdb_ids = [os.path.basename(pdb_file).split("_")[0] for pdb_file in self.pdb_files] + self.load_annotation(tsv_file, pdb_ids) + + splits = [os.path.basename(os.path.dirname(pdb_file)) for pdb_file in self.pdb_files] + self.num_samples = [splits.count("train"), splits.count("valid"), splits.count("test")] + + def filter_pdb(self, pdb_ids): + pdb_ids = set(pdb_ids) + sequences = [] + pdb_files = [] + data = [] + for sequence, pdb_file, protein in zip(self.sequences, self.pdb_files, self.data): + if os.path.basename(pdb_file).split("_")[0] in pdb_ids: + continue + sequences.append(sequence) + pdb_files.append(pdb_file) + data.append(protein) + self.sequences = sequences + self.pdb_files = pdb_files + self.data = data + + def load_annotation(self, tsv_file, pdb_ids): + with open(tsv_file, "r") as fin: + reader = csv.reader(fin, delimiter="\t") + _ = next(reader) + tasks = next(reader) + task2id = {task: i for i, task in enumerate(tasks)} + _ = next(reader) + pos_targets = {} + for pdb_id, pos_target in reader: + pos_target = [task2id[t] for t in pos_target.split(",")] + pos_target = torch.tensor(pos_target) + pos_targets[pdb_id] = pos_target + + # fake targets to enable the property self.tasks + self.targets = task2id + self.pos_targets = [] + for pdb_id in pdb_ids: + self.pos_targets.append(pos_targets[pdb_id]) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + + def get_item(self, index): + if getattr(self, "lazy", False): + protein = data.Protein.from_pdb(self.pdb_files[index], self.kwargs) + else: + protein = self.data[index].clone() + if hasattr(protein, "residue_feature"): + with protein.residue(): + protein.residue_feature = protein.residue_feature.to_dense() + item = {"graph": protein} + if self.transform: + item = self.transform(item) + indices = self.pos_targets[index].unsqueeze(0) + values = torch.ones(len(self.pos_targets[index])) + item["targets"] = utils.sparse_coo_tensor(indices, values, (len(self.tasks),)).to_dense() + return item \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/fb15k.py b/build/lib/torchdrug/datasets/fb15k.py new file mode 100644 index 00000000..9a12fe97 --- /dev/null +++ b/build/lib/torchdrug/datasets/fb15k.py @@ -0,0 +1,106 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.FB15k") +class FB15k(data.KnowledgeGraphDataset): + """ + Subset of Freebase knowledge base for knowledge graph reasoning. + + Statistics: + - #Entity: 14,951 + - #Relation: 1,345 + - #Triplet: 592,213 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + """ + + urls = [ + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k/train.txt", + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k/valid.txt", + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k/test.txt", + ] + md5s = [ + "5a87195e68d7797af00e137a7f6929f2", + "275835062bb86a86477a3c402d20b814", + "71098693b0efcfb8ac6cd61cf3a3b505" + ] + + def __init__(self, path, verbose=1): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + txt_files = [] + for url, md5 in zip(self.urls, self.md5s): + save_file = "fb15k_%s" % os.path.basename(url) + txt_file = utils.download(url, self.path, save_file=save_file, md5=md5) + txt_files.append(txt_file) + + self.load_tsvs(txt_files, verbose=verbose) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + + +@R.register("datasets.FB15k237") +class FB15k237(data.KnowledgeGraphDataset): + """ + A filtered version of FB15k dataset without trivial cases. + + Statistics: + - #Entity: 14,541 + - #Relation: 237 + - #Triplet: 310,116 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + """ + + urls = [ + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k-237/train.txt", + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k-237/valid.txt", + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/FB15k-237/test.txt", + ] + md5s = [ + "c05b87b9ac00f41901e016a2092d7837", + "6a94efd530e5f43fcf84f50bc6d37b69", + "f5bdf63db39f455dec0ed259bb6f8628" + ] + + def __init__(self, path, verbose=1): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + txt_files = [] + for url, md5 in zip(self.urls, self.md5s): + save_file = "fb15k237_%s" % os.path.basename(url) + txt_file = utils.download(url, self.path, save_file=save_file, md5=md5) + txt_files.append(txt_file) + + self.load_tsvs(txt_files, verbose=verbose) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/fluorescence.py b/build/lib/torchdrug/datasets/fluorescence.py new file mode 100644 index 00000000..81d31d05 --- /dev/null +++ b/build/lib/torchdrug/datasets/fluorescence.py @@ -0,0 +1,51 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Fluorescence") +@utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) +class Fluorescence(data.ProteinDataset): + """ + The fitness values of a set of green fluorescent protein mutants. + + Statistics: + - #Train: 21,446 + - #Valid: 5,362 + - #Test: 27,217 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/fluorescence.tar.gz" + md5 = "d63d1d51ec8c20ff0d981e4cbd67457a" + splits = ["train", "valid", "test"] + target_fields = ["log_fluorescence"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "fluorescence/fluorescence_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/fold.py b/build/lib/torchdrug/datasets/fold.py new file mode 100644 index 00000000..a41a8c3b --- /dev/null +++ b/build/lib/torchdrug/datasets/fold.py @@ -0,0 +1,53 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Fold") +@utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) +class Fold(data.ProteinDataset): + """ + Fold labels for a set of proteins determined by the global structural topology. + + Statistics: + - #Train: 12,312 + - #Valid: 736 + - #Test: 718 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/remote_homology.tar.gz" + md5 = "1d687bdeb9e3866f77504d6079eed00a" + splits = ["train", "valid", "test_fold_holdout", "test_family_holdout", "test_superfamily_holdout"] + target_fields = ["fold_label"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "remote_homology/remote_homology_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self, keys=None): + keys = keys or self.splits + offset = 0 + splits = [] + for split_name, num_sample in zip(self.splits, self.num_samples): + if split_name in keys: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/freesolv.py b/build/lib/torchdrug/datasets/freesolv.py new file mode 100644 index 00000000..51ae15b9 --- /dev/null +++ b/build/lib/torchdrug/datasets/freesolv.py @@ -0,0 +1,37 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.FreeSolv") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class FreeSolv(data.MoleculeDataset): + """ + Experimental and calculated hydration free energy of small molecules in water. + + Statistics: + - #Molecule: 642 + - #Regression task: 1 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/FreeSolv.zip" + md5 = "8d681babd239b15e2f8b2d29f025577a" + target_fields = ["expt"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, self.path, md5=self.md5) + csv_file = utils.extract(zip_file, "SAMPL.csv") + + self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/gene_ontology.py b/build/lib/torchdrug/datasets/gene_ontology.py new file mode 100644 index 00000000..68ea4427 --- /dev/null +++ b/build/lib/torchdrug/datasets/gene_ontology.py @@ -0,0 +1,145 @@ +import os +import csv +import glob + +import torch +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.GeneOntology") +@utils.copy_args(data.ProteinDataset.load_pdbs) +class GeneOntology(data.ProteinDataset): + """ + A set of proteins with their 3D structures and GO terms. These terms classify proteins + into hierarchically related functional classes organized into three ontologies: molecular + function (MF), biological process (BP) and cellular component (CC). + + Statistics (test_cutoff=0.95): + - #Train: 27,496 + - #Valid: 3,053 + - #Test: 2,991 + + Parameters: + path (str): the path to store the dataset + branch (str, optional): the GO branch + test_cutoff (float, optional): the test cutoff used to split the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://zenodo.org/record/6622158/files/GeneOntology.zip" + md5 = "376be1f088cd1fe720e1eaafb701b5cb" + branches = ["MF", "BP", "CC"] + processed_file = "gene_ontology.pkl.gz" + test_cutoffs = [0.3, 0.4, 0.5, 0.7, 0.95] + + def __init__(self, path, branch="MF", test_cutoff=0.95, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + if branch not in self.branches: + raise ValueError("Unknown branch `%s` for GeneOntology dataset" % branch) + self.branch = branch + if test_cutoff not in self.test_cutoffs: + raise ValueError("Unknown test cutoff `%.2f` for GeneOntology dataset" % test_cutoff) + self.test_cutoff = test_cutoff + + zip_file = utils.download(self.url, path, md5=self.md5) + path = os.path.join(utils.extract(zip_file), "GeneOntology") + pkl_file = os.path.join(path, self.processed_file) + + csv_file = os.path.join(path, "nrPDB-GO_test.csv") + pdb_ids = [] + with open(csv_file, "r") as fin: + reader = csv.reader(fin, delimiter=",") + idx = self.test_cutoffs.index(test_cutoff) + 1 + _ = next(reader) + for line in reader: + if line[idx] == "0": + pdb_ids.append(line[0]) + + if os.path.exists(pkl_file): + self.load_pickle(pkl_file, verbose=verbose, **kwargs) + else: + pdb_files = [] + for split in ["train", "valid", "test"]: + split_path = utils.extract(os.path.join(path, "%s.zip" % split)) + pdb_files += sorted(glob.glob(os.path.join(split_path, split, "*.pdb"))) + self.load_pdbs(pdb_files, verbose=verbose, **kwargs) + self.save_pickle(pkl_file, verbose=verbose) + if len(pdb_ids) > 0: + self.filter_pdb(pdb_ids) + + tsv_file = os.path.join(path, "nrPDB-GO_annot.tsv") + pdb_ids = [os.path.basename(pdb_file).split("_")[0] for pdb_file in self.pdb_files] + self.load_annotation(tsv_file, pdb_ids) + + splits = [os.path.basename(os.path.dirname(pdb_file)) for pdb_file in self.pdb_files] + self.num_samples = [splits.count("train"), splits.count("valid"), splits.count("test")] + + def filter_pdb(self, pdb_ids): + pdb_ids = set(pdb_ids) + sequences = [] + pdb_files = [] + data = [] + for sequence, pdb_file, protein in zip(self.sequences, self.pdb_files, self.data): + if os.path.basename(pdb_file).split("_")[0] in pdb_ids: + continue + sequences.append(sequence) + pdb_files.append(pdb_file) + data.append(protein) + self.sequences = sequences + self.pdb_files = pdb_files + self.data = data + + def load_annotation(self, tsv_file, pdb_ids): + idx = self.branches.index(self.branch) + with open(tsv_file, "r") as fin: + reader = csv.reader(fin, delimiter="\t") + for i in range(12): + _ = next(reader) + if i == idx * 4 + 1: + tasks = _ + task2id = {task: i for i, task in enumerate(tasks)} + _ = next(reader) + pos_targets = {} + for line in reader: + pdb_id, pos_target = line[0], line[idx + 1] if idx + 1 < len(line) else None + pos_target = [task2id[t] for t in pos_target.split(",")] if pos_target else [] + pos_target = torch.LongTensor(pos_target) + pos_targets[pdb_id] = pos_target + + # fake targets to enable the property self.tasks + self.targets = task2id + self.pos_targets = [] + for pdb_id in pdb_ids: + self.pos_targets.append(pos_targets[pdb_id]) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + + def get_item(self, index): + if getattr(self, "lazy", False): + protein = data.Protein.from_pdb(self.pdb_files[index], self.kwargs) + else: + protein = self.data[index].clone() + if hasattr(protein, "residue_feature"): + with protein.residue(): + protein.residue_feature = protein.residue_feature.to_dense() + item = {"graph": protein} + if self.transform: + item = self.transform(item) + indices = self.pos_targets[index].unsqueeze(0) + values = torch.ones(len(self.pos_targets[index])) + item["targets"] = utils.sparse_coo_tensor(indices, values, (len(self.tasks),)).to_dense() + return item \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/hetionet.py b/build/lib/torchdrug/datasets/hetionet.py new file mode 100644 index 00000000..83aadee2 --- /dev/null +++ b/build/lib/torchdrug/datasets/hetionet.py @@ -0,0 +1,57 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Hetionet") +class Hetionet(data.KnowledgeGraphDataset): + """ + Hetionet for knowledge graph reasoning. + + Statistics: + - #Entity: 45,158 + - #Relation: 24 + - #Triplet: 2,025,177 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + """ + + urls = [ + "https://www.dropbox.com/s/y47bt9oq57h6l5k/train.txt?dl=1", + "https://www.dropbox.com/s/a0pbrx9tz3dgsff/valid.txt?dl=1", + "https://www.dropbox.com/s/4dhrvg3fyq5tnu4/test.txt?dl=1", + ] + md5s = [ + "6e58915d70ce6d9389c6e4785245e0b3", + "77f15fac4f8170b836392a5b1d315afa", + "e8877aafe89d0c9b9c1efb9027cb7226" + ] + + def __init__(self, path, verbose=1): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + txt_files = [] + for url, md5 in zip(self.urls, self.md5s): + save_file = "hetionet_%s.txt" % os.path.splitext(os.path.basename(url))[0] + txt_file = utils.download(url, self.path, save_file=save_file, md5=md5) + txt_files.append(txt_file) + + self.load_tsvs(txt_files, verbose=verbose) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + diff --git a/build/lib/torchdrug/datasets/hiv.py b/build/lib/torchdrug/datasets/hiv.py new file mode 100644 index 00000000..f6ec3023 --- /dev/null +++ b/build/lib/torchdrug/datasets/hiv.py @@ -0,0 +1,36 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.HIV") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class HIV(data.MoleculeDataset): + """ + Experimentally measured abilities to inhibit HIV replication. + + Statistics: + - #Molecule: 41,127 + - #Classification task: 1 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/HIV.csv" + md5 = "9ad10c88f82f1dac7eb5c52b668c30a7" + target_fields = ["HIV_active"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + file_name = utils.download(self.url, path, md5=self.md5) + + self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/human_ppi.py b/build/lib/torchdrug/datasets/human_ppi.py new file mode 100644 index 00000000..da149a10 --- /dev/null +++ b/build/lib/torchdrug/datasets/human_ppi.py @@ -0,0 +1,67 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.HumanPPI") +@utils.copy_args(data.ProteinPairDataset.load_lmdbs, ignore=("sequence_field", "target_fields")) +class HumanPPI(data.ProteinPairDataset): + """ + Binary labels indicating whether two human proteins interact or not. + + Statistics: + - #Train: 6,844 + - #Valid: 277 + - #Test: 227 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/ppidata/human_ppi.zip" + md5 = "89885545ebc2c11d774c342910230e20" + splits = ["train", "valid", "test", "cross_species_test"] + target_fields = ["interaction"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "human_ppi/human_ppi_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, sequence_field=["primary_1", "primary_2"], target_fields=self.target_fields, + verbose=verbose, **kwargs) + + def split(self, keys=None): + keys = keys or self.splits + offset = 0 + splits = [] + for split_name, num_sample in zip(self.splits, self.num_samples): + if split_name in keys: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + + def get_item(self, index): + if self.lazy: + graph1 = data.Protein.from_sequence(self.sequences[index][0], **self.kwargs) + graph2 = data.Protein.from_sequence(self.sequences[index][1], **self.kwargs) + else: + graph1 = self.data[index][0] + graph2 = self.data[index][1] + item = {"graph1": graph1, "graph2": graph2} + item.update({k: v[index] for k, v in self.targets.items()}) + if self.transform: + item = self.transform(item) + return item \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/lipophilicity.py b/build/lib/torchdrug/datasets/lipophilicity.py new file mode 100644 index 00000000..73935a58 --- /dev/null +++ b/build/lib/torchdrug/datasets/lipophilicity.py @@ -0,0 +1,36 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Lipophilicity") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class Lipophilicity(data.MoleculeDataset): + """ + Experimental results of octanol/water distribution coefficient (logD at pH 7.4). + + Statistics: + - #Molecule: 4,200 + - #Regression task: 1 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/Lipophilicity.csv" + md5 = "85a0e1cb8b38b0dfc3f96ff47a57f0ab" + target_fields = ["exp"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + file_name = utils.download(self.url, self.path, md5=self.md5) + + self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) diff --git a/build/lib/torchdrug/datasets/malaria.py b/build/lib/torchdrug/datasets/malaria.py new file mode 100644 index 00000000..0ae4078b --- /dev/null +++ b/build/lib/torchdrug/datasets/malaria.py @@ -0,0 +1,37 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Malaria") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class Malaria(data.MoleculeDataset): + """ + Half-maximal effective concentration (EC50) against a parasite that causes malaria. + + Statistics: + - #Molecule: 10,000 + - #Regression task: 1 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-03-malaria/" \ + "malaria-processed.csv" + md5 = "ef40ddfd164be0e5ed1bd3dd0cce9b88" + target_fields = ["activity"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + file_name = utils.download(self.url, self.path, md5=self.md5) + + self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/moses.py b/build/lib/torchdrug/datasets/moses.py new file mode 100644 index 00000000..83fca6c7 --- /dev/null +++ b/build/lib/torchdrug/datasets/moses.py @@ -0,0 +1,48 @@ +import os +from collections import defaultdict + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.MOSES") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class MOSES(data.MoleculeDataset): + """ + Subset of ZINC database for molecule generation. + This dataset doesn't contain any label information. + + Statistics: + - #Molecule: 1,936,963 + + Parameters: + path (str): path for the CSV dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv" + md5 = "6bdb0d9526ddf5fdeb87d6aa541df213" + target_fields = ["SPLIT"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + file_name = utils.download(self.url, path, md5=self.md5) + + self.load_csv(file_name, smiles_field="SMILES", target_fields=self.target_fields, + lazy=True, verbose=verbose, **kwargs) + + def split(self): + indexes = defaultdict(list) + for i, split in enumerate(self.targets["SPLIT"]): + indexes[split].append(i) + train_set = torch_data.Subset(self, indexes["train"]) + valid_set = torch_data.Subset(self, indexes["valid"]) + test_set = torch_data.Subset(self, indexes["test"]) + return train_set, valid_set, test_set diff --git a/build/lib/torchdrug/datasets/muv.py b/build/lib/torchdrug/datasets/muv.py new file mode 100644 index 00000000..ffa283b3 --- /dev/null +++ b/build/lib/torchdrug/datasets/muv.py @@ -0,0 +1,38 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.MUV") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class MUV(data.MoleculeDataset): + """ + Subset of PubChem BioAssay by applying a refined nearest neighbor analysis. + + Statistics: + - #Molecule: 93,087 + - #Classification task: 17 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/muv.csv.gz" + md5 = "9c40bd41310991efd40f4d4868fa3ddf" + target_fields = ["MUV-466", "MUV-548", "MUV-600", "MUV-644", "MUV-652", "MUV-689", "MUV-692", "MUV-712", "MUV-713", + "MUV-733", "MUV-737", "MUV-810", "MUV-832", "MUV-846", "MUV-852", "MUV-858", "MUV-859"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + csv_file = utils.extract(zip_file) + + self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/opv.py b/build/lib/torchdrug/datasets/opv.py new file mode 100644 index 00000000..728c3ed8 --- /dev/null +++ b/build/lib/torchdrug/datasets/opv.py @@ -0,0 +1,97 @@ +import os +import csv +import math +from collections import defaultdict +from tqdm import tqdm + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.OPV") +@utils.copy_args(data.MoleculeDataset.load_smiles) +class OPV(data.MoleculeDataset): + """ + Quantum mechanical calculations on organic photovoltaic candidate molecules. + + Statistics: + - #Molecule: 94,576 + - #Regression task: 8 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + train_url = "https://cscdata.nrel.gov/api/datasets/ad5d2c9a-af0a-4d72-b943-1e433d5750d6/download/" \ + "b69cf9a5-e7e0-405b-88cb-40df8007242e" + valid_url = "https://cscdata.nrel.gov/api/datasets/ad5d2c9a-af0a-4d72-b943-1e433d5750d6/download/" \ + "1c8e7379-3071-4360-ba8e-0c6481c33d2c" + test_url = "https://cscdata.nrel.gov/api/datasets/ad5d2c9a-af0a-4d72-b943-1e433d5750d6/download/" \ + "4ef40592-0080-4f00-9bb7-34b25f94962a" + train_md5 = "16e439b7411ea0a8d3a56ba4802b61b1" + valid_md5 = "3aa2ac62015932ca84661feb5d29adda" + test_md5 = "bad072224f0755478f0729476ca99a33" + target_fields = ["gap", "homo", "lumo", "spectral_overlap", "gap_extrapolated", "homo_extrapolated", + "lumo_extrapolated", "optical_lumo_extrapolated"] + + def read_csv(self, csv_file, smiles_field="smiles", target_fields=None, verbose=0): + if target_fields is not None: + target_fields = set(target_fields) + + with open(csv_file, "r") as fin: + reader = csv.reader(fin) + if verbose: + reader = iter(tqdm(reader, "Loading %s" % csv_file, utils.get_line_count(csv_file))) + fields = next(reader) + smiles = [] + targets = defaultdict(list) + for i, values in enumerate(reader): + if not any(values): + continue + if smiles_field is None: + smiles.append("") + for field, value in zip(fields, values): + if field == smiles_field: + smiles.append(value) + elif target_fields is None or field in target_fields: + value = utils.literal_eval(value) + if value == "": + value = math.nan + targets[field].append(value) + + return smiles, targets + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + train_zip_file = utils.download(self.train_url, path, save_file="mol_train.csv.gz", md5=self.train_md5) + valid_zip_file = utils.download(self.valid_url, path, save_file="mol_valid.csv.gz", md5=self.valid_md5) + test_zip_file = utils.download(self.test_url, path, save_file="mol_test.csv.gz", md5=self.test_md5) + train_file = utils.extract(train_zip_file) + valid_file = utils.extract(valid_zip_file) + test_file = utils.extract(test_zip_file) + + train_smiles, train_targets = self.read_csv(train_file, smiles_field="smile", target_fields=self.target_fields) + valid_smiles, valid_targets = self.read_csv(valid_file, smiles_field="smile", target_fields=self.target_fields) + test_smiles, test_targets = self.read_csv(test_file, smiles_field="smile", target_fields=self.target_fields) + self.num_train = len(train_smiles) + self.num_valid = len(valid_smiles) + self.num_test = len(test_smiles) + + smiles = train_smiles + valid_smiles + test_smiles + targets = {k: train_targets[k] + valid_targets[k] + test_targets[k] for k in train_targets} + + self.load_smiles(smiles, targets, verbose=verbose, **kwargs) + + def split(self): + train_set = torch_data.Subset(self, range(self.num_train)) + valid_set = torch_data.Subset(self, range(self.num_train, self.num_train + self.num_valid)) + test_set = torch_data.Subset(self, range(-self.num_test, 0)) + return train_set, valid_set, test_set \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/pcqm4m.py b/build/lib/torchdrug/datasets/pcqm4m.py new file mode 100644 index 00000000..bc0c61ec --- /dev/null +++ b/build/lib/torchdrug/datasets/pcqm4m.py @@ -0,0 +1,38 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.PCQM4M") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class PCQM4M(data.MoleculeDataset): + """ + Quantum chemistry dataset originally curated under the PubChemQC of molecules. + + Statistics: + - #Molecule: 3,803,453 + - #Regression task: 1 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip" + md5 = "5144ebaa7c67d24da1a2acbe41f57f6a" + target_fields = ["homolumogap"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, self.path, md5=self.md5) + zip_file = utils.extract(zip_file, "pcqm4m_kddcup2021/raw/data.csv.gz") + file_name = utils.extract(zip_file) + + self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, + lazy=True, verbose=verbose, **kwargs) diff --git a/build/lib/torchdrug/datasets/pdbbind.py b/build/lib/torchdrug/datasets/pdbbind.py new file mode 100644 index 00000000..5c2543f4 --- /dev/null +++ b/build/lib/torchdrug/datasets/pdbbind.py @@ -0,0 +1,74 @@ +import os + +from rdkit import Chem + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.PDBBind") +@utils.copy_args(data.ProteinLigandDataset.load_lmdbs, ignore=("sequence_field", "smiles_field", "target_fields")) +class PDBBind(data.ProteinLigandDataset): + """ + The PDBbind-2019 dataset with binding affinity indicating the interaction strength + between pairs of protein and ligand. + + Statistics: + - #Train: 16,436 + - #Valid: 937 + - #Test: 285 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/pdbind.tar.gz" + md5 = "5f5b3d2cd5f5a5fcf9e6da922850f4a0" + splits = ["train", "valid", "test"] + target_fields = ["affinity"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "pdbind/pdbind_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, sequence_field="target", smiles_field="drug", + target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self, keys=None): + keys = keys or self.splits + offset = 0 + splits = [] + for split_name, num_sample in zip(self.splits, self.num_samples): + if split_name in keys: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + + def get_item(self, index): + if self.lazy: + graph1 = data.Protein.from_sequence(self.sequences[index], **self.kwargs) + mol = Chem.MolFromSmiles(self.smiles[index]) + if not mol: + graph2 = None + else: + graph2 = data.Molecule.from_molecule(mol, **self.kwargs) + else: + graph1 = self.data[index][0] + graph2 = self.data[index][1] + item = {"graph1": graph1, "graph2": graph2} + item.update({k: v[index] for k, v in self.targets.items()}) + if self.transform: + item = self.transform(item) + return item \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/ppi_affinity.py b/build/lib/torchdrug/datasets/ppi_affinity.py new file mode 100644 index 00000000..e287486e --- /dev/null +++ b/build/lib/torchdrug/datasets/ppi_affinity.py @@ -0,0 +1,67 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.PPIAffinity") +@utils.copy_args(data.ProteinPairDataset.load_lmdbs, ignore=("sequence_field", "target_fields")) +class PPIAffinity(data.ProteinPairDataset): + r""" + The binding affinity values measured by :math:`p_{K_d}` between two proteins. + + Statistics: + - #Train: 2,127 + - #Valid: 212 + - #Test: 343 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/ppidata/ppi_affinity.zip" + md5 = "d114907fd20c75820e41881f8901e9e4" + splits = ["train", "valid", "test"] + target_fields = ["interaction"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "ppi_affinity/ppi_affinity_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, sequence_field=["primary_1", "primary_2"], target_fields=self.target_fields, + verbose=verbose, **kwargs) + + def split(self, keys=None): + keys = keys or self.splits + offset = 0 + splits = [] + for split_name, num_sample in zip(self.splits, self.num_samples): + if split_name in keys: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + + def get_item(self, index): + if self.lazy: + graph1 = data.Protein.from_sequence(self.sequences[index][0], **self.kwargs) + graph2 = data.Protein.from_sequence(self.sequences[index][1], **self.kwargs) + else: + graph1 = self.data[index][0] + graph2 = self.data[index][1] + item = {"graph1": graph1, "graph2": graph2} + item.update({k: v[index] for k, v in self.targets.items()}) + if self.transform: + item = self.transform(item) + return item \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/proteinnet.py b/build/lib/torchdrug/datasets/proteinnet.py new file mode 100644 index 00000000..f2231fea --- /dev/null +++ b/build/lib/torchdrug/datasets/proteinnet.py @@ -0,0 +1,69 @@ +import os + +import torch +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.ProteinNet") +@utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) +class ProteinNet(data.ProteinDataset): + """ + A set of proteins with 3D structures for the contact prediction task. + + Statistics: + - #Train: 25,299 + - #Valid: 224 + - #Test: 40 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/data/proteinnet.tar.gz" + md5 = "ab44ab201b1570c0171a2bba9eb4d389" + splits = ["train", "valid", "test"] + target_fields = ["tertiary", "valid_mask"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "proteinnet/proteinnet_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) + + def get_item(self, index): + if self.lazy: + graph = data.Protein.from_sequence(self.sequences[index], **self.kwargs) + else: + graph = self.data[index] + with graph.residue(): + residue_position = torch.as_tensor(self.targets["tertiary"][index], dtype=torch.float) + graph.residue_position = residue_position + mask = torch.as_tensor(self.targets["valid_mask"][index], dtype=torch.bool) + graph.mask = mask + item = {"graph": graph} + if self.transform: + item = self.transform(item) + return item + + def split(self, keys=None): + keys = keys or self.splits + offset = 0 + splits = [] + for split_name, num_sample in zip(self.splits, self.num_samples): + if split_name in keys: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/pubchem110m.py b/build/lib/torchdrug/datasets/pubchem110m.py new file mode 100644 index 00000000..bae71bc6 --- /dev/null +++ b/build/lib/torchdrug/datasets/pubchem110m.py @@ -0,0 +1,47 @@ +import os +import csv +from tqdm import tqdm + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.PubChem110m") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class PubChem110m(data.MoleculeDataset): + """ + PubChem. + This dataset doesn't contain any label information. + + Statistics: + - #Molecule: + + Parameters: + path (str): + verbose (int, optional): output verbose level + **kwargs + """ + # TODO: download path & md5. Is it the statistics right? + + target_fields = [] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + smiles_file = os.path.join(path, "CID-SMILES") + + with open(smiles_file, "r") as fin: + reader = csv.reader(fin, delimiter="\t") + if verbose: + reader = iter(tqdm(reader, "Loading %s" % path, utils.get_line_count(smiles_file))) + smiles_list = [] + + for values in reader: + smiles = values[1] + smiles_list.append(smiles) + + targets = {} + self.load_smiles(smiles_list, targets, lazy=True, verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/pubmed.py b/build/lib/torchdrug/datasets/pubmed.py new file mode 100644 index 00000000..1078c2b6 --- /dev/null +++ b/build/lib/torchdrug/datasets/pubmed.py @@ -0,0 +1,95 @@ +import os +import re +import csv + +from tqdm import tqdm + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.PubMed") +class PubMed(data.NodeClassificationDataset): + """ + A citation network of scientific publications with TF-IDF word features. + + Statistics: + - #Node: 19,717 + - #Edge: 44,338 + - #Class: 3 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + """ + + url = "https://linqs-data.soe.ucsc.edu/public/Pubmed-Diabetes.tgz" + md5 = "9fa24b917990c47e264a94079b9599fe" + + def __init__(self, path, verbose=1): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + node_file = utils.extract(zip_file, "Pubmed-Diabetes/data/Pubmed-Diabetes.NODE.paper.tab") + edge_file = utils.extract(zip_file, "Pubmed-Diabetes/data/Pubmed-Diabetes.DIRECTED.cites.tab") + + inv_node_vocab = {} + node_feature = [] + node_label = [] + + with open(node_file, "r") as fin: + reader = csv.reader(fin, delimiter="\t") + if verbose: + reader = iter(tqdm(reader, "Loading %s" % node_file, utils.get_line_count(node_file))) + _ = next(reader) + fields = next(reader) + group, = re.match(r"cat=(\S+):label", fields[0]).groups() + label_tokens = group.split(",") + inv_label_vocab = {token: i for i, token in enumerate(label_tokens)} + inv_feature_vocab = {} + for field in fields[1:]: + match = re.match(r"numeric:(\S+):0\.0", field) + if not match: + continue + feature_token, = match.groups() + inv_feature_vocab[feature_token] = len(inv_feature_vocab) + + for tokens in reader: + node_token = tokens[0] + label_token, = re.match(r"label=(\S+)", tokens[1]).groups() + feature = [0] * len(inv_feature_vocab) + inv_node_vocab[node_token] = len(inv_node_vocab) + for token in tokens[2:]: + match = re.match(r"(\S+)=([0-9.]+)", token) + if not match: + continue + feature_token, value = match.groups() + feature[inv_feature_vocab[feature_token]] = utils.literal_eval(value) + label = inv_label_vocab[label_token] + node_feature.append(feature) + node_label.append(label) + + edge_list = [] + + with open(edge_file, "r") as fin: + reader = csv.reader(fin, delimiter="\t") + if verbose: + reader = iter(tqdm(reader, "Loading %s" % edge_file, utils.get_line_count(edge_file))) + _ = next(reader) + _ = next(reader) + for tokens in reader: + h_token, = re.match(r"paper:(\S+)", tokens[1]).groups() + t_token, = re.match(r"paper:(\S+)", tokens[3]).groups() + if h_token not in inv_node_vocab: + inv_node_vocab[h_token] = len(inv_node_vocab) + h = inv_node_vocab[h_token] + if t_token not in inv_node_vocab: + inv_node_vocab[t_token] = len(inv_node_vocab) + t = inv_node_vocab[t_token] + edge_list.append((h, t)) + + self.load_edge(edge_list, node_feature, node_label, inv_node_vocab=inv_node_vocab, + inv_label_vocab=inv_label_vocab) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/qm8.py b/build/lib/torchdrug/datasets/qm8.py new file mode 100644 index 00000000..32a24c1f --- /dev/null +++ b/build/lib/torchdrug/datasets/qm8.py @@ -0,0 +1,85 @@ +import os +import csv +from collections import defaultdict + +from tqdm import tqdm +from rdkit import Chem, RDLogger + +import torch + +from torchdrug import data, utils +from torchdrug.data import feature +from torchdrug.core import Registry as R + + +@R.register("datasets.QM8") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class QM8(data.MoleculeDataset): + """ + Electronic spectra and excited state energy of small molecules. + + Statistics: + - #Molecule: 21,786 + - #Regression task: 12 + + Parameters: + path (str): path to store the dataset + node_position (bool, optional): load node position or not. + This will add `node_position` as a node attribute to each sample. + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/gdb8.tar.gz" + md5 = "b7e2a2c823c75b35c596f3013319c86e" + target_fields = ["E1-CC2", "E2-CC2", "f1-CC2", "f2-CC2", + "E1-PBE0/def2SVP", "E2-PBE0/def2SVP", "f1-PBE0/def2SVP", "f2-PBE0/def2SVP", + "E1-PBE0/def2TZVP", "E2-PBE0/def2TZVP", "f1-PBE0/def2TZVP", "f2-PBE0/def2TZVP", + "E1-CAM", "E2-CAM", "f1-CAM", "f2-CAM"] + + def __init__(self, path, node_position=False, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + sdf_file = utils.extract(zip_file, "qm8.sdf") + csv_file = utils.extract(zip_file, "qm8.sdf.csv") + csv_file2 = os.path.join(path, "qm8.sdf.clean.csv") + + if not os.path.exists(csv_file2): + with open(csv_file, "r") as fin, open(csv_file2, "w") as fout: + reader = csv.reader(fin) + writer = csv.writer(fout) + fields = next(reader) + fields[5:9] = [field + "/def2SVP" for field in fields[5:9]] + fields[9:13] = [field + "/def2TZVP" for field in fields[9:13]] + writer.writerow(fields) + for values in reader: + writer.writerow(values) + + self.load_csv(csv_file2, smiles_field=None, target_fields=self.target_fields, verbose=verbose) + + with utils.no_rdkit_log(): + molecules = Chem.SDMolSupplier(sdf_file, True, True, False) + + targets = self.targets + self.data = [] + self.targets = defaultdict(list) + assert len(molecules) == len(targets[self.target_fields[0]]) + indexes = range(len(molecules)) + if verbose: + indexes = tqdm(indexes, "Constructing molecules from SDF") + for i in indexes: + with utils.capture_rdkit_log() as log: + mol = molecules[i] + if mol is None: + continue + d = data.Molecule.from_molecule(mol, **kwargs) + if node_position: + with d.node(): + d.node_position = torch.tensor([feature.atom_position(atom) for atom in mol.GetAtoms()]) + self.data.append(d) + for k in targets: + self.targets[k].append(targets[k][i]) diff --git a/build/lib/torchdrug/datasets/qm9.py b/build/lib/torchdrug/datasets/qm9.py new file mode 100644 index 00000000..5ba94b59 --- /dev/null +++ b/build/lib/torchdrug/datasets/qm9.py @@ -0,0 +1,71 @@ +import os +from collections import defaultdict + +from tqdm import tqdm +from rdkit import Chem, RDLogger + +import torch + +from torchdrug import data, utils +from torchdrug.data import feature +from torchdrug.core import Registry as R + + +@R.register("datasets.QM9") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class QM9(data.MoleculeDataset): + """ + Geometric, energetic, electronic and thermodynamic properties of DFT-modeled small molecules. + + Statistics: + - #Molecule: 133,885 + - #Regression task: 12 + + Parameters: + path (str): path to store the dataset + node_position (bool, optional): load node position or not. + This will add `node_position` as a node attribute to each sample. + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/gdb9.tar.gz" + md5 = "560f62d8e6c992ca0cf8ed8d013f9131" + target_fields = ["mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298"] + + def __init__(self, path, node_position=False, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + sdf_file = utils.extract(zip_file, "gdb9.sdf") + csv_file = utils.extract(zip_file, "gdb9.sdf.csv") + + self.load_csv(csv_file, smiles_field=None, target_fields=self.target_fields, verbose=verbose) + + with utils.no_rdkit_log(): + molecules = Chem.SDMolSupplier(sdf_file, True, True, False) + + targets = self.targets + self.data = [] + self.targets = defaultdict(list) + assert len(molecules) == len(targets[self.target_fields[0]]) + indexes = range(len(molecules)) + if verbose: + indexes = tqdm(indexes, "Constructing molecules from SDF") + for i in indexes: + with utils.capture_rdkit_log() as log: + mol = molecules[i] + if mol is None: + continue + if log.content: + print(log.content) + d = data.Molecule.from_molecule(mol, **kwargs) + if node_position: + with d.node(): + d.node_position = torch.tensor([feature.atom_position(atom) for atom in mol.GetAtoms()]) + self.data.append(d) + for k in targets: + self.targets[k].append(targets[k][i]) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/secondary_structure.py b/build/lib/torchdrug/datasets/secondary_structure.py new file mode 100644 index 00000000..dff07da5 --- /dev/null +++ b/build/lib/torchdrug/datasets/secondary_structure.py @@ -0,0 +1,70 @@ +import os + +import torch +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.SecondaryStructure") +@utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) +class SecondaryStructure(data.ProteinDataset): + """ + Secondary structure labels for a set of proteins determined by the local structures + of protein residues in their natural state + + Statistics: + - #Train: 8,678 + - #Valid: 2,170 + - #Test: 513 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/secondary_structure.tar.gz" + md5 = "2f61e8e09c215c032ef5bc8b910c8e97" + splits = ["train", "valid", "casp12", "ts115", "cb513"] + target_fields = ["ss3", "valid_mask"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "secondary_structure/secondary_structure_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) + + def get_item(self, index): + if self.lazy: + graph = data.Protein.from_sequence(self.sequences[index], **self.kwargs) + else: + graph = self.data[index] + with graph.residue(): + target = torch.as_tensor(self.targets["ss3"][index], dtype=torch.long) + graph.target = target + mask = torch.as_tensor(self.targets["valid_mask"][index], dtype=torch.bool) + graph.mask = mask + item = {"graph": graph} + if self.transform: + item = self.transform(item) + return item + + def split(self, keys=None): + keys = keys or self.splits + offset = 0 + splits = [] + for split_name, num_sample in zip(self.splits, self.num_samples): + if split_name in keys: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/sider.py b/build/lib/torchdrug/datasets/sider.py new file mode 100644 index 00000000..39a86c60 --- /dev/null +++ b/build/lib/torchdrug/datasets/sider.py @@ -0,0 +1,37 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.SIDER") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class SIDER(data.MoleculeDataset): + """ + Marketed drugs and adverse drug reactions (ADR) dataset, grouped into 27 system organ classes. + + Statistics: + - #Molecule: 1,427 + - #Classification task: 27 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/sider.csv.gz" + md5 = "77c0ef421f7cc8ce963c5836c8761fd2" + target_fields = None # pick all targets + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + csv_file = utils.extract(zip_file) + + self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/solubility.py b/build/lib/torchdrug/datasets/solubility.py new file mode 100644 index 00000000..29c019b9 --- /dev/null +++ b/build/lib/torchdrug/datasets/solubility.py @@ -0,0 +1,51 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Solubility") +@utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) +class Solubility(data.ProteinDataset): + """ + Proteins with binary labels indicating their solubility. + + Statistics: + - #Train: 62,478 + - #Valid: 6,942 + - #Test: 1,999 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/solubility.tar.gz" + md5 = "8a8612b7bfa2ed80375db6e465ccf77e" + splits = ["train", "valid", "test"] + target_fields = ["solubility"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "solubility/solubility_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/stability.py b/build/lib/torchdrug/datasets/stability.py new file mode 100644 index 00000000..e6f8a7a3 --- /dev/null +++ b/build/lib/torchdrug/datasets/stability.py @@ -0,0 +1,51 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Stability") +@utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) +class Stability(data.ProteinDataset): + """ + The stability values of proteins under natural environment. + + Statistics: + - #Train: 53,571 + - #Valid: 2,512 + - #Test: 12,851 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/stability.tar.gz" + md5 = "aa1e06eb5a59e0ecdae581e9ea029675" + splits = ["train", "valid", "test"] + target_fields = ["stability_score"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "stability/stability_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/subcellular_localization.py b/build/lib/torchdrug/datasets/subcellular_localization.py new file mode 100644 index 00000000..d77cc475 --- /dev/null +++ b/build/lib/torchdrug/datasets/subcellular_localization.py @@ -0,0 +1,51 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.SubcellularLocalization") +@utils.copy_args(data.ProteinDataset.load_lmdbs, ignore=("target_fields",)) +class SubcellularLocalization(data.ProteinDataset): + """ + Class labels indicating where a natural protein locates in the cell. + + Statistics: + - #Train: 8,945 + - #Valid: 2,248 + - #Test: 2,768 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/subcellular_localization.tar.gz" + md5 = "37cb6138b8d4603512530458b7c8a77d" + splits = ["train", "valid", "test"] + target_fields = ["localization"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "subcellular_localization/subcellular_localization_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/tox21.py b/build/lib/torchdrug/datasets/tox21.py new file mode 100644 index 00000000..5e73e58e --- /dev/null +++ b/build/lib/torchdrug/datasets/tox21.py @@ -0,0 +1,39 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.Tox21") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class Tox21(data.MoleculeDataset): + """ + Qualitative toxicity measurements on 12 biological targets, including nuclear receptors + and stress response pathways. + + Statistics: + - #Molecule: 7,831 + - #Classification task: 12 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/tox21.csv.gz" + md5 = "2882d69e70bba0fec14995f26787cc25" + target_fields = ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", + "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + csv_file = utils.extract(zip_file) + + self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/toxcast.py b/build/lib/torchdrug/datasets/toxcast.py new file mode 100644 index 00000000..3bb05869 --- /dev/null +++ b/build/lib/torchdrug/datasets/toxcast.py @@ -0,0 +1,37 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.ToxCast") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class ToxCast(data.MoleculeDataset): + """ + Toxicology data based on in vitro high-throughput screening. + + Statistics: + - #Molecule: 8,575 + - #Classification task: 617 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/toxcast_data.csv.gz" + md5 = "92911bbf9c1e2ad85231014859388cd6" + target_fields = None # pick all targets + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + csv_file = utils.extract(zip_file) + + self.load_csv(csv_file, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/uspto50k.py b/build/lib/torchdrug/datasets/uspto50k.py new file mode 100644 index 00000000..1169985e --- /dev/null +++ b/build/lib/torchdrug/datasets/uspto50k.py @@ -0,0 +1,264 @@ +import os +import copy +from collections import defaultdict + +import numpy as np +import networkx as nx +from tqdm import tqdm +from rdkit import Chem + +import torch +from torch.utils import data as torch_data +from torch_scatter import scatter_max + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.USPTO50k") +@utils.copy_args(data.ReactionDataset.load_csv, ignore=("smiles_field", "target_fields")) +class USPTO50k(data.ReactionDataset): + """ + Chemical reactions extracted from USPTO patents. + + Statistics: + - #Reaction: 50,017 + - #Reaction class: 10 + + Parameters: + path (str): path to store the dataset + as_synthon (bool, optional): whether decompose (reactant, product) pairs into (reactant, synthon) pairs + verbose (int, optional): output verbose level + **kwargs + """ + + target_fields = ["class"] + target_alias = {"class": "reaction"} + + reaction_names = ["Heteroatom alkylation and arylation", + "Acylation and related processes", + "C-C bond formation", + "Heterocycle formation", + "Protections", + "Deprotections", + "Reductions", + "Oxidations", + "Functional group interconversion (FGI)", + "Functional group addition (FGA)"] + + url = "https://raw.githubusercontent.com/connorcoley/retrosim/master/retrosim/data/data_processed.csv" + md5 = "404c361dd1568fbdb4d16ca588953749" + + def __init__(self, path, as_synthon=False, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + self.as_synthon = as_synthon + + file_name = utils.download(self.url, path, md5=self.md5) + + self.load_csv(file_name, smiles_field="rxn_smiles", target_fields=self.target_fields, verbose=verbose, + **kwargs) + + if as_synthon: + prefix = "Computing synthons" + process_fn = self._get_synthon + else: + prefix = "Computing reaction centers" + process_fn = self._get_reaction_center + + data = self.data + targets = self.targets + self.data = [] + self.targets = defaultdict(list) + indexes = range(len(data)) + if verbose: + indexes = tqdm(indexes, prefix) + invalid = 0 + for i in indexes: + reactant, product = data[i] + reactant.bond_stereo[:] = 0 + product.bond_stereo[:] = 0 + + reactants, products = process_fn(reactant, product) + if not reactants: + invalid += 1 + continue + + self.data += zip(reactants, products) + for k in targets: + new_k = self.target_alias.get(k, k) + self.targets[new_k] += [targets[k][i] - 1] * len(reactants) + self.targets["sample id"] += [i] * len(reactants) + + self.valid_rate = 1 - invalid / len(data) + + def _get_difference(self, reactant, product): + product2id = product.atom_map + id2reactant = torch.zeros(product2id.max() + 1, dtype=torch.long) + id2reactant[reactant.atom_map] = torch.arange(reactant.num_node) + prod2react = id2reactant[product2id] + + # check edges in the product + product = product.directed() + # O(n^2) brute-force match is faster than O(nlogn) data.Graph.match for small molecules + mapped_edge = product.edge_list.clone() + mapped_edge[:, :2] = prod2react[mapped_edge[:, :2]] + is_same_index = mapped_edge.unsqueeze(0) == reactant.edge_list.unsqueeze(1) + has_typed_edge = is_same_index.all(dim=-1).any(dim=0) + has_edge = is_same_index[:, :, :2].all(dim=-1).any(dim=0) + is_added = ~has_edge + is_modified = has_edge & ~has_typed_edge + edge_added = product.edge_list[is_added, :2] + edge_modified = product.edge_list[is_modified, :2] + + return edge_added, edge_modified, prod2react + + def _get_reaction_center(self, reactant, product): + edge_added, edge_modified, prod2react = self._get_difference(reactant, product) + + edge_label = torch.zeros(product.num_edge, dtype=torch.long) + node_label = torch.zeros(product.num_node, dtype=torch.long) + + if len(edge_added) > 0: + if len(edge_added) == 1: # add a single edge + any = -torch.ones(1, 1, dtype=torch.long) + pattern = torch.cat([edge_added, any], dim=-1) + index, num_match = product.match(pattern) + assert num_match.item() == 1 + edge_label[index] = 1 + h, t = edge_added[0] + reaction_center = torch.tensor([product.atom_map[h], product.atom_map[t]]) + else: + if len(edge_modified) == 1: # modify a single edge + h, t = edge_modified[0] + if product.degree_in[h] == 1: + node_label[h] = 1 + reaction_center = torch.tensor([product.atom_map[h], 0]) + elif product.degree_in[t] == 1: + node_label[t] = 1 + reaction_center = torch.tensor([product.atom_map[t], 0]) + else: + # pretend the reaction center is h + node_label[h] = 1 + reaction_center = torch.tensor([product.atom_map[h], 0]) + else: + product_hs = torch.tensor([atom.GetTotalNumHs() for atom in product.to_molecule().GetAtoms()]) + reactant_hs = torch.tensor([atom.GetTotalNumHs() for atom in reactant.to_molecule().GetAtoms()]) + atom_modified = (product_hs != reactant_hs[prod2react]).nonzero().flatten() + if len(atom_modified) == 1: # modify single node + node_label[atom_modified] = 1 + reaction_center = torch.tensor([product.atom_map[atom_modified[0]], 0]) + + if edge_label.sum() + node_label.sum() == 0: + return [], [] + + with product.edge(): + product.edge_label = edge_label + with product.node(): + product.node_label = node_label + with reactant.graph(): + reactant.reaction_center = reaction_center + with product.graph(): + product.reaction_center = reaction_center + return [reactant], [product] + + def _get_synthon(self, reactant, product): + edge_added, edge_modified, prod2react = self._get_difference(reactant, product) + + reactants = [] + synthons = [] + + if len(edge_added) > 0: + if len(edge_added) == 1: # add a single edge + reverse_edge = edge_added.flip(1) + any = -torch.ones(2, 1, dtype=torch.long) + pattern = torch.cat([edge_added, reverse_edge]) + pattern = torch.cat([pattern, any], dim=-1) + index, num_match = product.match(pattern) + edge_mask = torch.ones(product.num_edge, dtype=torch.bool) + edge_mask[index] = 0 + product = product.edge_mask(edge_mask) + _reactants = reactant.connected_components()[0] + _synthons = product.connected_components()[0] + assert len(_synthons) >= len(_reactants) # because a few samples contain multiple products + + h, t = edge_added[0] + reaction_center = torch.tensor([product.atom_map[h], product.atom_map[t]]) + with _reactants.graph(): + _reactants.reaction_center = reaction_center.expand(len(_reactants), -1) + with _synthons.graph(): + _synthons.reaction_center = reaction_center.expand(len(_synthons), -1) + # reactant / sython can be uniquely indexed by their maximal atom mapping ID + reactant_id = scatter_max(_reactants.atom_map, _reactants.node2graph, dim_size=len(_reactants))[0] + synthon_id = scatter_max(_synthons.atom_map, _synthons.node2graph, dim_size=len(_synthons))[0] + react2synthon = (reactant_id.unsqueeze(-1) == synthon_id.unsqueeze(0)).long().argmax(-1) + react2synthon = react2synthon.tolist() + for r, s in enumerate(react2synthon): + reactants.append(_reactants[r]) + synthons.append(_synthons[s]) + else: + num_cc = reactant.connected_components()[1] + assert num_cc == 1 + + if len(edge_modified) == 1: # modify a single edge + synthon = product + h, t = edge_modified[0] + if product.degree_in[h] == 1: + reaction_center = torch.tensor([product.atom_map[h], 0]) + elif product.degree_in[t] == 1: + reaction_center = torch.tensor([product.atom_map[t], 0]) + else: + # pretend the reaction center is h + reaction_center = torch.tensor([product.atom_map[h], 0]) + with reactant.graph(): + reactant.reaction_center = reaction_center + with synthon.graph(): + synthon.reaction_center = reaction_center + reactants.append(reactant) + synthons.append(synthon) + else: + product_hs = torch.tensor([atom.GetTotalNumHs() for atom in product.to_molecule().GetAtoms()]) + reactant_hs = torch.tensor([atom.GetTotalNumHs() for atom in reactant.to_molecule().GetAtoms()]) + atom_modified = (product_hs != reactant_hs[prod2react]).nonzero().flatten() + if len(atom_modified) == 1: # modify single node + synthon = product + reaction_center = torch.tensor([product.atom_map[atom_modified[0]], 0]) + with reactant.graph(): + reactant.reaction_center = reaction_center + with synthon.graph(): + synthon.reaction_center = reaction_center + reactants.append(reactant) + synthons.append(synthon) + + return reactants, synthons + + def split(self, ratios=(0.8, 0.1, 0.1)): + react2index = defaultdict(list) + react2sample = defaultdict(list) + for i in range(len(self)): + reaction = self.targets["reaction"][i] + sample_id = self.targets["sample id"][i] + react2index[reaction].append(i) + react2sample[reaction].append(sample_id) + + indexes = [[] for _ in ratios] + for reaction in react2index: + num_sample = len(set(react2sample[reaction])) + key_lengths = [int(round(num_sample * ratio)) for ratio in ratios] + key_lengths[-1] = num_sample - sum(key_lengths[:-1]) + react_indexes = data.key_split(react2index[reaction], react2sample[reaction], key_lengths=key_lengths) + for index, react_index in zip(indexes, react_indexes): + index += [i for i in react_index] + + return [torch_data.Subset(self, index) for index in indexes] + + @property + def num_reaction_type(self): + return len(self.reaction_types) + + @utils.cached_property + def reaction_types(self): + """All reaction types.""" + return sorted(set(self.target["class"])) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/wn18.py b/build/lib/torchdrug/datasets/wn18.py new file mode 100644 index 00000000..dca6c95e --- /dev/null +++ b/build/lib/torchdrug/datasets/wn18.py @@ -0,0 +1,106 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.WN18") +class WN18(data.KnowledgeGraphDataset): + """ + WordNet knowledge base. + + Statistics: + - #Entity: 40,943 + - #Relation: 18 + - #Triplet: 151,442 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + """ + + urls = [ + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/train.txt", + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/valid.txt", + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/test.txt", + ] + md5s = [ + "7d68324d293837ac165c3441a6c8b0eb", + "f4f66fec0ca83b5ebe7ad7003404e61d", + "b035247a8916c7ec3443fa949e1ff02c" + ] + + def __init__(self, path, verbose=1): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + txt_files = [] + for url, md5 in zip(self.urls, self.md5s): + save_file = "wn18_%s" % os.path.basename(url) + txt_file = utils.download(url, self.path, save_file=save_file, md5=md5) + txt_files.append(txt_file) + + self.load_tsvs(txt_files, verbose=verbose) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + + +@R.register("datasets.WN18RR") +class WN18RR(data.KnowledgeGraphDataset): + """ + A filtered version of WN18 dataset without trivial cases. + + Statistics: + - #Entity: 40,943 + - #Relation: 11 + - #Triplet: 93,003 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + """ + + urls = [ + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/train.txt", + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/valid.txt", + "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/test.txt", + ] + md5s = [ + "35e81af3ae233327c52a87f23b30ad3c", + "74a2ee9eca9a8d31f1a7d4d95b5e0887", + "2b45ba1ba436b9d4ff27f1d3511224c9" + ] + + def __init__(self, path, verbose=1): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + txt_files = [] + for url, md5 in zip(self.urls, self.md5s): + save_file = "wn18rr_%s" % os.path.basename(url) + txt_file = utils.download(url, self.path, save_file=save_file, md5=md5) + txt_files.append(txt_file) + + self.load_tsvs(txt_files, verbose=verbose) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/yago310.py b/build/lib/torchdrug/datasets/yago310.py new file mode 100644 index 00000000..87e58af4 --- /dev/null +++ b/build/lib/torchdrug/datasets/yago310.py @@ -0,0 +1,56 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.YAGO310") +class YAGO310(data.KnowledgeGraphDataset): + """ + Subset of YAGO3 knowledge base for knowledge graph reasoning. + + Statistics: + - #Entity: 123,182 + - #Relation: 37 + - #Triplet: 1,089,040 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + """ + + urls = [ + "https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/train.txt", + "https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/valid.txt", + "https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/test.txt", + ] + md5s = [ + "a9da8f583ec3920570eeccf07199229a", + "2d679a906f2b1ac29d74d5c948c1ad09", + "14bf97890b2fee774dbce5f326acd189" + ] + + def __init__(self, path, verbose=1): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + txt_files = [] + for url, md5 in zip(self.urls, self.md5s): + save_file = "yago310_%s" % os.path.basename(url) + txt_file = utils.download(url, self.path, save_file=save_file, md5=md5) + txt_files.append(txt_file) + + self.load_tsvs(txt_files, verbose=verbose) + + def split(self): + offset = 0 + splits = [] + for num_sample in self.num_samples: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits diff --git a/build/lib/torchdrug/datasets/yeast_ppi.py b/build/lib/torchdrug/datasets/yeast_ppi.py new file mode 100644 index 00000000..f7dd9470 --- /dev/null +++ b/build/lib/torchdrug/datasets/yeast_ppi.py @@ -0,0 +1,67 @@ +import os + +from torch.utils import data as torch_data + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.YeastPPI") +@utils.copy_args(data.ProteinPairDataset.load_lmdbs, ignore=("sequence_field", "target_fields")) +class YeastPPI(data.ProteinPairDataset): + """ + Binary labels indicating whether two yeast proteins interact or not. + + Statistics: + - #Train: 1,668 + - #Valid: 131 + - #Test: 373 + + Parameters: + path (str): the path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/ppidata/yeast_ppi.zip" + md5 = "3993b02c3080d74996cddf6fe798b1e8" + splits = ["train", "valid", "test", "cross_species_test"] + target_fields = ["interaction"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + zip_file = utils.download(self.url, path, md5=self.md5) + data_path = utils.extract(zip_file) + lmdb_files = [os.path.join(data_path, "yeast_ppi/yeast_ppi_%s.lmdb" % split) + for split in self.splits] + + self.load_lmdbs(lmdb_files, sequence_field=["primary_1", "primary_2"], + target_fields=self.target_fields, verbose=verbose, **kwargs) + + def split(self, keys=None): + keys = keys or self.splits + offset = 0 + splits = [] + for split_name, num_sample in zip(self.splits, self.num_samples): + if split_name in keys: + split = torch_data.Subset(self, range(offset, offset + num_sample)) + splits.append(split) + offset += num_sample + return splits + + def get_item(self, index): + if self.lazy: + graph1 = data.Protein.from_sequence(self.sequences[index][0], **self.kwargs) + graph2 = data.Protein.from_sequence(self.sequences[index][1], **self.kwargs) + else: + graph1 = self.data[index][0] + graph2 = self.data[index][1] + item = {"graph1": graph1, "graph2": graph2} + item.update({k: v[index] for k, v in self.targets.items()}) + if self.transform: + item = self.transform(item) + return item \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/zinc250k.py b/build/lib/torchdrug/datasets/zinc250k.py new file mode 100644 index 00000000..24959c00 --- /dev/null +++ b/build/lib/torchdrug/datasets/zinc250k.py @@ -0,0 +1,37 @@ +import os + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.ZINC250k") +@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields")) +class ZINC250k(data.MoleculeDataset): + """ + Subset of ZINC compound database for virtual screening. + + Statistics: + - #Molecule: 498,910 + - #Regression task: 2 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/" \ + "250k_rndm_zinc_drugs_clean_3.csv" + md5 = "b59078b2b04c6e9431280e3dc42048d5" + target_fields = ["logP", "qed"] + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + file_name = utils.download(self.url, path, md5=self.md5) + + self.load_csv(file_name, smiles_field="smiles", target_fields=self.target_fields, + verbose=verbose, **kwargs) \ No newline at end of file diff --git a/build/lib/torchdrug/datasets/zinc2m.py b/build/lib/torchdrug/datasets/zinc2m.py new file mode 100644 index 00000000..957748c0 --- /dev/null +++ b/build/lib/torchdrug/datasets/zinc2m.py @@ -0,0 +1,53 @@ +import os +import csv +from tqdm import tqdm +import shutil + +from torchdrug import data, utils +from torchdrug.core import Registry as R + + +@R.register("datasets.ZINC2m") +@utils.copy_args(data.MoleculeDataset.load_smiles, ignore=("smiles_field", "target_fields")) +class ZINC2m(data.MoleculeDataset): + """ + ZINC compound database for virtual screening. + This dataset doesn't contain any label information. + + Statistics: + - #Molecule: 2,000,000 + + Parameters: + path (str): path to store the dataset + verbose (int, optional): output verbose level + **kwargs + """ + + target_fields = [] + + url = "http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip" + md5 = "e95da4dffa0fdb1d4af2726bdf8c23e0" + member = "dataset/zinc_standard_agent/processed/smiles.csv" + + def __init__(self, path, verbose=1, **kwargs): + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + zip_file_name = utils.download(self.url, path, md5=self.md5) + + save_file = utils.extract(zip_file=zip_file_name, member=self.member) + neo_save_file = os.path.join(os.path.dirname(zip_file_name), 'zinc2m_'+os.path.basename(self.member)) + shutil.move(save_file, neo_save_file) + + with open(neo_save_file, "r") as fin: + reader = csv.reader(fin) + if verbose: + reader = iter(tqdm(reader, "Loading %s" % path, utils.get_line_count(neo_save_file))) + smiles_list = [] + + for idx, values in enumerate(reader): + smiles = values[0] + smiles_list.append(smiles) + + targets = {} + self.load_smiles(smiles_list, targets, lazy=True, verbose=verbose, **kwargs) diff --git a/build/lib/torchdrug/layers/__init__.py b/build/lib/torchdrug/layers/__init__.py new file mode 100644 index 00000000..d9df9a8c --- /dev/null +++ b/build/lib/torchdrug/layers/__init__.py @@ -0,0 +1,37 @@ +from .common import MultiLayerPerceptron, GaussianSmearing, MutualInformation, PairNorm, InstanceNorm, Sequential, \ + SinusoidalPositionEmbedding + +from .block import ProteinResNetBlock, SelfAttentionBlock, ProteinBERTBlock +from .conv import MessagePassingBase, GraphConv, GraphAttentionConv, RelationalGraphConv, GraphIsomorphismConv, \ + NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv, GeometricRelationalGraphConv +from .pool import DiffPool, MinCutPool +from .readout import MeanReadout, SumReadout, MaxReadout, AttentionReadout, Softmax, Set2Set, Sort +from .flow import ConditionalFlow +from .sampler import NodeSampler, EdgeSampler +from .geometry import GraphConstruction, SpatialLineGraph +from . import distribution, functional + +# alias +MLP = MultiLayerPerceptron +RBF = GaussianSmearing +GCNConv = GraphConv +RGCNConv = RelationalGraphConv +GINConv = GraphIsomorphismConv +NFPConv = NeuralFingerprintConv +CFConv = ContinuousFilterConv +MPConv = MessagePassing + +__all__ = [ + "MultiLayerPerceptron", "GaussianSmearing", "MutualInformation", "PairNorm", "InstanceNorm", "Sequential", + "SinusoidalPositionEmbedding", + "MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv", + "NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv", "GeometricRelationalGraphConv", + "DiffPool", "MinCutPool", + "MeanReadout", "SumReadout", "MaxReadout", "AttentionReadout", "Softmax", "Set2Set", "Sort", + "ConditionalFlow", + "NodeSampler", "EdgeSampler", + "GraphConstruction", "SpatialLineGraph", + "distribution", "functional", + "MLP", "RBF", "GCNConv", "RGCNConv", "GINConv", "NFPConv", "CFConv", "MPConv", + "ProteinResNetBlock", "SelfAttentionBlock", "ProteinBERTBlock", +] \ No newline at end of file diff --git a/build/lib/torchdrug/layers/block.py b/build/lib/torchdrug/layers/block.py new file mode 100644 index 00000000..f6bd5ac7 --- /dev/null +++ b/build/lib/torchdrug/layers/block.py @@ -0,0 +1,167 @@ +from torch import nn +from torch.nn import functional as F + +from torchdrug import layers + + +class ProteinResNetBlock(nn.Module): + """ + Convolutional block with residual connection from `Deep Residual Learning for Image Recognition`_. + + .. _Deep Residual Learning for Image Recognition: + https://arxiv.org/pdf/1512.03385.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): output dimension + kernel_size (int, optional): size of convolutional kernel + stride (int, optional): stride of convolution + padding (int, optional): padding added to both sides of the input + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, output_dim, kernel_size=3, stride=1, padding=1, activation="gelu"): + super(ProteinResNetBlock, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + self.conv1 = nn.Conv1d(input_dim, output_dim, kernel_size, stride, padding, bias=False) + self.layer_norm1 = nn.LayerNorm(output_dim) + self.conv2 = nn.Conv1d(output_dim, output_dim, kernel_size, stride, padding, bias=False) + self.layer_norm2 = nn.LayerNorm(output_dim) + + def forward(self, input, mask): + """ + Perform 1D convolutions over the input. + + Parameters: + input (Tensor): input representations of shape `(..., length, dim)` + mask (Tensor): bool mask of shape `(..., length, dim)` + """ + identity = input + + input = input * mask # (B, L, d) + out = self.conv1(input.transpose(1, 2)).transpose(1, 2) + out = self.layer_norm1(out) + out = self.activation(out) + + out = out * mask + out = self.conv2(out.transpose(1, 2)).transpose(1, 2) + out = self.layer_norm2(out) + + out += identity + out = self.activation(out) + + return out + + +class SelfAttentionBlock(nn.Module): + """ + Multi-head self-attention block from + `Attention Is All You Need`_. + + .. _Attention Is All You Need: + https://arxiv.org/pdf/1706.03762.pdf + + Parameters: + hidden_dim (int): hidden dimension + num_heads (int): number of attention heads + dropout (float, optional): dropout ratio of attention maps + """ + + def __init__(self, hidden_dim, num_heads, dropout=0.0): + super(SelfAttentionBlock, self).__init__() + if hidden_dim % num_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_dim, num_heads)) + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_size = hidden_dim // num_heads + + self.query = nn.Linear(hidden_dim, hidden_dim) + self.key = nn.Linear(hidden_dim, hidden_dim) + self.value = nn.Linear(hidden_dim, hidden_dim) + + self.attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout) + + def forward(self, input, mask): + """ + Perform self attention over the input. + + Parameters: + input (Tensor): input representations of shape `(..., length, dim)` + mask (Tensor): bool mask of shape `(..., length)` + """ + query = self.query(input).transpose(0, 1) + key = self.key(input).transpose(0, 1) + value = self.value(input).transpose(0, 1) + + mask = (~mask.bool()).squeeze(-1) + output = self.attn(query, key, value, key_padding_mask=mask)[0].transpose(0, 1) + + return output + + +class ProteinBERTBlock(nn.Module): + """ + Transformer encoding block from + `Attention Is All You Need`_. + + .. _Attention Is All You Need: + https://arxiv.org/pdf/1706.03762.pdf + + Parameters: + input_dim (int): input dimension + hidden_dim (int): hidden dimension + num_heads (int): number of attention heads + attention_dropout (float, optional): dropout ratio of attention maps + hidden_dropout (float, optional): dropout ratio of hidden features + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, hidden_dim, num_heads, attention_dropout=0, + hidden_dropout=0, activation="relu"): + super(ProteinBERTBlock, self).__init__() + self.input_dim = input_dim + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.hidden_dropout = hidden_dropout + self.hidden_dim = hidden_dim + + self.attention = SelfAttentionBlock(input_dim, num_heads, attention_dropout) + self.linear1 = nn.Linear(input_dim, input_dim) + self.dropout1 = nn.Dropout(hidden_dropout) + self.layer_norm1 = nn.LayerNorm(input_dim) + + self.intermediate = layers.MultiLayerPerceptron(input_dim, hidden_dim, activation=activation) + + self.linear2 = nn.Linear(hidden_dim, input_dim) + self.dropout2 = nn.Dropout(hidden_dropout) + self.layer_norm2 = nn.LayerNorm(input_dim) + + def forward(self, input, mask): + """ + Perform a BERT-block transformation over the input. + + Parameters: + input (Tensor): input representations of shape `(..., length, dim)` + mask (Tensor): bool mask of shape `(..., length)` + """ + x = self.attention(input, mask) + x = self.linear1(x) + x = self.dropout1(x) + x = self.layer_norm1(x + input) + + hidden = self.intermediate(x) + + hidden = self.linear2(hidden) + hidden = self.dropout2(hidden) + output = self.layer_norm2(hidden + x) + + return output \ No newline at end of file diff --git a/build/lib/torchdrug/layers/common.py b/build/lib/torchdrug/layers/common.py new file mode 100644 index 00000000..7334ed5f --- /dev/null +++ b/build/lib/torchdrug/layers/common.py @@ -0,0 +1,349 @@ +import inspect +import warnings +from collections.abc import Sequence + +import torch +from torch import nn +from torch.nn import functional as F +from torch_scatter import scatter_mean + +from torchdrug.layers import functional + + +class MultiLayerPerceptron(nn.Module): + """ + Multi-layer Perceptron. + Note there is no batch normalization, activation or dropout in the last layer. + + Parameters: + input_dim (int): input dimension + hidden_dim (list of int): hidden dimensions + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + dropout (float, optional): dropout rate + """ + + def __init__(self, input_dim, hidden_dims, short_cut=False, batch_norm=False, activation="relu", dropout=0): + super(MultiLayerPerceptron, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.dims = [input_dim] + hidden_dims + self.short_cut = short_cut + + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) + if batch_norm: + self.batch_norms = nn.ModuleList() + for i in range(len(self.dims) - 2): + self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1])) + else: + self.batch_norms = None + + def forward(self, input): + """""" + layer_input = input + + for i, layer in enumerate(self.layers): + hidden = layer(layer_input) + if i < len(self.layers) - 1: + if self.batch_norms: + x = hidden.flatten(0, -2) + hidden = self.batch_norms[i](x).view_as(hidden) + hidden = self.activation(hidden) + if self.dropout: + hidden = self.dropout(hidden) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + layer_input = hidden + + return hidden + + +class GaussianSmearing(nn.Module): + r""" + Gaussian smearing from + `SchNet: A continuous-filter convolutional neural network for modeling quantum interactions`_. + + There are two modes for Gaussian smearing. + + Non-centered mode: + + .. math:: + + \mu = [0, 1, ..., n], \sigma = [1, 1, ..., 1] + + Centered mode: + + .. math:: + + \mu = [0, 0, ..., 0], \sigma = [0, 1, ..., n] + + .. _SchNet\: A continuous-filter convolutional neural network for modeling quantum interactions: + https://arxiv.org/pdf/1706.08566.pdf + + Parameters: + start (int, optional): minimal input value + stop (int, optional): maximal input value + num_kernel (int, optional): number of RBF kernels + centered (bool, optional): centered mode or not + learnable (bool, optional): learnable gaussian parameters or not + """ + + def __init__(self, start=0, stop=5, num_kernel=100, centered=False, learnable=False): + super(GaussianSmearing, self).__init__() + if centered: + mu = torch.zeros(num_kernel) + sigma = torch.linspace(start, stop, num_kernel) + else: + mu = torch.linspace(start, stop, num_kernel) + sigma = torch.ones(num_kernel) * (mu[1] - mu[0]) + + if learnable: + self.mu = nn.Parameter(mu) + self.sigma = nn.Parameter(sigma) + else: + self.register_buffer("mu", mu) + self.register_buffer("sigma", sigma) + + def forward(self, x, y): + """ + Compute smeared gaussian features between data. + + Parameters: + x (Tensor): data of shape :math:`(..., d)` + y (Tensor): data of shape :math:`(..., d)` + Returns: + Tensor: features of shape :math:`(..., num\_kernel)` + """ + distance = (x - y).norm(2, dim=-1, keepdim=True) + z = (distance - self.mu) / self.sigma + prob = torch.exp(-0.5 * z * z) + return prob + + +class PairNorm(nn.Module): + """ + Pair normalization layer proposed in `PairNorm: Tackling Oversmoothing in GNNs`_. + + .. _PairNorm\: Tackling Oversmoothing in GNNs: + https://openreview.net/pdf?id=rkecl1rtwB + + Parameters: + scale_individual (bool, optional): additionally normalize each node representation to have the same L2-norm + """ + + eps = 1e-8 + + def __init__(self, scale_individual=False): + super(PairNorm, self).__init__() + self.scale_individual = scale_individual + + def forward(self, graph, input): + """""" + if graph.batch_size > 1: + warnings.warn("PairNorm is proposed for a single graph, but now applied to a batch of graphs.") + + x = input.flatten(1) + x = x - x.mean(dim=0) + if self.scale_individual: + output = x / (x.norm(dim=-1, keepdim=True) + self.eps) + else: + output = x * x.shape[0] ** 0.5 / (x.norm() + self.eps) + return output.view_as(input) + + +class InstanceNorm(nn.modules.instancenorm._InstanceNorm): + """ + Instance normalization for graphs. This layer follows the definition in + `GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training`_. + + .. _GraphNorm\: A Principled Approach to Accelerating Graph Neural Network Training: + https://arxiv.org/pdf/2009.03294.pdf + + Parameters: + input_dim (int): input dimension + eps (float, optional): epsilon added to the denominator + affine (bool, optional): use learnable affine parameters or not + """ + def __init__(self, input_dim, eps=1e-5, affine=False): + super(InstanceNorm, self).__init__(input_dim, eps, affine=affine) + + def forward(self, graph, input): + """""" + assert (graph.num_nodes >= 1).all() + + mean = scatter_mean(input, graph.node2graph, dim=0, dim_size=graph.batch_size) + centered = input - mean[graph.node2graph] + var = scatter_mean(centered ** 2, graph.node2graph, dim=0, dim_size=graph.batch_size) + std = (var + self.eps).sqrt() + output = centered / std[graph.node2graph] + + if self.affine: + output = torch.addcmul(self.bias, self.weight, output) + return output + + +class MutualInformation(nn.Module): + """ + Mutual information estimator from + `Learning deep representations by mutual information estimation and maximization`_. + + .. _Learning deep representations by mutual information estimation and maximization: + https://arxiv.org/pdf/1808.06670.pdf + + Parameters: + input_dim (int): input dimension + num_mlp_layer (int, optional): number of MLP layers + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, num_mlp_layer=2, activation="relu"): + super(MutualInformation, self).__init__() + self.x_mlp = MultiLayerPerceptron(input_dim, [input_dim] * num_mlp_layer, activation=activation) + self.y_mlp = MultiLayerPerceptron(input_dim, [input_dim] * num_mlp_layer, activation=activation) + + def forward(self, x, y, pair_index=None): + """""" + x = self.x_mlp(x) + y = self.y_mlp(y) + score = x @ y.t() + score = score.flatten() + + if pair_index is None: + assert len(x) == len(y) + pair_index = torch.arange(len(x), device=x.device).unsqueeze(-1).expand(-1, 2) + + index = pair_index[:, 0] * len(y) + pair_index[:, 1] + positive = torch.zeros_like(score, dtype=torch.bool) + positive[index] = 1 + negative = ~positive + + mutual_info = - functional.shifted_softplus(-score[positive]).mean() \ + - functional.shifted_softplus(score[negative]).mean() + return mutual_info + + +class Sequential(nn.Sequential): + """ + Improved sequential container. + Modules will be called in the order they are passed to the constructor. + + Compared to the vanilla nn.Sequential, this layer additionally supports the following features. + + 1. Multiple input / output arguments. + + >>> # layer1 signature: (...) -> (a, b) + >>> # layer2 signature: (a, b) -> (...) + >>> layer = layers.Sequential(layer1, layer2) + + 2. Global arguments. + + >>> # layer1 signature: (graph, a) -> b + >>> # layer2 signature: (graph, b) -> c + >>> layer = layers.Sequential(layer1, layer2, global_args=("graph",)) + + Note the global arguments don't need to be present in every layer. + + >>> # layer1 signature: (graph, a) -> b + >>> # layer2 signature: b -> c + >>> # layer3 signature: (graph, c) -> d + >>> layer = layers.Sequential(layer1, layer2, global_args=("graph",)) + + 3. Dict outputs. + + >>> # layer1 signature: a -> {"b": b, "c": c} + >>> # layer2 signature: b -> d + >>> layer = layers.Sequential(layer1, layer2, allow_unused=True) + + When dict outputs are used with global arguments, the global arguments can be explicitly + overwritten by any layer outputs. + + >>> # layer1 signature: (graph, a) -> {"graph": graph, "b": b} + >>> # layer2 signature: (graph, b) -> c + >>> # layer2 takes in the graph output by layer1 + >>> layer = layers.Sequential(layer1, layer2, global_args=("graph",)) + """ + + def __init__(self, *args, global_args=None, allow_unused=False): + super(Sequential, self).__init__(*args) + if global_args is not None: + self.global_args = set(global_args) + else: + self.global_args = {} + self.allow_unused = allow_unused + + def forward(self, *args, **kwargs): + """""" + global_kwargs = {} + for i, module in enumerate(self._modules.values()): + sig = inspect.signature(module.forward) + parameters = list(sig.parameters.values()) + param_names = [param.name for param in parameters] + j = 0 + for name in param_names: + if j == len(args): + break + if name in kwargs: + continue + if name in global_kwargs and name not in kwargs: + kwargs[name] = global_kwargs[name] + continue + kwargs[name] = args[j] + j += 1 + if self.allow_unused: + param_names = set(param_names) + # pop unused kwargs + kwargs = {k: v for k, v in kwargs.items() if k in param_names} + if j < len(args): + raise TypeError("too many positional arguments") + + output = module(**kwargs) + + global_kwargs.update({k: v for k, v in kwargs.items() if k in self.global_args}) + args = [] + kwargs = {} + if isinstance(output, dict): + kwargs.update(output) + elif isinstance(output, Sequence): + args += list(output) + else: + args.append(output) + + return output + + +class SinusoidalPositionEmbedding(nn.Module): + """ + Positional embedding based on sine and cosine functions, proposed in `Attention Is All You Need`_. + + .. _Attention Is All You Need: + https://arxiv.org/pdf/1706.03762.pdf + + Parameters: + output_dim (int): output dimension + """ + + def __init__(self, output_dim): + super(SinusoidalPositionEmbedding, self).__init__() + inverse_frequency = 1 / (10000 ** (torch.arange(0.0, output_dim, 2.0) / output_dim)) + self.register_buffer("inverse_frequency", inverse_frequency) + + def forward(self, input): + """""" + # input: [B, L, ...] + positions = torch.arange(input.shape[1] - 1, -1, -1.0, dtype=input.dtype, device=input.device) + sinusoidal_input = torch.outer(positions, self.inverse_frequency) + position_embedding = torch.cat([sinusoidal_input.sin(), sinusoidal_input.cos()], -1) + return position_embedding \ No newline at end of file diff --git a/build/lib/torchdrug/layers/conv.py b/build/lib/torchdrug/layers/conv.py new file mode 100644 index 00000000..2efaed38 --- /dev/null +++ b/build/lib/torchdrug/layers/conv.py @@ -0,0 +1,813 @@ +import functools + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import checkpoint +from torch_scatter import scatter_mean, scatter_add, scatter_max + +from torchdrug import data, layers, utils +from torchdrug.layers import functional + + +class MessagePassingBase(nn.Module): + """ + Base module for message passing. + + Any custom message passing module should be derived from this class. + """ + gradient_checkpoint = False + + def message(self, graph, input): + """ + Compute edge messages for the graph. + + Parameters: + graph (Graph): graph(s) + input (Tensor): node representations of shape :math:`(|V|, ...)` + + Returns: + Tensor: edge messages of shape :math:`(|E|, ...)` + """ + raise NotImplementedError + + def aggregate(self, graph, message): + """ + Aggregate edge messages to nodes. + + Parameters: + graph (Graph): graph(s) + message (Tensor): edge messages of shape :math:`(|E|, ...)` + + Returns: + Tensor: node updates of shape :math:`(|V|, ...)` + """ + raise NotImplementedError + + def message_and_aggregate(self, graph, input): + """ + Fused computation of message and aggregation over the graph. + This may provide better time or memory complexity than separate calls of + :meth:`message ` and :meth:`aggregate `. + + Parameters: + graph (Graph): graph(s) + input (Tensor): node representations of shape :math:`(|V|, ...)` + + Returns: + Tensor: node updates of shape :math:`(|V|, ...)` + """ + message = self.message(graph, input) + update = self.aggregate(graph, message) + return update + + def _message_and_aggregate(self, *tensors): + graph = data.Graph.from_tensors(tensors[:-1]) + input = tensors[-1] + update = self.message_and_aggregate(graph, input) + return update + + def combine(self, input, update): + """ + Combine node input and node update. + + Parameters: + input (Tensor): node representations of shape :math:`(|V|, ...)` + update (Tensor): node updates of shape :math:`(|V|, ...)` + """ + raise NotImplementedError + + def forward(self, graph, input): + """ + Perform message passing over the graph(s). + + Parameters: + graph (Graph): graph(s) + input (Tensor): node representations of shape :math:`(|V|, ...)` + """ + if self.gradient_checkpoint: + update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input) + else: + update = self.message_and_aggregate(graph, input) + output = self.combine(input, update) + return output + + +class GraphConv(MessagePassingBase): + """ + Graph convolution operator from `Semi-Supervised Classification with Graph Convolutional Networks`_. + + .. _Semi-Supervised Classification with Graph Convolutional Networks: + https://arxiv.org/pdf/1609.02907.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): output dimension + edge_input_dim (int, optional): dimension of edge features + batch_norm (bool, optional): apply batch normalization on nodes or not + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, output_dim, edge_input_dim=None, batch_norm=False, activation="relu"): + super(GraphConv, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.edge_input_dim = edge_input_dim + + if batch_norm: + self.batch_norm = nn.BatchNorm1d(output_dim) + else: + self.batch_norm = None + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + self.linear = nn.Linear(input_dim, output_dim) + if edge_input_dim: + self.edge_linear = nn.Linear(edge_input_dim, input_dim) + else: + self.edge_linear = None + + def message(self, graph, input): + # add self loop + node_in = torch.cat([graph.edge_list[:, 0], torch.arange(graph.num_node, device=graph.device)]) + degree_in = graph.degree_in.unsqueeze(-1) + 1 + message = input[node_in] + if self.edge_linear: + edge_input = self.edge_linear(graph.edge_feature.float()) + edge_input = torch.cat([edge_input, torch.zeros(graph.num_node, self.input_dim, device=graph.device)]) + message += edge_input + message /= (degree_in[node_in].sqrt() + 1e-10) + return message + + def aggregate(self, graph, message): + # add self loop + node_out = torch.cat([graph.edge_list[:, 1], torch.arange(graph.num_node, device=graph.device)]) + edge_weight = torch.cat([graph.edge_weight, torch.ones(graph.num_node, device=graph.device)]) + edge_weight = edge_weight.unsqueeze(-1) + degree_out = graph.degree_out.unsqueeze(-1) + 1 + update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) + update = update / (degree_out.sqrt() + 1e-10) + return update + + def message_and_aggregate(self, graph, input): + node_in, node_out = graph.edge_list.t()[:2] + node_in = torch.cat([node_in, torch.arange(graph.num_node, device=graph.device)]) + node_out = torch.cat([node_out, torch.arange(graph.num_node, device=graph.device)]) + edge_weight = torch.cat([graph.edge_weight, torch.ones(graph.num_node, device=graph.device)]) + degree_in = graph.degree_in + 1 + degree_out = graph.degree_out + 1 + edge_weight = edge_weight / ((degree_in[node_in] * degree_out[node_out]).sqrt() + 1e-10) + adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight, + (graph.num_node, graph.num_node)) + update = torch.sparse.mm(adjacency.t(), input) + if self.edge_linear: + edge_input = graph.edge_feature.float() + edge_input = torch.cat([self.edge_linear(edge_input), torch.zeros(graph.num_node, self.input_dim, device=graph.device)]) + edge_weight = edge_weight.unsqueeze(-1) + node_out = torch.cat([graph.edge_list[:, 1], torch.arange(graph.num_node, device=graph.device)]) + edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0, dim_size=graph.num_node) + update += edge_update + + return update + + def combine(self, input, update): + output = self.linear(update) + if self.batch_norm: + output = self.batch_norm(output) + if self.activation: + output = self.activation(output) + return output + + +class GraphAttentionConv(MessagePassingBase): + """ + Graph attentional convolution operator from `Graph Attention Networks`_. + + .. _Graph Attention Networks: + https://arxiv.org/pdf/1710.10903.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): output dimension + edge_input_dim (int, optional): dimension of edge features + num_head (int, optional): number of attention heads + negative_slope (float, optional): negative slope of leaky relu activation + batch_norm (bool, optional): apply batch normalization on nodes or not + activation (str or function, optional): activation function + """ + + eps = 1e-10 + + def __init__(self, input_dim, output_dim, edge_input_dim=None, num_head=1, negative_slope=0.2, concat=True, + batch_norm=False, activation="relu"): + super(GraphAttentionConv, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.edge_input_dim = edge_input_dim + self.num_head = num_head + self.concat = concat + self.leaky_relu = functools.partial(F.leaky_relu, negative_slope=negative_slope) + + if batch_norm: + self.batch_norm = nn.BatchNorm1d(output_dim) + else: + self.batch_norm = None + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + if output_dim % num_head != 0: + raise ValueError("Expect output_dim to be a multiplier of num_head, but found `%d` and `%d`" + % (output_dim, num_head)) + + self.linear = nn.Linear(input_dim, output_dim) + if edge_input_dim: + self.edge_linear = nn.Linear(edge_input_dim, output_dim) + else: + self.edge_linear = None + self.query = nn.Parameter(torch.zeros(num_head, output_dim * 2 // num_head)) + nn.init.kaiming_uniform_(self.query, negative_slope, mode="fan_in") + + def message(self, graph, input): + # add self loop + node_in = torch.cat([graph.edge_list[:, 0], torch.arange(graph.num_node, device=graph.device)]) + node_out = torch.cat([graph.edge_list[:, 1], torch.arange(graph.num_node, device=graph.device)]) + edge_weight = torch.cat([graph.edge_weight, torch.ones(graph.num_node, device=graph.device)]) + edge_weight = edge_weight.unsqueeze(-1) + hidden = self.linear(input) + + key = torch.stack([hidden[node_in], hidden[node_out]], dim=-1) + if self.edge_linear: + edge_input = self.edge_linear(graph.edge_feature.float()) + edge_input = torch.cat([edge_input, torch.zeros(graph.num_node, self.output_dim, device=graph.device)]) + key += edge_input.unsqueeze(-1) + key = key.view(-1, *self.query.shape) + weight = torch.einsum("hd, nhd -> nh", self.query, key) + weight = self.leaky_relu(weight) + + weight = weight - scatter_max(weight, node_out, dim=0, dim_size=graph.num_node)[0][node_out] + attention = weight.exp() * edge_weight + # why mean? because with mean we have normalized message scale across different node degrees + normalizer = scatter_mean(attention, node_out, dim=0, dim_size=graph.num_node)[node_out] + attention = attention / (normalizer + self.eps) + + value = hidden[node_in].view(-1, self.num_head, self.query.shape[-1] // 2) + attention = attention.unsqueeze(-1).expand_as(value) + message = (attention * value).flatten(1) + return message + + def aggregate(self, graph, message): + # add self loop + node_out = torch.cat([graph.edge_list[:, 1], torch.arange(graph.num_node, device=graph.device)]) + update = scatter_mean(message, node_out, dim=0, dim_size=graph.num_node) + return update + + def combine(self, input, update): + output = update + if self.batch_norm: + output = self.batch_norm(output) + if self.activation: + output = self.activation(output) + return output + + +class GraphIsomorphismConv(MessagePassingBase): + """ + Graph isomorphism convolution operator from `How Powerful are Graph Neural Networks?`_ + + .. _How Powerful are Graph Neural Networks?: + https://arxiv.org/pdf/1810.00826.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): output dimension + edge_input_dim (int, optional): dimension of edge features + hidden_dims (list of int, optional): hidden dimensions + eps (float, optional): initial epsilon + learn_eps (bool, optional): learn epsilon or not + batch_norm (bool, optional): apply batch normalization on nodes or not + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, output_dim, edge_input_dim=None, hidden_dims=None, eps=0, learn_eps=False, + batch_norm=False, activation="relu"): + super(GraphIsomorphismConv, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.edge_input_dim = edge_input_dim + + eps = torch.tensor([eps], dtype=torch.float32) + if learn_eps: + self.eps = nn.Parameter(eps) + else: + self.register_buffer("eps", eps) + if batch_norm: + self.batch_norm = nn.BatchNorm1d(output_dim) + else: + self.batch_norm = None + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + if hidden_dims is None: + hidden_dims = [] + self.mlp = layers.MLP(input_dim, list(hidden_dims) + [output_dim], activation) + if edge_input_dim: + self.edge_linear = nn.Linear(edge_input_dim, input_dim) + else: + self.edge_linear = None + + def message(self, graph, input): + node_in = graph.edge_list[:, 0] + message = input[node_in] + if self.edge_linear: + message += self.edge_linear(graph.edge_feature.float()) + return message + + def aggregate(self, graph, message): + node_out = graph.edge_list[:, 1] + edge_weight = graph.edge_weight.unsqueeze(-1) + update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) + return update + + def message_and_aggregate(self, graph, input): + adjacency = utils.sparse_coo_tensor(graph.edge_list.t()[:2], graph.edge_weight, + (graph.num_node, graph.num_node)) + update = torch.sparse.mm(adjacency.t(), input) + if self.edge_linear: + edge_input = graph.edge_feature.float() + edge_weight = graph.edge_weight.unsqueeze(-1) + edge_input = self.edge_linear(edge_input) + edge_update = scatter_add(edge_input * edge_weight, graph.edge_list[:, 1], dim=0, + dim_size=graph.num_node) + update += edge_update + + return update + + def combine(self, input, update): + output = self.mlp((1 + self.eps) * input + update) + if self.batch_norm: + output = self.batch_norm(output) + if self.activation: + output = self.activation(output) + return output + + +class RelationalGraphConv(MessagePassingBase): + """ + Relational graph convolution operator from `Modeling Relational Data with Graph Convolutional Networks`_. + + .. _Modeling Relational Data with Graph Convolutional Networks: + https://arxiv.org/pdf/1703.06103.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): output dimension + num_relation (int): number of relations + edge_input_dim (int, optional): dimension of edge features + batch_norm (bool, optional): apply batch normalization on nodes or not + activation (str or function, optional): activation function + """ + eps = 1e-10 + + def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"): + super(RelationalGraphConv, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.num_relation = num_relation + self.edge_input_dim = edge_input_dim + + if batch_norm: + self.batch_norm = nn.BatchNorm1d(output_dim) + else: + self.batch_norm = None + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + self.self_loop = nn.Linear(input_dim, output_dim) + self.linear = nn.Linear(num_relation * input_dim, output_dim) + if edge_input_dim: + self.edge_linear = nn.Linear(edge_input_dim, input_dim) + else: + self.edge_linear = None + + def message(self, graph, input): + node_in = graph.edge_list[:, 0] + message = input[node_in] + if self.edge_linear: + message += self.edge_linear(graph.edge_feature.float()) + return message + + def aggregate(self, graph, message): + assert graph.num_relation == self.num_relation + + node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2] + edge_weight = graph.edge_weight.unsqueeze(-1) + update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node * self.num_relation) / \ + (scatter_add(edge_weight, node_out, dim=0, dim_size=graph.num_node * self.num_relation) + self.eps) + return update.view(graph.num_node, self.num_relation * self.input_dim) + + def message_and_aggregate(self, graph, input): + assert graph.num_relation == self.num_relation + + node_in, node_out, relation = graph.edge_list.t() + node_out = node_out * self.num_relation + relation + degree_out = scatter_add(graph.edge_weight, node_out, dim_size=graph.num_node * graph.num_relation) + edge_weight = graph.edge_weight / degree_out[node_out] + adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight, + (graph.num_node, graph.num_node * graph.num_relation)) + update = torch.sparse.mm(adjacency.t(), input) + if self.edge_linear: + edge_input = graph.edge_feature.float() + edge_input = self.edge_linear(edge_input) + edge_weight = edge_weight.unsqueeze(-1) + edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0, + dim_size=graph.num_node * graph.num_relation) + update += edge_update + + return update.view(graph.num_node, self.num_relation * self.input_dim) + + def combine(self, input, update): + output = self.linear(update) + self.self_loop(input) + if self.batch_norm: + output = self.batch_norm(output) + if self.activation: + output = self.activation(output) + return output + + +class NeuralFingerprintConv(MessagePassingBase): + """ + Graph neural network operator from `Convolutional Networks on Graphs for Learning Molecular Fingerprints`_. + + Note this operator doesn't include the sparsifying step of the original paper. + + .. _Convolutional Networks on Graphs for Learning Molecular Fingerprints: + https://arxiv.org/pdf/1509.09292.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): output dimension + edge_input_dim (int, optional): dimension of edge features + batch_norm (bool, optional): apply batch normalization on nodes or not + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, output_dim, edge_input_dim=None, batch_norm=False, activation="relu"): + super(NeuralFingerprintConv, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.edge_input_dim = edge_input_dim + + if batch_norm: + self.batch_norm = nn.BatchNorm1d(output_dim) + else: + self.batch_norm = None + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + self.linear = nn.Linear(input_dim, output_dim) + if edge_input_dim: + self.edge_linear = nn.Linear(edge_input_dim, input_dim) + else: + self.edge_linear = None + + def message(self, graph, input): + node_in = graph.edge_list[:, 0] + message = input[node_in] + if self.edge_linear: + message += self.edge_linear(graph.edge_feature.float()) + return message + + def aggregate(self, graph, message): + node_out = graph.edge_list[:, 1] + edge_weight = graph.edge_weight.unsqueeze(-1) + update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) + return update + + def message_and_aggregate(self, graph, input): + adjacency = utils.sparse_coo_tensor(graph.edge_list.t()[:2], graph.edge_weight, + (graph.num_node, graph.num_node)) + update = torch.sparse.mm(adjacency.t(), input) + if self.edge_linear: + edge_input = graph.edge_feature.float() + edge_weight = graph.edge_weight.unsqueeze(-1) + edge_input = self.edge_linear(edge_input) + edge_update = scatter_add(edge_input * edge_weight, graph.edge_list[:, 1], dim=0, + dim_size=graph.num_node) + update += edge_update + + return update + + def combine(self, input, update): + output = self.linear(input + update) + if self.batch_norm: + output = self.batch_norm(output) + if self.activation: + output = self.activation(output) + return output + + +class ContinuousFilterConv(MessagePassingBase): + """ + Continuous filter operator from + `SchNet: A continuous-filter convolutional neural network for modeling quantum interactions`_. + + .. _SchNet\: A continuous-filter convolutional neural network for modeling quantum interactions: + https://arxiv.org/pdf/1706.08566.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): output dimension + edge_input_dim (int, optional): dimension of edge features + hidden_dim (int, optional): hidden dimension. By default, same as :attr:`output_dim` + cutoff (float, optional): maximal scale for RBF kernels + num_gaussian (int, optional): number of RBF kernels + batch_norm (bool, optional): apply batch normalization on nodes or not + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, output_dim, edge_input_dim=None, hidden_dim=None, cutoff=5, num_gaussian=100, + batch_norm=False, activation="shifted_softplus"): + super(ContinuousFilterConv, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.edge_input_dim = edge_input_dim + if hidden_dim is None: + hidden_dim = output_dim + self.hidden_dim = hidden_dim + self.rbf = layers.RBF(stop=cutoff, num_kernel=num_gaussian) + + if batch_norm: + self.batch_norm = nn.BatchNorm1d(output_dim) + else: + self.batch_norm = None + if activation == "shifted_softplus": + self.activation = functional.shifted_softplus + elif isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + self.input_layer = nn.Linear(input_dim, hidden_dim) + self.rbf_layer = nn.Linear(num_gaussian, hidden_dim) + self.output_layer = nn.Linear(hidden_dim, output_dim) + if edge_input_dim: + self.edge_linear = nn.Linear(edge_input_dim, hidden_dim) + else: + self.edge_linear = None + + def message(self, graph, input): + node_in, node_out = graph.edge_list.t()[:2] + position = graph.node_position + message = self.input_layer(input)[node_in] + if self.edge_linear: + message += self.edge_linear(graph.edge_feature.float()) + weight = self.rbf_layer(self.rbf(position[node_in], position[node_out])) + message *= weight + return message + + def aggregate(self, graph, message): + node_out = graph.edge_list[:, 1] + edge_weight = graph.edge_weight.unsqueeze(-1) + update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) + return update + + def message_and_aggregate(self, graph, input): + node_in, node_out = graph.edge_list.t()[:2] + position = graph.node_position + rbf_weight = self.rbf_layer(self.rbf(position[node_in], position[node_out])) + indices = torch.stack([node_out, node_in, torch.arange(graph.num_edge, device=graph.device)]) + adjacency = utils.sparse_coo_tensor(indices, graph.edge_weight, (graph.num_node, graph.num_node, graph.num_edge)) + update = functional.generalized_rspmm(adjacency, rbf_weight, self.input_layer(input)) + if self.edge_linear: + edge_input = graph.edge_feature.float() + edge_input = self.edge_linear(edge_input) + edge_weight = graph.edge_weight.unsqueeze(-1) * rbf_weight + edge_update = scatter_add(edge_input * edge_weight, graph.edge_list[:, 1], dim=0, + dim_size=graph.num_node) + update += edge_update + + return update + + def combine(self, input, update): + output = self.output_layer(update) + if self.batch_norm: + output = self.batch_norm(output) + if self.activation: + output = self.activation(output) + return output + + +class MessagePassing(MessagePassingBase): + """ + Message passing operator from `Neural Message Passing for Quantum Chemistry`_. + + This implements the edge network variant in the original paper. + + .. _Neural Message Passing for Quantum Chemistry: + https://arxiv.org/pdf/1704.01212.pdf + + Parameters: + input_dim (int): input dimension + edge_input_dim (int): dimension of edge features + hidden_dims (list of int, optional): hidden dims of edge network + batch_norm (bool, optional): apply batch normalization on nodes or not + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, edge_input_dim, hidden_dims=None, batch_norm=False, activation="relu"): + super(MessagePassing, self).__init__() + self.input_dim = input_dim + self.output_dim = input_dim + self.edge_input_dim = edge_input_dim + if hidden_dims is None: + hidden_dims = [] + + if batch_norm: + self.batch_norm = nn.BatchNorm1d(input_dim) + else: + self.batch_norm = None + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + self.edge_mlp = layers.MLP(edge_input_dim, list(hidden_dims) + [input_dim * input_dim], activation) + + def message(self, graph, input): + node_in = graph.edge_list[:, 0] + transform = self.edge_mlp(graph.edge_feature.float()).view(-1, self.input_dim, self.input_dim) + if graph.num_edge: + message = torch.einsum("bed, bd -> be", transform, input[node_in]) + else: + message = torch.zeros(0, self.input_dim, device=graph.device) + return message + + def aggregate(self, graph, message): + node_out = graph.edge_list[:, 1] + edge_weight = graph.edge_weight.unsqueeze(-1) + update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) + return update + + def combine(self, input, update): + output = update + if self.batch_norm: + output = self.batch_norm(output) + if self.activation: + output = self.activation(output) + return output + + +class ChebyshevConv(MessagePassingBase): + """ + Chebyshev spectral graph convolution operator from + `Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering`_. + + .. _Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering: + https://arxiv.org/pdf/1606.09375.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): output dimension + edge_input_dim (int, optional): dimension of edge features + k (int, optional): number of Chebyshev polynomials. + This also corresponds to the radius of the receptive field. + hidden_dims (list of int, optional): hidden dims of edge network + batch_norm (bool, optional): apply batch normalization on nodes or not + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, output_dim, edge_input_dim=None, k=1, batch_norm=False, activation="relu"): + super(ChebyshevConv, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.k = k + self.edge_input_dim = edge_input_dim + + if batch_norm: + self.batch_norm = nn.BatchNorm1d(output_dim) + else: + self.batch_norm = None + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + self.linear = nn.Linear((k + 1) * input_dim, output_dim) + if edge_input_dim: + self.edge_linear = nn.Linear(edge_input_dim, input_dim) + else: + self.edge_linear = None + + def message(self, graph, input): + node_in = graph.edge_list[:, 0] + degree_in = graph.degree_in.unsqueeze(-1) + # because self-loop messages have a different scale, they are processed in combine() + message = input[node_in] + if self.edge_linear: + message += self.edge_linear(graph.edge_feature.float()) + message /= (degree_in[node_in].sqrt() + 1e-10) + return message + + def aggregate(self, graph, message): + node_out = graph.edge_list[:, 1] + edge_weight = graph.edge_weight.unsqueeze(-1) + degree_out = graph.degree_out.unsqueeze(-1) + # because self-loop messages have a different scale, they are processed in combine() + update = -scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) + update = update / (degree_out.sqrt() + 1e-10) + return update + + def message_and_aggregate(self, graph, input): + node_in, node_out = graph.edge_list.t()[:2] + edge_weight = -graph.edge_weight / ((graph.degree_in[node_in] * graph.degree_out[node_out]).sqrt() + 1e-10) + adjacency = utils.sparse_coo_tensor(graph.edge_list.t()[:2], edge_weight, (graph.num_node, graph.num_node)) + update = torch.sparse.mm(adjacency.t(), input) + if self.edge_linear: + edge_input = graph.edge_feature.float() + edge_input = self.edge_linear(edge_input) + edge_weight = edge_weight.unsqueeze(-1) + edge_update = scatter_add(edge_input * edge_weight, graph.edge_list[:, 1], dim=0, + dim_size=graph.num_node) + update += edge_update + + return update + + def forward(self, graph, input): + # Chebyshev polynomial bases + bases = [input] + for i in range(self.k): + x = super(ChebyshevConv, self).forward(graph, bases[-1]) + if i > 0: + x = 2 * x - bases[-2] + bases.append(x) + bases = torch.cat(bases, dim=-1) + + output = self.linear(bases) + if self.batch_norm: + x = self.batch_norm(output) + if self.activation: + output = self.activation(output) + return output + + def combine(self, input, update): + output = input + update + return output + + +class GeometricRelationalGraphConv(RelationalGraphConv): + """ + Geometry-aware relational graph convolution operator from + `Protein Representation Learning by Geometric Structure Pretraining`_. + + .. _Protein Representation Learning by Geometric Structure Pretraining: + https://arxiv.org/pdf/2203.06125.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): output dimension + num_relation (int): number of relations + edge_input_dim (int, optional): dimension of edge features + batch_norm (bool, optional): apply batch normalization on nodes or not + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"): + super(GeometricRelationalGraphConv, self).__init__(input_dim, output_dim, num_relation, edge_input_dim, + batch_norm, activation) + + def aggregate(self, graph, message): + assert graph.num_relation == self.num_relation + + node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2] + edge_weight = graph.edge_weight.unsqueeze(-1) + update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node * self.num_relation) + update = update.view(graph.num_node, self.num_relation * self.input_dim) + + return update + + def message_and_aggregate(self, graph, input): + assert graph.num_relation == self.num_relation + + node_in, node_out, relation = graph.edge_list.t() + node_out = node_out * self.num_relation + relation + adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), graph.edge_weight, + (graph.num_node, graph.num_node * graph.num_relation)) + update = torch.sparse.mm(adjacency.t(), input) + if self.edge_linear: + edge_input = graph.edge_feature.float() + edge_input = self.edge_linear(edge_input) + edge_weight = graph.edge_weight.unsqueeze(-1) + edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0, + dim_size=graph.num_node * graph.num_relation) + update += edge_update + + return update.view(graph.num_node, self.num_relation * self.input_dim) diff --git a/build/lib/torchdrug/layers/distribution.py b/build/lib/torchdrug/layers/distribution.py new file mode 100644 index 00000000..bdb0ebb5 --- /dev/null +++ b/build/lib/torchdrug/layers/distribution.py @@ -0,0 +1,50 @@ +import math +from collections.abc import Sequence + +import torch +from torch import nn + + +class IndependentGaussian(nn.Module): + """ + Independent Gaussian distribution. + + Parameters: + mu (Tensor): mean of shape :math:`(N,)` + sigma2 (Tensor): variance of shape :math:`(N,)` + learnable (bool, optional): learnable parameters or not + """ + + def __init__(self, mu, sigma2, learnable=False): + super(IndependentGaussian, self).__init__() + if learnable: + self.mu = nn.Parameter(torch.as_tensor(mu)) + self.sigma2 = nn.Parameter(torch.as_tensor(sigma2)) + else: + self.register_buffer("mu", torch.as_tensor(mu)) + self.register_buffer("sigma2", torch.as_tensor(sigma2)) + self.dim = len(mu) + + def forward(self, input): + """ + Compute the likelihood of input data. + + Parameters: + input (Tensor): input data of shape :math:`(..., N)` + """ + log_likelihood = -0.5 * (math.log(2 * math.pi) + self.sigma2.log() + (input - self.mu) ** 2 / self.sigma2) + return log_likelihood + + def sample(self, *size): + """ + Draw samples from the distribution. + + Parameters: + size (tuple of int): shape of the samples + """ + if len(size) == 1 and isinstance(size[0], Sequence): + size = size[0] + size = list(size) + [self.dim] + + sample = torch.randn(size, device=self.mu.device) * self.sigma2.sqrt() + self.mu + return sample diff --git a/build/lib/torchdrug/layers/flow.py b/build/lib/torchdrug/layers/flow.py new file mode 100644 index 00000000..ab0b27f4 --- /dev/null +++ b/build/lib/torchdrug/layers/flow.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from torchdrug import layers + + +class ConditionalFlow(nn.Module): + """ + Conditional flow transformation from `Masked Autoregressive Flow for Density Estimation`_. + + .. _Masked Autoregressive Flow for Density Estimation: + https://arxiv.org/pdf/1705.07057.pdf + + Parameters: + input_dim (int): input & output dimension + condition_dim (int): condition dimension + hidden_dims (list of int, optional): hidden dimensions + activation (str or function, optional): activation function + """ + + def __init__(self, input_dim, condition_dim, hidden_dims=None, activation="relu"): + super(ConditionalFlow, self).__init__() + self.input_dim = input_dim + self.output_dim = input_dim + + if hidden_dims is None: + hidden_dims = [] + self.mlp = layers.MLP(condition_dim, list(hidden_dims) + [input_dim * 2], activation) + self.rescale = nn.Parameter(torch.zeros(1)) + + def forward(self, input, condition): + """ + Transform data into latent representations. + + Parameters: + input (Tensor): input representations + condition (Tensor): conditional representations + + Returns: + (Tensor, Tensor): latent representations, log-likelihood of the transformation + """ + scale, bias = self.mlp(condition).chunk(2, dim=-1) + scale = (F.tanh(scale) * self.rescale) + output = (input + bias) * scale.exp() + log_det = scale + return output, log_det + + def reverse(self, latent, condition): + """ + Transform latent representations into data. + + Parameters: + latent (Tensor): latent representations + condition (Tensor): conditional representations + + Returns: + (Tensor, Tensor): input representations, log-likelihood of the transformation + """ + scale, bias = self.mlp(condition).chunk(2, dim=-1) + scale = (F.tanh(scale) * self.rescale) + output = latent / scale.exp() - bias + log_det = scale + return output, log_det \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/__init__.py b/build/lib/torchdrug/layers/functional/__init__.py new file mode 100644 index 00000000..36b2988e --- /dev/null +++ b/build/lib/torchdrug/layers/functional/__init__.py @@ -0,0 +1,17 @@ +from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice, multi_slice_mask, \ + as_mask, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, variadic_max, \ + variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, variadic_sample,\ + variadic_meshgrid, variadic_to_padded, padded_to_variadic, one_hot, clipped_policy_gradient_objective, \ + policy_gradient_objective +from .embedding import transe_score, distmult_score, complex_score, simple_score, rotate_score +from .spmm import generalized_spmm, generalized_rspmm + +__all__ = [ + "multinomial", "masked_mean", "mean_with_nan", "shifted_softplus", "multi_slice_mask", "as_mask", + "variadic_log_softmax", "variadic_softmax", "variadic_sum", "variadic_mean", "variadic_max", + "variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm", + "variadic_sample", "variadic_meshgrid", "variadic_to_padded", "padded_to_variadic", + "one_hot", "clipped_policy_gradient_objective", "policy_gradient_objective", + "transe_score", "distmult_score", "complex_score", "simple_score", "rotate_score", + "generalized_spmm", "generalized_rspmm", +] \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/embedding.py b/build/lib/torchdrug/layers/functional/embedding.py new file mode 100644 index 00000000..eb422ae8 --- /dev/null +++ b/build/lib/torchdrug/layers/functional/embedding.py @@ -0,0 +1,269 @@ +import os + +import torch +from torch import autograd + +from torchdrug import utils + +backend = "fast" + +path = os.path.join(os.path.dirname(__file__), "extension") +embedding = utils.load_extension("embedding", + [os.path.join(path, "embedding.cpp"), os.path.join(path, "embedding.cu")]) + + +class TransEFunction(autograd.Function): + + @staticmethod + def forward(ctx, entity, relation, h_index, t_index, r_index): + if entity.device.type == "cuda": + forward = embedding.transe_forward_cuda + else: + forward = embedding.transe_forward_cpu + score = forward(entity, relation, h_index, t_index, r_index) + ctx.save_for_backward(entity, relation, h_index, t_index, r_index) + return score + + @staticmethod + def backward(ctx, score_grad): + if score_grad.device.type == "cuda": + backward = embedding.transe_backward_cuda + else: + backward = embedding.transe_backward_cpu + entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad) + return entity_grad, relation_grad, None, None, None + + +class DistMultFunction(autograd.Function): + + @staticmethod + def forward(ctx, entity, relation, h_index, t_index, r_index): + if entity.device.type == "cuda": + forward = embedding.distmult_forward_cuda + else: + forward = embedding.distmult_forward_cpu + score = forward(entity, relation, h_index, t_index, r_index) + ctx.save_for_backward(entity, relation, h_index, t_index, r_index) + return score + + @staticmethod + def backward(ctx, score_grad): + if score_grad.device.type == "cuda": + backward = embedding.distmult_backward_cuda + else: + backward = embedding.distmult_backward_cpu + entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad) + return entity_grad, relation_grad, None, None, None + + +class ComplExFunction(autograd.Function): + + @staticmethod + def forward(ctx, entity, relation, h_index, t_index, r_index): + if entity.device.type == "cuda": + forward = embedding.complex_forward_cuda + else: + forward = embedding.complex_forward_cpu + score = forward(entity, relation, h_index, t_index, r_index) + ctx.save_for_backward(entity, relation, h_index, t_index, r_index) + return score + + @staticmethod + def backward(ctx, score_grad): + if score_grad.device.type == "cuda": + backward = embedding.complex_backward_cuda + else: + backward = embedding.complex_backward_cpu + entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad) + return entity_grad, relation_grad, None, None, None + + +class SimplEFunction(autograd.Function): + + @staticmethod + def forward(ctx, entity, relation, h_index, t_index, r_index): + if entity.device.type == "cuda": + forward = embedding.simple_forward_cuda + else: + forward = embedding.simple_forward_cpu + score = forward(entity, relation, h_index, t_index, r_index) + ctx.save_for_backward(entity, relation, h_index, t_index, r_index) + return score + + @staticmethod + def backward(ctx, score_grad): + if score_grad.device.type == "cuda": + backward = embedding.simple_backward_cuda + else: + backward = embedding.simple_backward_cpu + entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad) + return entity_grad, relation_grad, None, None, None + + +class RotatEFunction(autograd.Function): + + @staticmethod + def forward(ctx, entity, relation, h_index, t_index, r_index): + if entity.device.type == "cuda": + forward = embedding.rotate_forward_cuda + else: + forward = embedding.rotate_forward_cpu + score = forward(entity, relation, h_index, t_index, r_index) + ctx.save_for_backward(entity, relation, h_index, t_index, r_index) + return score + + @staticmethod + def backward(ctx, score_grad): + if score_grad.device.type == "cuda": + backward = embedding.rotate_backward_cuda + else: + backward = embedding.rotate_backward_cpu + entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad) + return entity_grad, relation_grad, None, None, None + + +def transe_score(entity, relation, h_index, t_index, r_index): + """ + TransE score function from `Translating Embeddings for Modeling Multi-relational Data`_. + + .. _Translating Embeddings for Modeling Multi-relational Data: + https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf + + Parameters: + entity (Tensor): entity embeddings of shape :math:`(|V|, d)` + relation (Tensor): relation embeddings of shape :math:`(|R|, d)` + h_index (LongTensor): index of head entities + t_index (LongTensor): index of tail entities + r_index (LongTensor): index of relations + """ + if backend == "native": + h = entity[h_index] + r = relation[r_index] + t = entity[t_index] + score = (h + r - t).norm(p=1, dim=-1) + elif backend == "fast": + score = TransEFunction.apply(entity, relation, h_index, t_index, r_index) + else: + raise ValueError("Unknown embedding backend `%s`" % backend) + return score + + +def distmult_score(entity, relation, h_index, t_index, r_index): + """ + DistMult score function from `Embedding Entities and Relations for Learning and Inference in Knowledge Bases`_. + + .. _Embedding Entities and Relations for Learning and Inference in Knowledge Bases: + https://arxiv.org/pdf/1412.6575.pdf + + Parameters: + entity (Tensor): entity embeddings of shape :math:`(|V|, d)` + relation (Tensor): relation embeddings of shape :math:`(|R|, d)` + h_index (LongTensor): index of head entities + t_index (LongTensor): index of tail entities + r_index (LongTensor): index of relations + """ + if backend == "native": + h = entity[h_index] + r = relation[r_index] + t = entity[t_index] + score = (h * r * t).sum(dim=-1) + elif backend == "fast": + score = DistMultFunction.apply(entity, relation, h_index, t_index, r_index) + else: + raise ValueError("Unknown embedding backend `%s`" % backend) + return score + + +def complex_score(entity, relation, h_index, t_index, r_index): + """ + ComplEx score function from `Complex Embeddings for Simple Link Prediction`_. + + .. _Complex Embeddings for Simple Link Prediction: + http://proceedings.mlr.press/v48/trouillon16.pdf + + Parameters: + entity (Tensor): entity embeddings of shape :math:`(|V|, 2d)` + relation (Tensor): relation embeddings of shape :math:`(|R|, 2d)` + h_index (LongTensor): index of head entities + t_index (LongTensor): index of tail entities + r_index (LongTensor): index of relations + """ + if backend == "native": + h = entity[h_index] + r = relation[r_index] + t = entity[t_index] + + h_re, h_im = h.chunk(2, dim=-1) + r_re, r_im = r.chunk(2, dim=-1) + t_re, t_im = t.chunk(2, dim=-1) + + x_re = h_re * r_re - h_im * r_im + x_im = h_re * r_im + h_im * r_re + x = x_re * t_re + x_im * t_im + score = x.sum(dim=-1) + elif backend == "fast": + score = ComplExFunction.apply(entity, relation, h_index, t_index, r_index) + else: + raise ValueError("Unknown embedding backend `%s`" % backend) + return score + + +def simple_score(entity, relation, h_index, t_index, r_index): + """ + SimplE score function from `SimplE Embedding for Link Prediction in Knowledge Graphs`_. + + .. _SimplE Embedding for Link Prediction in Knowledge Graphs: + https://papers.nips.cc/paper/2018/file/b2ab001909a8a6f04b51920306046ce5-Paper.pdf + + Parameters: + entity (Tensor): entity embeddings of shape :math:`(|V|, 2d)` + relation (Tensor): relation embeddings of shape :math:`(|R|, d)` + h_index (LongTensor): index of head entities + t_index (LongTensor): index of tail entities + r_index (LongTensor): index of relations + """ + if backend == "native": + h = entity[h_index] + r = relation[r_index] + t = entity[t_index] + t_flipped = torch.cat(t.chunk(2, dim=-1)[::-1], dim=-1) + score = (h * r * t_flipped).sum(dim=-1) + elif backend == "fast": + score = SimplEFunction.apply(entity, relation, h_index, t_index, r_index) + else: + raise ValueError("Unknown embedding backend `%s`" % backend) + return score + + +def rotate_score(entity, relation, h_index, t_index, r_index): + """ + RotatE score function from `RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space`_. + + .. _RotatE\: Knowledge Graph Embedding by Relational Rotation in Complex Space: + https://arxiv.org/pdf/1902.10197.pdf + + Parameters: + entity (Tensor): entity embeddings of shape :math:`(|V|, 2d)` + relation (Tensor): relation embeddings of shape :math:`(|R|, d)` + h_index (LongTensor): index of head entities + t_index (LongTensor): index of tail entities + r_index (LongTensor): index of relations + """ + if backend == "native": + h = entity[h_index] + r = relation[r_index] + t = entity[t_index] + + h_re, h_im = h.chunk(2, dim=-1) + r_re, r_im = torch.cos(r), torch.sin(r) + t_re, t_im = t.chunk(2, dim=-1) + + x_re = h_re * r_re - h_im * r_im - t_re + x_im = h_re * r_im + h_im * r_re - t_im + x = torch.stack([x_re, x_im], dim=-1) + score = x.norm(p=2, dim=-1).sum(dim=-1) + elif backend == "fast": + score = RotatEFunction.apply(entity, relation, h_index, t_index, r_index) + else: + raise ValueError("Unknown embedding backend `%s`" % backend) + return score \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/__init__.py b/build/lib/torchdrug/layers/functional/extension/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/torchdrug/layers/functional/extension/embedding.cpp b/build/lib/torchdrug/layers/functional/extension/embedding.cpp new file mode 100644 index 00000000..a4a20fba --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/embedding.cpp @@ -0,0 +1,416 @@ +#include + +#include "embedding.h" + +namespace at { + +// In PyTorch 1.4.0, parallel_for depends on some functions from at::internal in ATen/Parallel.h +// which are not explicitly included +// This is fixed in some new PyTorch release +using namespace at::internal; + +void embedding_forward_check(CheckedFrom c, const TensorArg &entity_arg, const TensorArg &relation_arg, + const TensorArg &h_index_arg, const TensorArg &t_index_arg, const TensorArg &r_index_arg) { + checkDim(c, entity_arg, 2); + checkDim(c, relation_arg, 2); + checkAllSameNumel(c, {h_index_arg, r_index_arg, t_index_arg}); + checkScalarType(c, h_index_arg, kLong); + checkScalarType(c, t_index_arg, kLong); + checkScalarType(c, r_index_arg, kLong); +} + +void embedding_backward_check(CheckedFrom c, const TensorArg &entity_arg, const TensorArg &relation_arg, + const TensorArg &h_index_arg, const TensorArg &t_index_arg, const TensorArg &r_index_arg, + const TensorArg &score_grad_arg) { + embedding_forward_check(c, entity_arg, relation_arg, h_index_arg, t_index_arg, r_index_arg); + checkSameSize(c, h_index_arg, score_grad_arg); +} + +template +void transe_forward_out_cpu(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + parallel_for(0, num_sample, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = start; sample_id < end; sample_id++) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = 0; i < embedding_dim; i++) + x += ::abs(h[i] + r[i] - t[i]); + + score[sample_id] = x; + } + }); +} + +template +void transe_backward_out_cpu(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + // since #CPU thread < embedding_dim + // we can parallel over embedding_dim to avoid atomic operations + parallel_for(0, embedding_dim, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = 0; sample_id < num_sample; sample_id++) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = start; i < end; i++) { + scalar_t s = h[i] + r[i] - t[i] > 0 ? 1 : -1; + h_grad[i] += grad * s; + r_grad[i] += grad * s; + t_grad[i] += -grad * s; + } + } + }); +} + +template +void distmult_forward_out_cpu(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + parallel_for(0, num_sample, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = start; sample_id < end; sample_id++) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = 0; i < embedding_dim; i++) + x += h[i] * r[i] * t[i]; + + score[sample_id] = x; + } + }); +} + +template +void distmult_backward_out_cpu(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + // since #CPU thread < embedding_dim + // we can parallel over embedding_dim to avoid atomic operations + parallel_for(0, embedding_dim, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = 0; sample_id < num_sample; sample_id++) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = start; i < end; i++) { + scalar_t h_i = h[i], r_i = r[i], t_i = t[i]; + h_grad[i] += grad * r_i * t_i; + r_grad[i] += grad * h_i * t_i; + t_grad[i] += grad * h_i * r_i; + } + } + }); +} + +template +void complex_forward_out_cpu(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + parallel_for(0, num_sample, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = start; sample_id < end; sample_id++) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = 0; i < embedding_dim / 2; i++) { + scalar_t h_re = h[i], h_im = h[i + embedding_dim / 2]; + scalar_t r_re = r[i], r_im = r[i + embedding_dim / 2]; + scalar_t t_re = t[i], t_im = t[i + embedding_dim / 2]; + scalar_t product_re = h_re * r_re - h_im * r_im; + scalar_t product_im = h_re * r_im + h_im * r_re; + x += product_re * t_re + product_im * t_im; + } + + score[sample_id] = x; + } + }); +} + +template +void complex_backward_out_cpu(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + // since #CPU thread < embedding_dim + // we can parallel over embedding_dim to avoid atomic operations + parallel_for(0, embedding_dim / 2, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = 0; sample_id < num_sample; sample_id++) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = start; i < end; i++) { + scalar_t h_re = h[i], h_im = h[i + embedding_dim / 2]; + scalar_t r_re = r[i], r_im = r[i + embedding_dim / 2]; + scalar_t t_re = t[i], t_im = t[i + embedding_dim / 2]; + h_grad[i] = grad * (r_re * t_re + r_im * t_im); + h_grad[i + embedding_dim / 2] = grad * (-r_im * t_re + r_re * t_im); + r_grad[i] = grad * (h_re * t_re + h_im * t_im); + r_grad[i + embedding_dim / 2] = grad * (-h_im * t_re + h_re * t_im); + t_grad[i] = grad * (h_re * r_re - h_im * r_im); + t_grad[i + embedding_dim / 2] = grad * (h_re * r_im + h_im * r_re); + } + } + }); +} + +template +void rotate_forward_out_cpu(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + parallel_for(0, num_sample, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = start; sample_id < end; sample_id++) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim / 2; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = 0; i < embedding_dim / 2; i++) { + scalar_t h_re = h[i], h_im = h[i + embedding_dim / 2]; + scalar_t r_re = ::cos(r[i]), r_im = ::sin(r[i]); + scalar_t t_re = t[i], t_im = t[i + embedding_dim / 2]; + scalar_t distance_re = h_re * r_re - h_im * r_im - t_re; + scalar_t distance_im = h_re * r_im + h_im * r_re - t_im; + x += ::sqrt(distance_re * distance_re + distance_im * distance_im); + } + + score[sample_id] = x; + } + }); +} + +template +void rotate_backward_out_cpu(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const float kEpsilon = 1e-15; // 1e-15 from GraphVite + // since #CPU thread < embedding_dim / 2 + // we can parallel over embedding_dim to avoid atomic operations + parallel_for(0, embedding_dim / 2, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = 0; sample_id < num_sample; sample_id++) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim / 2; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim / 2; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = start; i < end; i++) { + scalar_t h_re = h[i], h_im = h[i + embedding_dim / 2]; + scalar_t r_re = ::cos(r[i]), r_im = ::sin(r[i]); + scalar_t t_re = t[i], t_im = t[i + embedding_dim / 2]; + scalar_t distance_re = h_re * r_re - h_im * r_im - t_re; + scalar_t distance_im = h_re * r_im + h_im * r_re - t_im; + scalar_t g = grad / (::sqrt(distance_re * distance_re + distance_im * distance_im) + kEpsilon); + h_grad[i] += g * (distance_re * r_re + distance_im * r_im); + h_grad[i + embedding_dim / 2] += g * (-distance_re * r_im + distance_im * r_re); + r_grad[i] += g * (-distance_re * (h_re * r_im + h_im * r_re) + + distance_im * (h_re * r_re - h_im * r_im)); + t_grad[i] += -g * distance_re; + t_grad[i + embedding_dim / 2] += -g * distance_im; + } + } + }); +} + +template +void simple_forward_out_cpu(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + parallel_for(0, num_sample, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = start; sample_id < end; sample_id++) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = 0; i < embedding_dim; i++) { + int64_t j = (i + embedding_dim / 2) % embedding_dim; + x += h[i] * r[i] * t[j]; + } + score[sample_id] = x; + } + }); +} + +template +void simple_backward_out_cpu(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + // since #CPU thread < embedding_dim + // we can parallel over embedding_dim to avoid atomic operations + parallel_for(0, embedding_dim, 0, [&](int64_t start, int64_t end) { + for (int64_t sample_id = 0; sample_id < num_sample; sample_id++) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = start; i < end; i++) { + int64_t j = (i + embedding_dim / 2) % embedding_dim; + scalar_t h_i = h[i], r_i = r[i], t_j = t[j]; + h_grad[i] += grad * r_i * t_j; + r_grad[i] += grad * h_i * t_j; + t_grad[j] += grad * h_i * r_i; + } + } + }); +} + +// If written in templates, the partial instantiation of template template parameters can't be resolved +// Therefore we opt for a macro implementation +#define DECLARE_FORWARD_IMPL(NAME) \ + Tensor NAME##_forward_cpu(const Tensor &entity_, const Tensor &relation_, \ + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_) { \ + constexpr const char *fn_name = #NAME"_forward_cpu"; \ + TensorArg entity_arg(entity_, "entity", 1), relation_arg(relation_, "relation", 2), \ + h_index_arg(h_index_, "h_index", 3), r_index_arg(r_index_, "r_index", 4), \ + t_index_arg(t_index_, "t_index", 5); \ + \ + embedding_forward_check(fn_name, entity_arg, relation_arg, h_index_arg, r_index_arg, t_index_arg); \ + checkDeviceType(fn_name, {entity_, relation_, h_index_, r_index_, t_index_}, kCPU); \ + \ + const Tensor entity = entity_.contiguous(); \ + const Tensor relation = relation_.contiguous(); \ + const Tensor h_index = h_index_.contiguous(); \ + const Tensor r_index = r_index_.contiguous(); \ + const Tensor t_index = t_index_.contiguous(); \ + \ + int64_t num_entity = entity.size(0); \ + int64_t num_relation = relation.size(0); \ + int64_t embedding_dim = entity.size(-1); \ + int64_t num_sample = h_index.numel(); \ + \ + Tensor score = at::empty(h_index.sizes(), entity.options()); \ + \ + AT_DISPATCH_FLOATING_TYPES(entity.scalar_type(), #NAME"_forward_cpu", [&] { \ + NAME##_forward_out_cpu( \ + entity.data_ptr(), relation.data_ptr(), \ + h_index.data_ptr(), t_index.data_ptr(), r_index.data_ptr(), \ + score.data_ptr(), \ + num_entity, num_relation, embedding_dim, num_sample \ + ); \ + }); \ + \ + return score; \ + } + +#define DECLARE_BACKWARD_IMPL(NAME) \ + std::tuple NAME##_backward_cpu( \ + const Tensor &entity_, const Tensor &relation_, const Tensor &h_index_, \ + const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_) { \ + constexpr const char *fn_name = #NAME"_backward_cpu"; \ + TensorArg entity_arg(entity_, "entity", 1), relation_arg(relation_, "relation", 2), \ + h_index_arg(h_index_, "h_index", 3), r_index_arg(r_index_, "r_index", 4), \ + t_index_arg(t_index_, "t_index", 5), score_grad_arg(score_grad_, "score_grad", 6); \ + \ + embedding_backward_check(fn_name, entity_arg, relation_arg, h_index_arg, r_index_arg, t_index_arg, \ + score_grad_arg); \ + checkDeviceType(fn_name, {entity_, relation_, h_index_, r_index_, t_index_, score_grad_}, kCPU); \ + \ + const Tensor entity = entity_.contiguous(); \ + const Tensor relation = relation_.contiguous(); \ + const Tensor h_index = h_index_.contiguous(); \ + const Tensor r_index = r_index_.contiguous(); \ + const Tensor t_index = t_index_.contiguous(); \ + const Tensor score_grad = score_grad_.contiguous(); \ + \ + int64_t num_entity = entity.size(0); \ + int64_t num_relation = relation.size(0); \ + int64_t embedding_dim = entity.size(1); \ + int64_t num_sample = h_index.numel(); \ + \ + Tensor entity_grad = at::zeros_like(entity); \ + Tensor relation_grad = at::zeros_like(relation); \ + \ + AT_DISPATCH_FLOATING_TYPES(entity.scalar_type(), #NAME"_backward_cpu", [&] { \ + NAME##_backward_out_cpu( \ + entity.data_ptr(), relation.data_ptr(), \ + h_index.data_ptr(), t_index.data_ptr(), r_index.data_ptr(), \ + score_grad.data_ptr(), \ + entity_grad.data_ptr(), relation_grad.data_ptr(), \ + num_entity, num_relation, embedding_dim, num_sample \ + ); \ + }); \ + \ + return std::make_tuple(entity_grad, relation_grad); \ + } + +DECLARE_FORWARD_IMPL(transe) +DECLARE_BACKWARD_IMPL(transe) + +DECLARE_FORWARD_IMPL(distmult) +DECLARE_BACKWARD_IMPL(distmult) + +DECLARE_FORWARD_IMPL(complex) +DECLARE_BACKWARD_IMPL(complex) + +DECLARE_FORWARD_IMPL(rotate) +DECLARE_BACKWARD_IMPL(rotate) + +DECLARE_FORWARD_IMPL(simple) +DECLARE_BACKWARD_IMPL(simple) + +} // namespace at + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("transe_forward_cpu", &at::transe_forward_cpu); + m.def("transe_backward_cpu", &at::transe_backward_cpu); + m.def("distmult_forward_cpu", &at::distmult_forward_cpu); + m.def("distmult_backward_cpu", &at::distmult_backward_cpu); + m.def("complex_forward_cpu", &at::complex_forward_cpu); + m.def("complex_backward_cpu", &at::complex_backward_cpu); + m.def("rotate_forward_cpu", &at::rotate_forward_cpu); + m.def("rotate_backward_cpu", &at::rotate_backward_cpu); + m.def("simple_forward_cpu", &at::simple_forward_cpu); + m.def("simple_backward_cpu", &at::simple_backward_cpu); +#ifdef CUDA_OP + m.def("transe_forward_cuda", &at::transe_forward_cuda); + m.def("transe_backward_cuda", &at::transe_backward_cuda); + m.def("distmult_forward_cuda", &at::distmult_forward_cuda); + m.def("distmult_backward_cuda", &at::distmult_backward_cuda); + m.def("complex_forward_cuda", &at::complex_forward_cuda); + m.def("complex_backward_cuda", &at::complex_backward_cuda); + m.def("rotate_forward_cuda", &at::rotate_forward_cuda); + m.def("rotate_backward_cuda", &at::rotate_backward_cuda); + m.def("simple_forward_cuda", &at::simple_forward_cuda); + m.def("simple_backward_cuda", &at::simple_backward_cuda); +#endif +} \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/embedding.cu b/build/lib/torchdrug/layers/functional/extension/embedding.cu new file mode 100644 index 00000000..d48dc488 --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/embedding.cu @@ -0,0 +1,412 @@ +#include +#include + +#include "util.cuh" +#include "embedding.h" + +// Memory & time efficient implementation of embedding score functions +// Much of the code is adapted from GraphVite +// https://github.com/DeepGraphLearning/graphvite + +namespace at { + +template +__global__ +void transe_forward_out_cuda(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = lane_id; i < embedding_dim; i += warpSize) + x += ::abs(h[i] + r[i] - t[i]); + x = warp_broadcast(warp_reduce(x), 0); + + if (lane_id == 0) + score[sample_id] = x; + } +} + +template +__global__ +void transe_backward_out_cuda(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = lane_id; i < embedding_dim; i += warpSize) { + scalar_t s = h[i] + r[i] - t[i] > 0 ? 1 : -1; + atomicAdd(&h_grad[i], grad * s); + atomicAdd(&r_grad[i], grad * s); + atomicAdd(&t_grad[i], -grad * s); + } + } +} + +template +__global__ +void distmult_forward_out_cuda(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = lane_id; i < embedding_dim; i += warpSize) + x += h[i] * r[i] * t[i]; + x = warp_broadcast(warp_reduce(x), 0); + + if (lane_id == 0) + score[sample_id] = x; + } +} + +template +__global__ +void distmult_backward_out_cuda(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = lane_id; i < embedding_dim; i += warpSize) { + scalar_t h_i = h[i], r_i = r[i], t_i = t[i]; + atomicAdd(&h_grad[i], grad * r_i * t_i); + atomicAdd(&r_grad[i], grad * h_i * t_i); + atomicAdd(&t_grad[i], grad * h_i * r_i); + } + } +} + +template +__global__ +void complex_forward_out_cuda(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = lane_id; i < embedding_dim / 2; i += warpSize) { + scalar_t h_re = h[i], h_im = h[i + embedding_dim / 2]; + scalar_t r_re = r[i], r_im = r[i + embedding_dim / 2]; + scalar_t t_re = t[i], t_im = t[i + embedding_dim / 2]; + scalar_t product_re = h_re * r_re - h_im * r_im; + scalar_t product_im = h_re * r_im + h_im * r_re; + x += product_re * t_re + product_im * t_im; + } + x = warp_broadcast(warp_reduce(x), 0); + + if (lane_id == 0) + score[sample_id] = x; + } +} + +template +__global__ +void complex_backward_out_cuda(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = lane_id; i < embedding_dim / 2; i += warpSize) { + scalar_t h_re = h[i], h_im = h[i + embedding_dim / 2]; + scalar_t r_re = r[i], r_im = r[i + embedding_dim / 2]; + scalar_t t_re = t[i], t_im = t[i + embedding_dim / 2]; + atomicAdd(&h_grad[i], grad * (r_re * t_re + r_im * t_im)); + atomicAdd(&h_grad[i + embedding_dim / 2], grad * (-r_im * t_re + r_re * t_im)); + atomicAdd(&r_grad[i], grad * (h_re * t_re + h_im * t_im)); + atomicAdd(&r_grad[i + embedding_dim / 2], grad * (-h_im * t_re + h_re * t_im)); + atomicAdd(&t_grad[i], grad * (h_re * r_re - h_im * r_im)); + atomicAdd(&t_grad[i + embedding_dim / 2], grad * (h_re * r_im + h_im * r_re)); + } + } +} + +template +__global__ +void rotate_forward_out_cuda(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim / 2; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = lane_id; i < embedding_dim / 2; i += warpSize) { + scalar_t h_re = h[i], h_im = h[i + embedding_dim / 2]; + scalar_t r_re = ::cos(r[i]), r_im = ::sin(r[i]); + scalar_t t_re = t[i], t_im = t[i + embedding_dim / 2]; + scalar_t distance_re = h_re * r_re - h_im * r_im - t_re; + scalar_t distance_im = h_re * r_im + h_im * r_re - t_im; + x += ::sqrt(distance_re * distance_re + distance_im * distance_im); + } + x = warp_broadcast(warp_reduce(x), 0); + + if (lane_id == 0) + score[sample_id] = x; + } +} + +template +__global__ +void rotate_backward_out_cuda(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const float kEpsilon = 1e-15; // 1e-15 from GraphVite + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim / 2; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim / 2; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = lane_id; i < embedding_dim / 2; i += warpSize) { + scalar_t h_re = h[i], h_im = h[i + embedding_dim / 2]; + scalar_t r_re = ::cos(r[i]), r_im = ::sin(r[i]); + scalar_t t_re = t[i], t_im = t[i + embedding_dim / 2]; + scalar_t distance_re = h_re * r_re - h_im * r_im - t_re; + scalar_t distance_im = h_re * r_im + h_im * r_re - t_im; + scalar_t g = grad / (::sqrt(distance_re * distance_re + distance_im * distance_im) + kEpsilon); + atomicAdd(&h_grad[i], g * (distance_re * r_re + distance_im * r_im)); + atomicAdd(&h_grad[i + embedding_dim / 2], g * (-distance_re * r_im + distance_im * r_re)); + atomicAdd(&r_grad[i], g * (-distance_re * (h_re * r_im + h_im * r_re) + + distance_im * (h_re * r_re - h_im * r_im))); + atomicAdd(&t_grad[i], -g * distance_re); + atomicAdd(&t_grad[i + embedding_dim / 2], -g * distance_im); + } + } +} + +template +__global__ +void simple_forward_out_cuda(const scalar_t *entity, const scalar_t *relation, const int64_t *h_index, + const int64_t *t_index, const int64_t *r_index, scalar_t *score, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + const scalar_t *h = entity + h_index[sample_id] * embedding_dim; + const scalar_t *r = relation + r_index[sample_id] * embedding_dim; + const scalar_t *t = entity + t_index[sample_id] * embedding_dim; + scalar_t x = 0; + for (int64_t i = lane_id; i < embedding_dim; i += warpSize) { + int64_t j = (i + embedding_dim / 2) % embedding_dim; + x += h[i] * r[i] * t[j]; + } + x = warp_broadcast(warp_reduce(x), 0); + + if (lane_id == 0) + score[sample_id] = x; + } +} + +template +__global__ +void simple_backward_out_cuda(const scalar_t *entity, const scalar_t *relation, + const int64_t *h_index, const int64_t *t_index, const int64_t *r_index, + const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad, + int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) { + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int lane_id = thread_id % warpSize; + const int num_thread = gridDim.x * blockDim.x; + + for (int64_t sample_id = thread_id / warpSize; sample_id < num_sample; sample_id += num_thread / warpSize) { + int64_t h_sample = h_index[sample_id]; + int64_t r_sample = r_index[sample_id]; + int64_t t_sample = t_index[sample_id]; + const scalar_t *h = entity + h_sample * embedding_dim; + const scalar_t *r = relation + r_sample * embedding_dim; + const scalar_t *t = entity + t_sample * embedding_dim; + scalar_t *h_grad = entity_grad + h_sample * embedding_dim; + scalar_t *r_grad = relation_grad + r_sample * embedding_dim; + scalar_t *t_grad = entity_grad + t_sample * embedding_dim; + scalar_t grad = score_grad[sample_id]; + + for (int64_t i = lane_id; i < embedding_dim; i += warpSize) { + int64_t j = (i + embedding_dim / 2) % embedding_dim; + scalar_t h_i = h[i], r_i = r[i], t_j = t[j]; + atomicAdd(&h_grad[i], grad * r_i * t_j); + atomicAdd(&r_grad[i], grad * h_i * t_j); + atomicAdd(&t_grad[j], grad * h_i * r_i); + } + } +} + +// If written in templates, the partial instantiation of template template parameters can't be resolved +// Therefore we opt for a macro implementation +#define DECLARE_FORWARD_IMPL(NAME) \ + Tensor NAME##_forward_cuda(const Tensor &entity_, const Tensor &relation_, const Tensor &h_index_, \ + const Tensor &t_index_, const Tensor &r_index_) { \ + constexpr const char *fn_name = #NAME"_forward_cuda"; \ + TensorArg entity_arg(entity_, "entity", 1), relation_arg(relation_, "relation", 2), \ + h_index_arg(h_index_, "h_index", 3), r_index_arg(r_index_, "r_index", 4), \ + t_index_arg(t_index_, "t_index", 5); \ + \ + embedding_forward_check(fn_name, entity_arg, relation_arg, h_index_arg, r_index_arg, t_index_arg); \ + checkAllSameGPU(fn_name, {entity_arg, relation_arg, h_index_arg, r_index_arg, t_index_arg}); \ + \ + const Tensor entity = entity_.contiguous(); \ + const Tensor relation = relation_.contiguous(); \ + const Tensor h_index = h_index_.contiguous(); \ + const Tensor r_index = r_index_.contiguous(); \ + const Tensor t_index = t_index_.contiguous(); \ + \ + int64_t num_entity = entity.size(0); \ + int64_t num_relation = relation.size(0); \ + int64_t embedding_dim = entity.size(-1); \ + int64_t num_sample = h_index.numel(); \ + \ + Tensor score = at::empty(h_index.sizes(), entity.options()); \ + \ + cudaSetDevice(entity.get_device()); \ + auto stream = at::cuda::getCurrentCUDAStream(); \ + \ + AT_DISPATCH_FLOATING_TYPES(entity.scalar_type(), fn_name, [&] { \ + NAME##_forward_out_cuda<<<4096, 512, 0, stream>>>( \ + entity.data_ptr(), relation.data_ptr(), \ + h_index.data_ptr(), t_index.data_ptr(), r_index.data_ptr(), \ + score.data_ptr(), \ + num_entity, num_relation, embedding_dim, num_sample \ + ); \ + }); \ + \ + return score; \ + } \ + +#define DECLARE_BACKWARD_IMPL(NAME) \ + std::tuple NAME##_backward_cuda( \ + const Tensor &entity_, const Tensor &relation_, const Tensor &h_index_, \ + const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_) { \ + constexpr const char *fn_name = #NAME"_backward_cuda"; \ + TensorArg entity_arg(entity_, "entity", 1), relation_arg(relation_, "relation", 2), \ + h_index_arg(h_index_, "h_index", 3), r_index_arg(r_index_, "r_index", 4), \ + t_index_arg(t_index_, "t_index", 5), score_grad_arg(score_grad_, "score_grad", 6); \ + \ + embedding_backward_check(fn_name, entity_arg, relation_arg, h_index_arg, r_index_arg, t_index_arg, \ + score_grad_arg); \ + checkAllSameGPU(fn_name, {entity_arg, relation_arg, h_index_arg, r_index_arg, t_index_arg, score_grad_arg}); \ + \ + const Tensor entity = entity_.contiguous(); \ + const Tensor relation = relation_.contiguous(); \ + const Tensor h_index = h_index_.contiguous(); \ + const Tensor r_index = r_index_.contiguous(); \ + const Tensor t_index = t_index_.contiguous(); \ + const Tensor score_grad = score_grad_.contiguous(); \ + \ + int64_t num_entity = entity.size(0); \ + int64_t num_relation = relation.size(0); \ + int64_t embedding_dim = entity.size(-1); \ + int64_t num_sample = h_index.numel(); \ + \ + Tensor entity_grad = at::zeros_like(entity); \ + Tensor relation_grad = at::zeros_like(relation); \ + \ + cudaSetDevice(entity.get_device()); \ + auto stream = at::cuda::getCurrentCUDAStream(); \ + \ + AT_DISPATCH_FLOATING_TYPES(entity.scalar_type(), fn_name, [&] { \ + NAME##_backward_out_cuda<<<4096, 512, 0, stream>>>( \ + entity.data_ptr(), relation.data_ptr(), \ + h_index.data_ptr(), t_index.data_ptr(), r_index.data_ptr(), \ + score_grad.data_ptr(), \ + entity_grad.data_ptr(), relation_grad.data_ptr(), \ + num_entity, num_relation, embedding_dim, num_sample \ + ); \ + }); \ + \ + return std::make_tuple(entity_grad, relation_grad); \ + } + +DECLARE_FORWARD_IMPL(transe) +DECLARE_BACKWARD_IMPL(transe) + +DECLARE_FORWARD_IMPL(distmult) +DECLARE_BACKWARD_IMPL(distmult) + +DECLARE_FORWARD_IMPL(complex) +DECLARE_BACKWARD_IMPL(complex) + +DECLARE_FORWARD_IMPL(rotate) +DECLARE_BACKWARD_IMPL(rotate) + +DECLARE_FORWARD_IMPL(simple) +DECLARE_BACKWARD_IMPL(simple) + +} // namespace at \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/embedding.h b/build/lib/torchdrug/layers/functional/extension/embedding.h new file mode 100644 index 00000000..bb48d53f --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/embedding.h @@ -0,0 +1,88 @@ +#pragma once + +#include + +#include + +namespace at { + +void embedding_forward_check(CheckedFrom c, const TensorArg &entity_arg, const TensorArg &relation_arg, + const TensorArg &h_index_arg, const TensorArg &t_index_arg, const TensorArg &r_index_arg); + +void embedding_backward_check(CheckedFrom c, const TensorArg &entity_arg, const TensorArg &relation_arg, + const TensorArg &h_index_arg, const TensorArg &t_index_arg, const TensorArg &r_index_arg, + const TensorArg &score_grad_arg); + +Tensor transe_forward_cpu(const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_); + +std::tuple transe_backward_cpu( + const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); + +Tensor distmult_forward_cpu(const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_); + +std::tuple distmult_backward_cpu( + const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); + +Tensor complex_forward_cpu(const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_); + +std::tuple complex_backward_cpu( + const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); + +Tensor rotate_forward_cpu(const Tensor &entity_, const Tensor &relation_, const Tensor &h_index_, + const Tensor &t_index_, const Tensor &r_index_); + +std::tuple rotate_backward_cpu( + const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); + +Tensor simple_forward_cpu(const Tensor &entity_, const Tensor &relation_, const Tensor &h_index_, + const Tensor &t_index_, const Tensor &r_index_); + +std::tuple simple_backward_cpu( + const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); + +#ifdef CUDA_OP +Tensor transe_forward_cuda(const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_); + +std::tuple transe_backward_cuda( + const Tensor &entity, const Tensor &relation_, + const Tensor &h_index, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); + +Tensor distmult_forward_cuda(const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_); + +std::tuple distmult_backward_cuda( + const Tensor &entity, const Tensor &relation_, + const Tensor &h_index, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); + +Tensor complex_forward_cuda(const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_); + +std::tuple complex_backward_cuda( + const Tensor &entity, const Tensor &relation_, + const Tensor &h_index, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); + +Tensor rotate_forward_cuda(const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_); + +std::tuple rotate_backward_cuda( + const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); + +Tensor simple_forward_cuda(const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_); + +std::tuple simple_backward_cuda( + const Tensor &entity_, const Tensor &relation_, + const Tensor &h_index_, const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_); +#endif + +} // namespace at \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/operator.cuh b/build/lib/torchdrug/layers/functional/extension/operator.cuh new file mode 100644 index 00000000..ad217734 --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/operator.cuh @@ -0,0 +1,82 @@ +#pragma once + +#include + +#ifdef __CUDA_ARCH__ + #define HOST_DEVICE __host__ __device__ +#else + #define HOST_DEVICE +#endif + +namespace at { + +template +struct BinaryAdd { + HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) { + return x + y; + } + + HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) { + return 1; + } + + HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) { + return 1; + } +}; + +template +struct BinaryMul { + HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) { + return x * y; + } + + HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) { + return y; + } + + HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) { + return x; + } +}; + +template +struct NaryAdd { + HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) { + return result + x; + } + + HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) { + return 1; + } + + static constexpr scalar_t zero = 0; +}; + +template +struct NaryMin { + HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) { + return result < x ? result : x; + } + + HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) { + return result == x ? 1 : 0; + } + + static constexpr scalar_t zero = std::numeric_limits::max(); +}; + +template +struct NaryMax { + HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) { + return result > x ? result : x; + } + + HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) { + return result == x ? 1 : 0; + } + + static constexpr scalar_t zero = std::numeric_limits::lowest(); +}; + +} // namespace at \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/rspmm.cpp b/build/lib/torchdrug/layers/functional/extension/rspmm.cpp new file mode 100644 index 00000000..a2c3f54a --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/rspmm.cpp @@ -0,0 +1,263 @@ +#include + +#include + +#include "operator.cuh" +#include "rspmm.h" + +namespace at { + +// In PyTorch 1.4.0, parallel_for depends on some functions from at::internal in ATen/Parallel.h +// which are not explicitly included +// This is fixed in some new PyTorch release +using namespace at::internal; + +void rspmm_forward_check(CheckedFrom c, const TensorArg &sparse_arg, const TensorArg &relation_arg, + const TensorArg &input_arg) { + TORCH_CHECK(sparse_arg->sparse_dim() == 3, + "Expected 3-dimensional sparse tensor, but got ", sparse_arg->sparse_dim(), + "-dimensional tensor for ", sparse_arg," (while checking arguments for ", c, ")"); + TORCH_CHECK(sparse_arg->dense_dim() == 0, + "Expected 3-dimensional sparse tensor, but got ", sparse_arg->dense_dim(), + " dense dimensions for", sparse_arg," (while checking arguments for ", c, ")"); + checkDim(c, relation_arg, 2); + checkDim(c, input_arg, 2); + checkScalarType(c, input_arg, sparse_arg->scalar_type()); + checkSameType(c, relation_arg, input_arg); + checkSize(c, input_arg, 0, sparse_arg->size(1)); + checkSize(c, relation_arg, {sparse_arg->size(2), input_arg->size(1)}); +} + +void rspmm_backward_check(CheckedFrom c, const TensorArg &sparse_arg, const TensorArg &relation_arg, + const TensorArg &input_arg, const TensorArg &output_arg, const TensorArg &output_grad_arg) { + rspmm_forward_check(c, sparse_arg, relation_arg, input_arg); + checkDim(c, output_arg, 2); + checkSameSize(c, output_arg, output_grad_arg); + checkSameType(c, input_arg, output_arg); + checkSameType(c, input_arg, output_grad_arg); + checkSize(c, output_arg, {sparse_arg->size(0), input_arg->size(1)}); +} + +std::tuple coo2csr3d(const SparseTensor &sparse) { + TORCH_CHECK(sparse.is_coalesced(), "Expect coalesced sparse tensor"); + Tensor index = sparse.indices(); + Tensor row_ind = index.select(0, 0); + Tensor col_ind = index.select(0, 1); + Tensor layer_ind = index.select(0, 2); + Tensor value = sparse.values(); + // scatter_add is super slow for int64, due to non-hardware atomic operations + // use int32 instead + Tensor nnz_per_row = at::zeros({sparse.size(0)}, row_ind.options().dtype(at::ScalarType::Int)); + nnz_per_row.scatter_add_(0, row_ind, at::ones(row_ind.sizes(), nnz_per_row.options())); + nnz_per_row = nnz_per_row.toType(at::ScalarType::Long); + Tensor row_ptr = nnz_per_row.cumsum(0) - nnz_per_row; + return std::make_tuple(row_ptr, col_ind, layer_ind, value); +} + +SparseTensor csr2coo3d(const Tensor &row_ptr_, const Tensor &col_ind, const Tensor &layer_ind, const Tensor &value, + IntArrayRef size) { + Tensor row_ptr = row_ptr_.masked_select(row_ptr_ < col_ind.size(0)); + // scatter_add is super slow for int64, due to non-hardware atomic operations + // use int32 instead + Tensor row_ind = at::zeros(col_ind.sizes(), col_ind.options().dtype(at::ScalarType::Int)); + row_ind.scatter_add_(0, row_ptr, at::ones(row_ptr.sizes(), row_ind.options())); + row_ind = row_ind.toType(at::ScalarType::Long); + row_ind = row_ind.cumsum(0) - 1; + Tensor index = at::stack({row_ind, col_ind, layer_ind}, 0); + return at::_sparse_coo_tensor_unsafe(index, value, size, value.options().layout(kSparse)); +} + +template +void rspmm_forward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind, + const scalar_t *value, const scalar_t *relation, const scalar_t *input, + scalar_t *output, + int64_t num_row, int64_t nnz, int64_t dim) { + parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) { + for (int64_t row = row_start; row < row_end; row++) { + for (int64_t d = 0; d < dim; d++) + output[row * dim + d] = NaryOp::zero; + + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) { + int64_t col = col_ind[ptr]; + int64_t layer = layer_ind[ptr]; + scalar_t val = value[ptr]; + for (int64_t d = 0; d < dim; d++) { + scalar_t x = BinaryOp::forward(relation[layer * dim + d], input[col * dim + d]); + scalar_t y = val * x; + scalar_t &out = output[row * dim + d]; + out = NaryOp::forward(out, y); + } + } + } + }); +} + +template +void rspmm_backward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind, + const scalar_t *value, const scalar_t *relation, const scalar_t *input, + const scalar_t *output, const scalar_t *output_grad, + scalar_t *value_grad, scalar_t *relation_grad, scalar_t *input_grad, + int64_t num_row, int64_t nnz, int64_t dim, + std::vector &relation_mutex, std::vector &input_mutex) { + parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) { + for (int64_t row = row_start; row < row_end; row++) { + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) { + int64_t col = col_ind[ptr]; + int64_t layer = layer_ind[ptr]; + scalar_t val = value[ptr]; + scalar_t val_grad = 0; + for (int64_t d = 0; d < dim; d++) { + scalar_t rel = relation[layer * dim + d]; + scalar_t in = input[col * dim + d]; + scalar_t out = output[row * dim + d]; + scalar_t out_grad = output_grad[row * dim + d]; + scalar_t x = BinaryOp::forward(rel, in); + scalar_t y = val * x; + scalar_t dx_drel = BinaryOp::backward_lhs(rel, in); + scalar_t dx_din = BinaryOp::backward_rhs(rel, in); + scalar_t dout_dy = NaryOp::backward(out, y); + scalar_t dy_dval = x; + scalar_t dy_dx = val; + val_grad += out_grad * dout_dy * dy_dval; + { + std::lock_guard lock(relation_mutex[layer * dim + d]); + relation_grad[layer * dim + d] += out_grad * dout_dy * dy_dx * dx_drel; + } + { + std::lock_guard lock(input_mutex[col * dim + d]); + input_grad[col * dim + d] += out_grad * dout_dy * dy_dx * dx_din; + } + } + value_grad[ptr] = val_grad; + } + } + }); +} + +template class NaryOp, template class BinaryOp> +Tensor rspmm_forward_cpu(const SparseTensor &sparse, const Tensor &relation_, const Tensor &input_) { + constexpr const char *fn_name = "rspmm_forward_cpu"; + TensorArg sparse_arg(sparse, "sparse", 1), relation_arg(relation_, "relation", 2), input_arg(input_, "input", 3); + + rspmm_forward_check(fn_name, sparse_arg, relation_arg, input_arg); + checkDeviceType(fn_name, {sparse, relation_, input_}, kCPU); + + const Tensor relation = relation_.contiguous(); + const Tensor input = input_.contiguous(); + + int64_t nnz = sparse._nnz(); + int64_t dim = input.size(1); + int64_t num_row = sparse.size(0); + Tensor output = at::empty({num_row, dim}, input.options()); + + auto csr = coo2csr3d(sparse); + Tensor row_ptr = std::get<0>(csr); + Tensor col_ind = std::get<1>(csr); + Tensor layer_ind = std::get<2>(csr); + Tensor value = std::get<3>(csr); + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_forward_cpu", [&] { + rspmm_forward_out_cpu, BinaryOp>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + layer_ind.data_ptr(), + value.data_ptr(), + relation.data_ptr(), + input.data_ptr(), + output.data_ptr(), + num_row, nnz, dim + ); + }); + + return output; +} + +template class NaryOp, template class BinaryOp> +std::tuple rspmm_backward_cpu( + const SparseTensor &sparse, const Tensor &relation_, const Tensor &input_, const Tensor &output_, + const Tensor &output_grad_) { + constexpr const char *fn_name = "rspmm_backward_cpu"; + TensorArg sparse_arg(sparse, "sparse", 1), relation_arg(relation_, "relation", 2), input_arg(input_, "input", 3), + output_arg(output_, "output", 4), output_grad_arg(output_grad_, "output_grad", 5); + + rspmm_backward_check(fn_name, sparse_arg, relation_arg, input_arg, output_arg, output_grad_arg); + checkDeviceType(fn_name, {sparse, relation_, input_, output_, output_grad_}, kCPU); + + const Tensor relation = relation_.contiguous(); + const Tensor input = input_.contiguous(); + const Tensor output = output_.contiguous(); + const Tensor output_grad = output_grad_.contiguous(); + + int64_t nnz = sparse._nnz(); + int64_t dim = input.size(1); + int64_t num_row = sparse.size(0); + Tensor value_grad = at::zeros_like(sparse.values()); + Tensor relation_grad = at::zeros_like(relation); + Tensor input_grad = at::zeros_like(input); + SparseTensor sparse_grad = at::_sparse_coo_tensor_unsafe(sparse.indices(), value_grad, sparse.sizes()); + + auto csr = coo2csr3d(sparse); + Tensor row_ptr = std::get<0>(csr); + Tensor col_ind = std::get<1>(csr); + Tensor layer_ind = std::get<2>(csr); + Tensor value = std::get<3>(csr); + std::vector relation_mutex(relation.numel()); + std::vector input_mutex(input.numel()); + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cpu", [&] { + rspmm_backward_out_cpu, BinaryOp>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + layer_ind.data_ptr(), + value.data_ptr(), + relation.data_ptr(), + input.data_ptr(), + output.data_ptr(), + output_grad.data_ptr(), + value_grad.data_ptr(), + relation_grad.data_ptr(), + input_grad.data_ptr(), + num_row, nnz, dim, + relation_mutex, input_mutex + ); + }); + + return std::make_tuple(sparse_grad, relation_grad, input_grad); +} + +#define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \ + Tensor rspmm_##ADD##_##MUL##_forward_cpu( \ + const SparseTensor &sparse, const Tensor &relation, const Tensor &input) { \ + return rspmm_forward_cpu(sparse, relation, input); \ + } + +#define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \ + std::tuple rspmm_##ADD##_##MUL##_backward_cpu( \ + const SparseTensor &sparse, const Tensor &relation, const Tensor &input, const Tensor &output, \ + const Tensor &output_grad) { \ + return rspmm_backward_cpu(sparse, relation, input, output, output_grad); \ + } + +DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul) +DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul) + +DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul) +DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul) + +DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul) +DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul) + +DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd) +DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd) + +DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd) +DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd) + +DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd) +DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd) + +} // namespace at \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/rspmm.cu b/build/lib/torchdrug/layers/functional/extension/rspmm.cu new file mode 100644 index 00000000..1794719f --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/rspmm.cu @@ -0,0 +1,374 @@ +#include +#include + +#include "util.cuh" +#include "operator.cuh" +#include "rspmm.h" + +namespace at { + +// Memory & time efficient implementation of generalized spmm +// Much of the code is inspired by GE-SpMM +// https://github.com/hgyhungry/ge-spmm + +namespace { + +const int kCoarseningFactor = 2; +const int kThreadPerBlock = 256; + +} // namespace anonymous + +template +__global__ +void rspmm_forward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind, + const scalar_t *value, const scalar_t *relation, const scalar_t *input, + scalar_t *output, + int64_t num_row, int64_t nnz, int64_t dim) { + // for best optimization, the following code is compiled with constant warpSize + assert(blockDim.x == warpSize); + + extern __shared__ int64_t buffer[]; + int64_t *col_ind_buf = buffer; + int64_t *layer_ind_buf = buffer + blockDim.y * warpSize; + scalar_t *value_buf = reinterpret_cast(layer_ind_buf + blockDim.y * warpSize); + col_ind_buf += threadIdx.y * warpSize; + layer_ind_buf += threadIdx.y * warpSize; + value_buf += threadIdx.y * warpSize; + + int64_t row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= num_row) + return; + int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x; + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + scalar_t out[kCoarseningFactor]; +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) + out[i] = NaryOp::zero; + + for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) { + int64_t ptr = block_ptr + threadIdx.x; + if (ptr < ptr_end) { + col_ind_buf[threadIdx.x] = col_ind[ptr]; + layer_ind_buf[threadIdx.x] = layer_ind[ptr]; + value_buf[threadIdx.x] = value[ptr]; + } + __syncwarp(); + + int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr; + for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) { + int64_t col = col_ind_buf[offset_ptr]; + int64_t layer = layer_ind_buf[offset_ptr]; + scalar_t val = value_buf[offset_ptr]; +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) { + int64_t d = d_start + i * warpSize; + if (d >= dim) + break; + scalar_t x = BinaryOp::forward(relation[layer * dim + d], input[col * dim + d]); + scalar_t y = val * x; + out[i] = NaryOp::forward(out[i], y); + } + } + __syncwarp(); + } + +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) { + int64_t d = d_start + i * warpSize; + if (d >= dim) + break; + output[row * dim + d] = out[i]; + } +} + +template +__global__ +void rspmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind, + const scalar_t *value, const scalar_t *relation, const scalar_t *input, + const scalar_t *output, const scalar_t *output_grad, + scalar_t *value_grad, scalar_t *relation_grad, scalar_t *input_grad, + int64_t num_row, int64_t nnz, int64_t dim) { + // for best optimization, the following code is compiled with constant warpSize + assert(blockDim.x == warpSize); + + extern __shared__ int64_t buffer[]; + int64_t *col_ind_buf = buffer; + int64_t *layer_ind_buf = col_ind_buf + blockDim.y * warpSize; + scalar_t *value_buf = reinterpret_cast(layer_ind_buf + blockDim.y * warpSize); + col_ind_buf += threadIdx.y * warpSize; + layer_ind_buf += threadIdx.y * warpSize; + value_buf += threadIdx.y * warpSize; + + int64_t row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= num_row) + return; + int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x; + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + + for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) { + int64_t ptr = block_ptr + threadIdx.x; + if (ptr < ptr_end) { + col_ind_buf[threadIdx.x] = col_ind[ptr]; + layer_ind_buf[threadIdx.x] = layer_ind[ptr]; + value_buf[threadIdx.x] = value[ptr]; + } + __syncwarp(); + + int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr; + for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) { + int64_t col = col_ind_buf[offset_ptr]; + int64_t layer = layer_ind_buf[offset_ptr]; + scalar_t val = value_buf[offset_ptr]; + scalar_t val_grad = 0; +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) { + int64_t d = d_start + i * warpSize; + if (d >= dim) + break; + scalar_t rel = relation[layer * dim + d]; + scalar_t in = input[col * dim + d]; + scalar_t out = output[row * dim + d]; + scalar_t out_grad = output_grad[row * dim + d]; + scalar_t x = BinaryOp::forward(rel, in); + scalar_t y = val * x; + scalar_t dx_drel = BinaryOp::backward_lhs(rel, in); + scalar_t dx_din = BinaryOp::backward_rhs(rel, in); + scalar_t dout_dy = NaryOp::backward(out, y); + scalar_t dy_dval = x; + scalar_t dy_dx = val; + val_grad += out_grad * dout_dy * dy_dval; + atomicAdd(&relation_grad[layer * dim + d], out_grad * dout_dy * dy_dx * dx_drel); + atomicAdd(&input_grad[col * dim + d], out_grad * dout_dy * dy_dx * dx_din); + } + val_grad = warp_reduce(val_grad); + if (threadIdx.x == 0) + atomicAdd(&value_grad[block_ptr + offset_ptr], val_grad); + } + __syncwarp(); + } +} + +// only relation & input require gradients +template +__global__ +void rspmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind, + const scalar_t *value, const scalar_t *relation, const scalar_t *input, + const scalar_t *output, const scalar_t *output_grad, + scalar_t *relation_grad, scalar_t *input_grad, + int64_t num_row, int64_t nnz, int64_t dim) { + // for best optimization, the following code is compiled with constant warpSize + assert(blockDim.x == warpSize); + + extern __shared__ int64_t buffer[]; + int64_t *col_ind_buf = buffer; + int64_t *layer_ind_buf = col_ind_buf + blockDim.y * warpSize; + scalar_t *value_buf = reinterpret_cast(layer_ind_buf + blockDim.y * warpSize); + col_ind_buf += threadIdx.y * warpSize; + layer_ind_buf += threadIdx.y * warpSize; + value_buf += threadIdx.y * warpSize; + + int64_t row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= num_row) + return; + int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x; + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + + for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) { + int64_t ptr = block_ptr + threadIdx.x; + if (ptr < ptr_end) { + col_ind_buf[threadIdx.x] = col_ind[ptr]; + layer_ind_buf[threadIdx.x] = layer_ind[ptr]; + value_buf[threadIdx.x] = value[ptr]; + } + __syncwarp(); + + int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr; + for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) { + int64_t col = col_ind_buf[offset_ptr]; + int64_t layer = layer_ind_buf[offset_ptr]; + scalar_t val = value_buf[offset_ptr]; +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) { + int64_t d = d_start + i * warpSize; + if (d >= dim) + break; + scalar_t rel = relation[layer * dim + d]; + scalar_t in = input[col * dim + d]; + scalar_t out = output[row * dim + d]; + scalar_t out_grad = output_grad[row * dim + d]; + scalar_t x = BinaryOp::forward(rel, in); + scalar_t y = val * x; + scalar_t dx_drel = BinaryOp::backward_lhs(rel, in); + scalar_t dx_din = BinaryOp::backward_rhs(rel, in); + scalar_t dout_dy = NaryOp::backward(out, y); + scalar_t dy_dx = val; + atomicAdd(&relation_grad[layer * dim + d], out_grad * dout_dy * dy_dx * dx_drel); + atomicAdd(&input_grad[col * dim + d], out_grad * dout_dy * dy_dx * dx_din); + } + } + __syncwarp(); + } +} + +template class NaryOp, template class BinaryOp> +Tensor rspmm_forward_cuda(const SparseTensor &sparse, const Tensor &relation_, const Tensor &input_) { + constexpr const char *fn_name = "rspmm_forward_cuda"; + TensorArg sparse_arg(sparse, "sparse", 1), relation_arg(relation_, "relation", 2), input_arg(input_, "input", 3); + + rspmm_forward_check(fn_name, sparse_arg, relation_arg, input_arg); + checkAllSameGPU(fn_name, {sparse_arg, relation_arg, input_arg}); + + const Tensor relation = relation_.contiguous(); + const Tensor input = input_.contiguous(); + + int64_t nnz = sparse._nnz(); + int64_t dim = input.size(1); + int64_t num_row = sparse.size(0); + Tensor output = at::empty({num_row, dim}, input.options()); + + auto csr = coo2csr3d(sparse); + Tensor row_ptr = std::get<0>(csr); + Tensor col_ind = std::get<1>(csr); + Tensor layer_ind = std::get<2>(csr); + Tensor value = std::get<3>(csr); + + cudaSetDevice(input.get_device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + const int dim_per_block = 32; // warpSize + const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor); + const int row_per_block = kThreadPerBlock / dim_per_block; + const int num_row_block = (num_row + row_per_block - 1) / row_per_block; + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_forward_cuda", [&] { + const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t)); + rspmm_forward_out_cuda, BinaryOp> + <<>>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + layer_ind.data_ptr(), + value.data_ptr(), + relation.data_ptr(), + input.data_ptr(), + output.data_ptr(), + num_row, nnz, dim + ); + }); + + return output; +} + +template class NaryOp, template class BinaryOp> +std::tuple rspmm_backward_cuda( + const SparseTensor &sparse, const Tensor &relation_, const Tensor &input_, const Tensor &output_, + const Tensor &output_grad_) { + constexpr const char *fn_name = "rspmm_backward_cuda"; + TensorArg sparse_arg(sparse, "sparse", 1), relation_arg(relation_, "relation", 2), input_arg(input_, "input", 3), + output_arg(output_, "output", 4), output_grad_arg(output_grad_, "output_grad", 5); + + rspmm_backward_check(fn_name, sparse_arg, relation_arg, input_arg, output_arg, output_grad_arg); + checkAllSameGPU(fn_name, {sparse_arg, relation_arg, input_arg, output_arg, output_grad_arg}); + + const Tensor relation = relation_.contiguous(); + const Tensor input = input_.contiguous(); + const Tensor output = output_.contiguous(); + const Tensor output_grad = output_grad_.contiguous(); + + int64_t nnz = sparse._nnz(); + int64_t dim = input.size(1); + int64_t num_row = sparse.size(0); + Tensor value_grad = at::zeros_like(sparse.values()); + Tensor relation_grad = at::zeros_like(relation); + Tensor input_grad = at::zeros_like(input); + SparseTensor sparse_grad = at::_sparse_coo_tensor_unsafe(sparse.indices(), value_grad, sparse.sizes()); + + auto csr = coo2csr3d(sparse); + Tensor row_ptr = std::get<0>(csr); + Tensor col_ind = std::get<1>(csr); + Tensor layer_ind = std::get<2>(csr); + Tensor value = std::get<3>(csr); + + cudaSetDevice(input.get_device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + const int dim_per_block = 32; // warpSize + const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor); + const int row_per_block = kThreadPerBlock / dim_per_block; + const int num_row_block = (num_row + row_per_block - 1) / row_per_block; + + if (sparse.requires_grad()) + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cuda", [&] { + const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t)); + rspmm_backward_out_cuda, BinaryOp> + <<>>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + layer_ind.data_ptr(), + value.data_ptr(), + relation.data_ptr(), + input.data_ptr(), + output.data_ptr(), + output_grad.data_ptr(), + value_grad.data_ptr(), + relation_grad.data_ptr(), + input_grad.data_ptr(), + num_row, nnz, dim + ); + }); + else + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cuda", [&] { + const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t)); + rspmm_backward_out_cuda, BinaryOp> + <<>>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + layer_ind.data_ptr(), + value.data_ptr(), + relation.data_ptr(), + input.data_ptr(), + output.data_ptr(), + output_grad.data_ptr(), + relation_grad.data_ptr(), + input_grad.data_ptr(), + num_row, nnz, dim + ); + }); + + return std::make_tuple(sparse_grad, relation_grad, input_grad); +} + +#define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \ + Tensor rspmm_##ADD##_##MUL##_forward_cuda( \ + const SparseTensor &sparse, const Tensor &relation, const Tensor &input) { \ + return rspmm_forward_cuda(sparse, relation, input); \ + } + +#define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \ + std::tuple rspmm_##ADD##_##MUL##_backward_cuda( \ + const SparseTensor &sparse, const Tensor &relation, const Tensor &input, const Tensor &output, \ + const Tensor &output_grad) { \ + return rspmm_backward_cuda(sparse, relation, input, output, output_grad); \ + } + +DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul) +DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul) + +DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul) +DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul) + +DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul) +DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul) + +DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd) +DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd) + +DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd) +DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd) + +DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd) +DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd) + +} // namespace at \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/rspmm.h b/build/lib/torchdrug/layers/functional/extension/rspmm.h new file mode 100644 index 00000000..8025f4d7 --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/rspmm.h @@ -0,0 +1,85 @@ +#pragma once + +#include + +#include +#include + +namespace at { + +using namespace at::sparse; + +void rspmm_forward_check(CheckedFrom c, const TensorArg &sparse_arg, const TensorArg &relation_arg, + const TensorArg &input_arg); + +void rspmm_backward_check(CheckedFrom c, const TensorArg &sparse_arg, const TensorArg &relation_arg, + const TensorArg &input_arg, const TensorArg &output_arg, const TensorArg &output_grad_arg); + +std::tuple coo2csr3d(const SparseTensor &sparse); + +SparseTensor csr2coo3d(const Tensor &row_ptr, const Tensor &col_ind, const Tensor &layer_ind, const Tensor &value, + IntArrayRef size); + +Tensor rspmm_add_mul_forward_cpu(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_add_mul_backward_cpu(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_min_mul_forward_cpu(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_min_mul_backward_cpu(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_max_mul_forward_cpu(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_max_mul_backward_cpu(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_add_add_forward_cpu(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_add_add_backward_cpu(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_min_add_forward_cpu(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_min_add_backward_cpu(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_max_add_forward_cpu(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_max_add_backward_cpu(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +#ifdef CUDA_OP +Tensor rspmm_add_mul_forward_cuda(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_add_mul_backward_cuda(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_min_mul_forward_cuda(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_min_mul_backward_cuda(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_max_mul_forward_cuda(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_max_mul_backward_cuda(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_add_add_forward_cuda(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_add_add_backward_cuda(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_min_add_forward_cuda(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_min_add_backward_cuda(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor rspmm_max_add_forward_cuda(const SparseTensor &sparse, const Tensor &relation, const Tensor &input); + +std::tuple rspmm_max_add_backward_cuda(const SparseTensor &sparse, + const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad); +#endif + +} // namespace at \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/spmm.cpp b/build/lib/torchdrug/layers/functional/extension/spmm.cpp new file mode 100644 index 00000000..9a3d4afd --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/spmm.cpp @@ -0,0 +1,283 @@ +#include + +#include + +#include "operator.cuh" +#include "spmm.h" + +namespace at { + +// In PyTorch 1.4.0, parallel_for depends on some functions from at::internal in ATen/Parallel.h +// which are not explicitly included +// This is fixed in some new PyTorch release +using namespace at::internal; + +void spmm_forward_check(CheckedFrom c, const TensorArg &sparse_arg, const TensorArg &input_arg) { + TORCH_CHECK(sparse_arg->sparse_dim() == 2, + "Expected 2-dimensional sparse tensor, but got ", sparse_arg->sparse_dim(), + "-dimensional tensor for ", sparse_arg," (while checking arguments for ", c, ")"); + TORCH_CHECK(sparse_arg->dense_dim() == 0, + "Expected 2-dimensional sparse tensor, but got ", sparse_arg->dense_dim(), + " dense dimensions for", sparse_arg," (while checking arguments for ", c, ")"); + checkDim(c, input_arg, 2); + checkScalarType(c, input_arg, sparse_arg->scalar_type()); + checkSize(c, input_arg, 0, sparse_arg->size(1)); +} + +void spmm_backward_check(CheckedFrom c, const TensorArg &sparse_arg, const TensorArg &input_arg, + const TensorArg &output_arg, const TensorArg &output_grad_arg) { + spmm_forward_check(c, sparse_arg, input_arg); + checkDim(c, output_arg, 2); + checkSameSize(c, output_arg, output_grad_arg); + checkSameType(c, input_arg, output_arg); + checkSameType(c, input_arg, output_grad_arg); + checkSize(c, output_arg, {sparse_arg->size(0), input_arg->size(1)}); +} + +std::tuple coo2csr(const SparseTensor &sparse) { + TORCH_CHECK(sparse.is_coalesced(), "Expect coalesced sparse tensor"); + Tensor index = sparse.indices(); + Tensor row_ind = index.select(0, 0); + Tensor col_ind = index.select(0, 1); + Tensor value = sparse.values(); + // scatter_add is super slow for int64, due to non-hardware atomic operations + // use int32 instead + Tensor nnz_per_row = at::zeros({sparse.size(0)}, row_ind.options().dtype(at::ScalarType::Int)); + nnz_per_row.scatter_add_(0, row_ind, at::ones(row_ind.sizes(), nnz_per_row.options())); + nnz_per_row = nnz_per_row.toType(at::ScalarType::Long); + Tensor row_ptr = nnz_per_row.cumsum(0) - nnz_per_row; + return std::make_tuple(row_ptr, col_ind, value); +} + +SparseTensor csr2coo(const Tensor &row_ptr_, const Tensor &col_ind, const Tensor &value, IntArrayRef size) { + Tensor row_ptr = row_ptr_.masked_select(row_ptr_ < col_ind.size(0)); + // scatter_add is super slow for int64, due to non-hardware atomic operations + // use int32 instead + Tensor row_ind = at::zeros(col_ind.sizes(), col_ind.options().dtype(at::ScalarType::Int)); + row_ind.scatter_add_(0, row_ptr, at::ones(row_ptr.sizes(), row_ind.options())); + row_ind = row_ind.toType(at::ScalarType::Long); + row_ind = row_ind.cumsum(0) - 1; + Tensor index = at::stack({row_ind, col_ind}, 0); + return at::_sparse_coo_tensor_unsafe(index, value, size, value.options().layout(kSparse)); +} + +template +void spmm_forward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const scalar_t *value, + const scalar_t *input, scalar_t *output, + int64_t num_row, int64_t nnz, int64_t dim) { + parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) { + for (int64_t row = row_start; row < row_end; row++) { + for (int64_t d = 0; d < dim; d++) + output[row * dim + d] = NaryOp::zero; + + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) { + int64_t col = col_ind[ptr]; + scalar_t val = value[ptr]; + for (int64_t d = 0; d < dim; d++) { + scalar_t x = BinaryOp::forward(val, input[col * dim + d]); + scalar_t &out = output[row * dim + d]; + out = NaryOp::forward(out, x); + } + } + } + }); +} + +template +void spmm_backward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const scalar_t *value, + const scalar_t *input, const scalar_t *output, const scalar_t *output_grad, + scalar_t *value_grad, scalar_t *input_grad, + int64_t num_row, int64_t nnz, int64_t dim, + std::vector &mutex) { + parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) { + for (int64_t row = row_start; row < row_end; row++) { + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) { + int64_t col = col_ind[ptr]; + scalar_t val = value[ptr]; + scalar_t val_grad = 0; + for (int64_t d = 0; d < dim; d++) { + scalar_t in = input[col * dim + d]; + scalar_t out = output[row * dim + d]; + scalar_t out_grad = output_grad[row * dim + d]; + scalar_t x = BinaryOp::forward(val, in); + scalar_t dx_dval = BinaryOp::backward_lhs(val, in); + scalar_t dx_din = BinaryOp::backward_rhs(val, in); + scalar_t dout_dx = NaryOp::backward(out, x); + val_grad += out_grad * dout_dx * dx_dval; + { + std::lock_guard lock(mutex[col * dim + d]); + input_grad[col * dim + d] += out_grad * dout_dx * dx_din; + } + } + value_grad[ptr] = val_grad; + } + } + }); +} + +template class NaryOp, template class BinaryOp> +Tensor spmm_forward_cpu(const SparseTensor &sparse, const Tensor &input_) { + constexpr const char *fn_name = "spmm_forward_cpu"; + TensorArg sparse_arg(sparse, "sparse", 1), input_arg(input_, "input", 2); + + spmm_forward_check(fn_name, sparse_arg, input_arg); + checkDeviceType(fn_name, {sparse, input_}, kCPU); + + const Tensor input = input_.contiguous(); + + int64_t nnz = sparse._nnz(); + int64_t dim = input.size(1); + int64_t num_row = sparse.size(0); + Tensor output = at::empty({num_row, dim}, input.options()); + + auto csr = coo2csr(sparse); + Tensor row_ptr = std::get<0>(csr); + Tensor col_ind = std::get<1>(csr); + Tensor value = std::get<2>(csr); + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "spmm_forward_cpu", [&] { + spmm_forward_out_cpu, BinaryOp>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + value.data_ptr(), + input.data_ptr(), + output.data_ptr(), + num_row, nnz, dim + ); + }); + + return output; +} + +template class NaryOp, template class BinaryOp> +std::tuple spmm_backward_cpu( + const SparseTensor &sparse, const Tensor &input_, const Tensor &output_, const Tensor &output_grad_) { + constexpr const char *fn_name = "spmm_backward_cpu"; + TensorArg sparse_arg(sparse, "sparse", 1), input_arg(input_, "input", 2), output_arg(output_, "output", 3), + output_grad_arg(output_grad_, "output_grad", 4); + + spmm_backward_check(fn_name, sparse_arg, input_arg, output_arg, output_grad_arg); + checkDeviceType(fn_name, {sparse, input_, output_, output_grad_}, kCPU); + + const Tensor input = input_.contiguous(); + const Tensor output = output_.contiguous(); + const Tensor output_grad = output_grad_.contiguous(); + + int64_t nnz = sparse._nnz(); + int64_t dim = input.size(1); + int64_t num_row = sparse.size(0); + Tensor value_grad = at::zeros_like(sparse.values()); + Tensor input_grad = at::zeros_like(input); + SparseTensor sparse_grad = at::_sparse_coo_tensor_unsafe(sparse.indices(), value_grad, sparse.sizes()); + + auto csr = coo2csr(sparse); + Tensor row_ptr = std::get<0>(csr).contiguous(); + Tensor col_ind = std::get<1>(csr).contiguous(); + Tensor value = std::get<2>(csr).contiguous(); + std::vector mutex(input.numel()); + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "spmm_backward_cpu", [&] { + spmm_backward_out_cpu, BinaryOp>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + value.data_ptr(), + input.data_ptr(), + output.data_ptr(), + output_grad.data_ptr(), + value_grad.data_ptr(), + input_grad.data_ptr(), + num_row, nnz, dim, + mutex + ); + }); + + return std::make_tuple(sparse_grad, input_grad); +} + +#define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \ + Tensor spmm_##ADD##_##MUL##_forward_cpu(const SparseTensor &sparse, const Tensor &input) { \ + return spmm_forward_cpu(sparse, input); \ + } + +#define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \ + std::tuple spmm_##ADD##_##MUL##_backward_cpu( \ + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad) { \ + return spmm_backward_cpu(sparse, input, output, output_grad); \ + } + +DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul) +DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul) + +DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul) +DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul) + +DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul) +DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul) + +DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd) +DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd) + +DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd) +DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd) + +DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd) +DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd) + +} // namespace at + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("spmm_add_mul_forward_cpu", &at::spmm_add_mul_forward_cpu); + m.def("spmm_add_mul_backward_cpu", &at::spmm_add_mul_backward_cpu); + m.def("spmm_min_mul_forward_cpu", &at::spmm_min_mul_forward_cpu); + m.def("spmm_min_mul_backward_cpu", &at::spmm_min_mul_backward_cpu); + m.def("spmm_max_mul_forward_cpu", &at::spmm_max_mul_forward_cpu); + m.def("spmm_max_mul_backward_cpu", &at::spmm_max_mul_backward_cpu); + m.def("spmm_add_add_forward_cpu", &at::spmm_add_add_forward_cpu); + m.def("spmm_add_add_backward_cpu", &at::spmm_add_add_backward_cpu); + m.def("spmm_min_add_forward_cpu", &at::spmm_min_add_forward_cpu); + m.def("spmm_min_add_backward_cpu", &at::spmm_min_add_backward_cpu); + m.def("spmm_max_add_forward_cpu", &at::spmm_max_add_forward_cpu); + m.def("spmm_max_add_backward_cpu", &at::spmm_max_add_backward_cpu); + m.def("rspmm_add_mul_forward_cpu", &at::rspmm_add_mul_forward_cpu); + m.def("rspmm_add_mul_backward_cpu", &at::rspmm_add_mul_backward_cpu); + m.def("rspmm_min_mul_forward_cpu", &at::rspmm_min_mul_forward_cpu); + m.def("rspmm_min_mul_backward_cpu", &at::rspmm_min_mul_backward_cpu); + m.def("rspmm_max_mul_forward_cpu", &at::rspmm_max_mul_forward_cpu); + m.def("rspmm_max_mul_backward_cpu", &at::rspmm_max_mul_backward_cpu); + m.def("rspmm_add_add_forward_cpu", &at::rspmm_add_add_forward_cpu); + m.def("rspmm_add_add_backward_cpu", &at::rspmm_add_add_backward_cpu); + m.def("rspmm_min_add_forward_cpu", &at::rspmm_min_add_forward_cpu); + m.def("rspmm_min_add_backward_cpu", &at::rspmm_min_add_backward_cpu); + m.def("rspmm_max_add_forward_cpu", &at::rspmm_max_add_forward_cpu); + m.def("rspmm_max_add_backward_cpu", &at::rspmm_max_add_backward_cpu); +#ifdef CUDA_OP + m.def("spmm_add_mul_forward_cuda", &at::spmm_add_mul_forward_cuda); + m.def("spmm_add_mul_backward_cuda", &at::spmm_add_mul_backward_cuda); + m.def("spmm_min_mul_forward_cuda", &at::spmm_min_mul_forward_cuda); + m.def("spmm_min_mul_backward_cuda", &at::spmm_min_mul_backward_cuda); + m.def("spmm_max_mul_forward_cuda", &at::spmm_max_mul_forward_cuda); + m.def("spmm_max_mul_backward_cuda", &at::spmm_max_mul_backward_cuda); + m.def("spmm_add_add_forward_cuda", &at::spmm_add_add_forward_cuda); + m.def("spmm_add_add_backward_cuda", &at::spmm_add_add_backward_cuda); + m.def("spmm_min_add_forward_cuda", &at::spmm_min_add_forward_cuda); + m.def("spmm_min_add_backward_cuda", &at::spmm_min_add_backward_cuda); + m.def("spmm_max_add_forward_cuda", &at::spmm_max_add_forward_cuda); + m.def("spmm_max_add_backward_cuda", &at::spmm_max_add_backward_cuda); + m.def("rspmm_add_mul_forward_cuda", &at::rspmm_add_mul_forward_cuda); + m.def("rspmm_add_mul_backward_cuda", &at::rspmm_add_mul_backward_cuda); + m.def("rspmm_min_mul_forward_cuda", &at::rspmm_min_mul_forward_cuda); + m.def("rspmm_min_mul_backward_cuda", &at::rspmm_min_mul_backward_cuda); + m.def("rspmm_max_mul_forward_cuda", &at::rspmm_max_mul_forward_cuda); + m.def("rspmm_max_mul_backward_cuda", &at::rspmm_max_mul_backward_cuda); + m.def("rspmm_add_add_forward_cuda", &at::rspmm_add_add_forward_cuda); + m.def("rspmm_add_add_backward_cuda", &at::rspmm_add_add_backward_cuda); + m.def("rspmm_min_add_forward_cuda", &at::rspmm_min_add_forward_cuda); + m.def("rspmm_min_add_backward_cuda", &at::rspmm_min_add_backward_cuda); + m.def("rspmm_max_add_forward_cuda", &at::rspmm_max_add_forward_cuda); + m.def("rspmm_max_add_backward_cuda", &at::rspmm_max_add_backward_cuda); +#endif +} \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/spmm.cu b/build/lib/torchdrug/layers/functional/extension/spmm.cu new file mode 100644 index 00000000..7abd3ef5 --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/spmm.cu @@ -0,0 +1,333 @@ +#include +#include + +#include "util.cuh" +#include "operator.cuh" +#include "spmm.h" + +// Memory & time efficient implementation of generalized spmm +// Much of the code is inspired by GE-SpMM +// https://github.com/hgyhungry/ge-spmm + +namespace at { + +namespace { + +const int kCoarseningFactor = 2; +const int kThreadPerBlock = 256; + +} // namespace anonymous + +template +__global__ +void spmm_forward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const scalar_t *value, + const scalar_t *input, scalar_t *output, + int64_t num_row, int64_t nnz, int64_t dim) { + // for best optimization, the following code is compiled with constant warpSize + assert(blockDim.x == warpSize); + + extern __shared__ int64_t buffer[]; + int64_t *col_ind_buf = buffer; + scalar_t *value_buf = reinterpret_cast(col_ind_buf + blockDim.y * warpSize); + col_ind_buf += threadIdx.y * warpSize; + value_buf += threadIdx.y * warpSize; + + int64_t row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= num_row) + return; + int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x; + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + scalar_t out[kCoarseningFactor]; +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) + out[i] = NaryOp::zero; + + for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) { + int64_t ptr = block_ptr + threadIdx.x; + if (ptr < ptr_end) { + col_ind_buf[threadIdx.x] = col_ind[ptr]; + value_buf[threadIdx.x] = value[ptr]; + } + __syncwarp(); + + int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr; + for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) { + int64_t col = col_ind_buf[offset_ptr]; + scalar_t val = value_buf[offset_ptr]; +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) { + int64_t d = d_start + i * warpSize; + if (d >= dim) + break; + scalar_t x = BinaryOp::forward(val, input[col * dim + d]); + out[i] = NaryOp::forward(out[i], x); + } + } + __syncwarp(); + } + +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) { + int64_t d = d_start + i * warpSize; + if (d >= dim) + break; + output[row * dim + d] = out[i]; + } +} + +// both sparse and input require gradients +template +__global__ +void spmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const scalar_t *value, + const scalar_t *input, const scalar_t *output, const scalar_t *output_grad, + scalar_t *value_grad, scalar_t *input_grad, + int64_t num_row, int64_t nnz, int64_t dim) { + // for best optimization, the following code is compiled with constant warpSize + assert(blockDim.x == warpSize); + + extern __shared__ int64_t buffer[]; + int64_t *col_ind_buf = buffer; + scalar_t *value_buf = reinterpret_cast(col_ind_buf + blockDim.y * warpSize); + col_ind_buf += threadIdx.y * warpSize; + value_buf += threadIdx.y * warpSize; + + int64_t row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= num_row) + return; + int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x; + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + + for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) { + int64_t ptr = block_ptr + threadIdx.x; + if (ptr < ptr_end) { + col_ind_buf[threadIdx.x] = col_ind[ptr]; + value_buf[threadIdx.x] = value[ptr]; + } + __syncwarp(); + + int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr; + for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) { + int64_t col = col_ind_buf[offset_ptr]; + scalar_t val = value_buf[offset_ptr]; + scalar_t val_grad = 0; +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) { + int64_t d = d_start + i * warpSize; + if (d >= dim) + break; + scalar_t in = input[col * dim + d]; + scalar_t out = output[row * dim + d]; + scalar_t out_grad = output_grad[row * dim + d]; + scalar_t x = BinaryOp::forward(val, in); + scalar_t dx_dval = BinaryOp::backward_lhs(val, in); + scalar_t dx_din = BinaryOp::backward_rhs(val, in); + scalar_t dout_dx = NaryOp::backward(out, x); + val_grad += out_grad * dout_dx * dx_dval; + atomicAdd(&input_grad[col * dim + d], out_grad * dout_dx * dx_din); + } + val_grad = warp_reduce(val_grad); + if (threadIdx.x == 0) + atomicAdd(&value_grad[block_ptr + offset_ptr], val_grad); + } + __syncwarp(); + } +} + +// only input requires gradients +template +__global__ +void spmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const scalar_t *value, + const scalar_t *input, const scalar_t *output, const scalar_t *output_grad, + scalar_t *input_grad, + int64_t num_row, int64_t nnz, int64_t dim) { + // for best optimization, the following code is compiled with constant warpSize + assert(blockDim.x == warpSize); + + extern __shared__ int64_t buffer[]; + int64_t *col_ind_buf = buffer; + scalar_t *value_buf = reinterpret_cast(col_ind_buf + blockDim.y * warpSize); + col_ind_buf += threadIdx.y * warpSize; + value_buf += threadIdx.y * warpSize; + + int64_t row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= num_row) + return; + int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x; + int64_t ptr_start = row_ptr[row]; + int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz; + + for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) { + int64_t ptr = block_ptr + threadIdx.x; + if (ptr < ptr_end) { + col_ind_buf[threadIdx.x] = col_ind[ptr]; + value_buf[threadIdx.x] = value[ptr]; + } + __syncwarp(); + + int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr; + for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) { + int64_t col = col_ind_buf[offset_ptr]; + scalar_t val = value_buf[offset_ptr]; +#pragma unroll + for (int64_t i = 0; i < kCoarseningFactor; i++) { + int64_t d = d_start + i * warpSize; + if (d >= dim) + break; + scalar_t in = input[col * dim + d]; + scalar_t out = output[row * dim + d]; + scalar_t out_grad = output_grad[row * dim + d]; + scalar_t x = BinaryOp::forward(val, in); + scalar_t dx_din = BinaryOp::backward_rhs(val, in); + scalar_t dout_dx = NaryOp::backward(out, x); + atomicAdd(&input_grad[col * dim + d], out_grad * dout_dx * dx_din); + } + } + __syncwarp(); + } +} + +template class NaryOp, template class BinaryOp> +Tensor spmm_forward_cuda(const SparseTensor &sparse, const Tensor &input_) { + constexpr const char *fn_name = "spmm_forward_cuda"; + TensorArg sparse_arg(sparse, "sparse", 1), input_arg(input_, "input", 2); + + spmm_forward_check(fn_name, sparse_arg, input_arg); + checkAllSameGPU(fn_name, {sparse_arg, input_arg}); + + const Tensor input = input_.contiguous(); + + int64_t nnz = sparse._nnz(); + int64_t dim = input.size(1); + int64_t num_row = sparse.size(0); + Tensor output = at::empty({num_row, dim}, input.options()); + + auto csr = coo2csr(sparse); + Tensor row_ptr = std::get<0>(csr); + Tensor col_ind = std::get<1>(csr); + Tensor value = std::get<2>(csr); + + cudaSetDevice(input.get_device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + const int dim_per_block = 32; // warpSize + const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor); + const int row_per_block = kThreadPerBlock / dim_per_block; + const int num_row_block = (num_row + row_per_block - 1) / row_per_block; + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "spmm_forward_cuda", [&] { + const int memory_size = kThreadPerBlock * (sizeof(int64_t) + sizeof(scalar_t)); + spmm_forward_out_cuda, BinaryOp> + <<>>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + value.data_ptr(), + input.data_ptr(), + output.data_ptr(), + num_row, nnz, dim + ); + }); + + return output; +} + +template class NaryOp, template class BinaryOp> +std::tuple spmm_backward_cuda( + const SparseTensor &sparse, const Tensor &input_, const Tensor &output_, const Tensor &output_grad_) { + constexpr const char *fn_name = "spmm_backward_cuda"; + TensorArg sparse_arg(sparse, "sparse", 1), input_arg(input_, "input", 2), output_arg(output_, "output", 3), + output_grad_arg(output_grad_, "output_grad", 4); + + spmm_backward_check(fn_name, sparse_arg, input_arg, output_arg, output_grad_arg); + checkAllSameGPU(fn_name, {sparse_arg, input_arg, output_arg, output_grad_arg}); + + const Tensor input = input_.contiguous(); + const Tensor output = output_.contiguous(); + const Tensor output_grad = output_grad_.contiguous(); + + int64_t nnz = sparse._nnz(); + int64_t dim = input.size(1); + int64_t num_row = sparse.size(0); + Tensor value_grad = at::zeros_like(sparse.values()); + Tensor input_grad = at::zeros_like(input); + SparseTensor sparse_grad = at::_sparse_coo_tensor_unsafe(sparse.indices(), value_grad, sparse.sizes()); + + auto csr = coo2csr(sparse); + Tensor row_ptr = std::get<0>(csr).contiguous(); + Tensor col_ind = std::get<1>(csr).contiguous(); + Tensor value = std::get<2>(csr).contiguous(); + + cudaSetDevice(input.get_device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + const int dim_per_block = 32; // warpSize + const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor); + const int row_per_block = kThreadPerBlock / dim_per_block; + const int num_row_block = (num_row + row_per_block - 1) / row_per_block; + + if (sparse.requires_grad()) + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "spmm_backward_cuda", [&] { + const int memory_size = kThreadPerBlock * (sizeof(int64_t) + sizeof(scalar_t)); + spmm_backward_out_cuda, BinaryOp> + <<>>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + value.data_ptr(), + input.data_ptr(), + output.data_ptr(), + output_grad.data_ptr(), + value_grad.data_ptr(), + input_grad.data_ptr(), + num_row, nnz, dim + ); + }); + else + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "spmm_backward_cuda", [&] { + const int memory_size = kThreadPerBlock * (sizeof(int64_t) + sizeof(scalar_t)); + spmm_backward_out_cuda, BinaryOp> + <<>>( + row_ptr.data_ptr(), + col_ind.data_ptr(), + value.data_ptr(), + input.data_ptr(), + output.data_ptr(), + output_grad.data_ptr(), + input_grad.data_ptr(), + num_row, nnz, dim + ); + }); + + return std::make_tuple(sparse_grad, input_grad); +} + +#define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \ + Tensor spmm_##ADD##_##MUL##_forward_cuda(const SparseTensor &sparse, const Tensor &input) { \ + return spmm_forward_cuda(sparse, input); \ + } + +#define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \ + std::tuple spmm_##ADD##_##MUL##_backward_cuda( \ + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad) { \ + return spmm_backward_cuda(sparse, input, output, output_grad); \ + } + +DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul) +DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul) + +DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul) +DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul) + +DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul) +DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul) + +DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd) +DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd) + +DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd) +DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd) + +DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd) +DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd) + +} // namespace at \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/spmm.h b/build/lib/torchdrug/layers/functional/extension/spmm.h new file mode 100644 index 00000000..905e7836 --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/spmm.h @@ -0,0 +1,85 @@ +#pragma once + +#include + +#include +#include + +#include "rspmm.h" + +namespace at { + +using namespace at::sparse; + +void spmm_forward_check(CheckedFrom c, const TensorArg &sparse_arg, const TensorArg &input_arg); + +void spmm_backward_check(CheckedFrom c, const TensorArg &sparse_arg, const TensorArg &input_arg, + const TensorArg &output_arg, const TensorArg &output_grad_arg); + +std::tuple coo2csr(const SparseTensor &sparse); + +SparseTensor csr2coo(const Tensor &row_ptr_, const Tensor &col_ind, const Tensor &value, IntArrayRef size); + +Tensor spmm_add_mul_forward_cpu(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_add_mul_backward_cpu( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_min_mul_forward_cpu(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_min_mul_backward_cpu( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_max_mul_forward_cpu(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_max_mul_backward_cpu( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_add_add_forward_cpu(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_add_add_backward_cpu( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_min_add_forward_cpu(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_min_add_backward_cpu( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_max_add_forward_cpu(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_max_add_backward_cpu( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +#ifdef CUDA_OP +Tensor spmm_add_mul_forward_cuda(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_add_mul_backward_cuda( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_min_mul_forward_cuda(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_min_mul_backward_cuda( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_max_mul_forward_cuda(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_max_mul_backward_cuda( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_add_add_forward_cuda(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_add_add_backward_cuda( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_min_add_forward_cuda(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_min_add_backward_cuda( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); + +Tensor spmm_max_add_forward_cuda(const SparseTensor &sparse, const Tensor &input); + +std::tuple spmm_max_add_backward_cuda( + const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad); +#endif + +} // namespace at \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/extension/util.cuh b/build/lib/torchdrug/layers/functional/extension/util.cuh new file mode 100644 index 00000000..e014f2dc --- /dev/null +++ b/build/lib/torchdrug/layers/functional/extension/util.cuh @@ -0,0 +1,28 @@ +#pragma once + +namespace at { + +const unsigned kFullMask = 0xFFFFFFFF; + +template +__device__ scalar_t warp_reduce(scalar_t value) { +#pragma unroll + for (int delta = 1; delta < warpSize; delta *= 2) +#if __CUDACC_VER_MAJOR__ >= 9 + value += __shfl_down_sync(kFullMask, value, delta); +#else + value += __shfl_down(value, delta); +#endif + return value; +} + +template +__device__ scalar_t warp_broadcast(scalar_t value, int lane_id) { +#if __CUDACC_VER_MAJOR__ >= 9 + return __shfl_sync(kFullMask, value, lane_id); +#else + return __shfl(value, lane_id); +#endif +} + +} // namespace at \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/functional.py b/build/lib/torchdrug/layers/functional/functional.py new file mode 100644 index 00000000..bbb989cc --- /dev/null +++ b/build/lib/torchdrug/layers/functional/functional.py @@ -0,0 +1,529 @@ +import torch +from torch_scatter import scatter_add, scatter_mean, scatter_max +from torch_scatter.composite import scatter_log_softmax, scatter_softmax +from torch.nn import functional as F + + +def multinomial(input, num_sample, replacement=False): + """ + Fast multinomial sampling. This is the default implementation in PyTorch v1.6.0+. + + Parameters: + input (Tensor): unnormalized distribution + num_sample (int): number of samples + replacement (bool, optional): sample with replacement or not + """ + if replacement: + return torch.multinomial(input, num_sample, replacement) + + rand = torch.rand_like(input).log() / input + samples = rand.topk(num_sample).indices + return samples + + +def masked_mean(input, mask, dim=None, keepdim=False): + """ + Masked mean of a tensor. + + Parameters: + input (Tensor): input tensor + mask (BoolTensor): mask tensor + dim (int or tuple of int, optional): dimension to reduce + keepdim (bool, optional): whether retain ``dim`` or not + """ + input = input.masked_scatter(~mask, torch.zeros_like(input)) # safe with nan + if dim is None: + return input.sum() / mask.sum().clamp(1) + return input.sum(dim, keepdim=keepdim) / mask.sum(dim, keepdim=keepdim).clamp(1) + + +def mean_with_nan(input, dim=None, keepdim=False): + """ + Mean of a tensor. Ignore all nan values. + + Parameters: + input (Tensor): input tensor + dim (int or tuple of int, optional): dimension to reduce + keepdim (bool, optional): whether retain ``dim`` or not + """ + mask = ~torch.isnan(input) + return masked_mean(input, mask, dim, keepdim) + + +def shifted_softplus(input): + """ + Shifted softplus function. + + Parameters: + input (Tensor): input tensor + """ + return F.softplus(input) - F.softplus(torch.zeros(1, device=input.device)) + + +def multi_slice(starts, ends): + """ + Compute the union of indexes in multiple slices. + + Example:: + + >>> mask = multi_slice(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6) + >>> assert (mask == torch.tensor([0, 1, 2, 4, 5]).all() + + Parameters: + starts (LongTensor): start indexes of slices + ends (LongTensor): end indexes of slices + """ + values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)]) + slices = torch.cat([starts, ends]) + slices, order = slices.sort() + values = values[order] + depth = values.cumsum(0) + valid = ((values == 1) & (depth == 1)) | ((values == -1) & (depth == 0)) + slices = slices[valid] + + starts, ends = slices.view(-1, 2).t() + size = ends - starts + indexes = variadic_arange(size) + indexes = indexes + starts.repeat_interleave(size) + return indexes + + +def multi_slice_mask(starts, ends, length): + """ + Compute the union of multiple slices into a binary mask. + + Example:: + + >>> mask = multi_slice_mask(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6) + >>> assert (mask == torch.tensor([1, 1, 1, 0, 1, 1])).all() + + Parameters: + starts (LongTensor): start indexes of slices + ends (LongTensor): end indexes of slices + length (int): length of mask + """ + values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)]) + slices = torch.cat([starts, ends]) + if slices.numel(): + assert slices.min() >= 0 and slices.max() <= length + mask = scatter_add(values, slices, dim=0, dim_size=length + 1)[:-1] + mask = mask.cumsum(0).bool() + return mask + + +def as_mask(indexes, length): + """ + Convert indexes into a binary mask. + + Parameters: + indexes (LongTensor): positive indexes + length (int): maximal possible value of indexes + """ + mask = torch.zeros(length, dtype=torch.bool, device=indexes.device) + mask[indexes] = 1 + return mask + + +def _extend(data, size, input, input_size): + """ + Extend variadic-sized data with variadic-sized input. + This is a variadic variant of ``torch.cat([data, input], dim=-1)``. + + Example:: + + >>> data = torch.tensor([0, 1, 2, 3, 4]) + >>> size = torch.tensor([3, 2]) + >>> input = torch.tensor([-1, -2, -3]) + >>> input_size = torch.tensor([1, 2]) + >>> new_data, new_size = _extend(data, size, input, input_size) + >>> assert (new_data == torch.tensor([0, 1, 2, -1, 3, 4, -2, -3])).all() + >>> assert (new_size == torch.tensor([4, 4])).all() + + Parameters: + data (Tensor): variadic data + size (LongTensor): size of data + input (Tensor): variadic input + input_size (LongTensor): size of input + + Returns: + (Tensor, LongTensor): output data, output size + """ + new_size = size + input_size + new_cum_size = new_size.cumsum(0) + new_data = torch.zeros(new_cum_size[-1], *data.shape[1:], dtype=data.dtype, device=data.device) + starts = new_cum_size - new_size + ends = starts + size + index = multi_slice_mask(starts, ends, new_cum_size[-1]) + new_data[index] = data + new_data[~index] = input + return new_data, new_size + + +def variadic_sum(input, size): + """ + Compute sum over sets with variadic sizes. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + """ + index2sample = torch.repeat_interleave(size) + index2sample = index2sample.view([-1] + [1] * (input.ndim - 1)) + index2sample = index2sample.expand_as(input) + + value = scatter_add(input, index2sample, dim=0) + return value + + +def variadic_mean(input, size): + """ + Compute mean over sets with variadic sizes. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + """ + index2sample = torch.repeat_interleave(size) + index2sample = index2sample.view([-1] + [1] * (input.ndim - 1)) + index2sample = index2sample.expand_as(input) + + value = scatter_mean(input, index2sample, dim=0) + return value + + +def variadic_max(input, size): + """ + Compute max over sets with variadic sizes. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + + Returns + (Tensor, LongTensor): max values and indexes + """ + index2sample = torch.repeat_interleave(size) + index2sample = index2sample.view([-1] + [1] * (input.ndim - 1)) + index2sample = index2sample.expand_as(input) + + value, index = scatter_max(input, index2sample, dim=0) + index = index + (size - size.cumsum(0)).view([-1] + [1] * (index.ndim - 1)) + return value, index + + +def variadic_log_softmax(input, size): + """ + Compute log softmax over categories with variadic sizes. + + Suppose there are :math:`N` samples, and the numbers of categories in all samples are summed to :math:`B`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): number of categories of shape :math:`(N,)` + """ + index2sample = torch.repeat_interleave(size) + index2sample = index2sample.view([-1] + [1] * (input.ndim - 1)) + index2sample = index2sample.expand_as(input) + + log_likelihood = scatter_log_softmax(input, index2sample, dim=0) + return log_likelihood + + +def variadic_softmax(input, size): + """ + Compute softmax over categories with variadic sizes. + + Suppose there are :math:`N` samples, and the numbers of categories in all samples are summed to :math:`B`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): number of categories of shape :math:`(N,)` + """ + index2sample = torch.repeat_interleave(size) + index2sample = index2sample.view([-1] + [1] * (input.ndim - 1)) + index2sample = index2sample.expand_as(input) + + log_likelihood = scatter_softmax(input, index2sample, dim=0) + return log_likelihood + + +def variadic_cross_entropy(input, target, size, reduction="mean"): + """ + Compute cross entropy loss over categories with variadic sizes. + + Suppose there are :math:`N` samples, and the numbers of categories in all samples are summed to :math:`B`. + + Parameters: + input (Tensor): prediction of shape :math:`(B, ...)` + target (Tensor): target of shape :math:`(N, ...)`. Each target is a relative index in a sample. + size (LongTensor): number of categories of shape :math:`(N,)` + reduction (string, optional): reduction to apply to the output. + Available reductions are ``none``, ``sum`` and ``mean``. + """ + index2sample = torch.repeat_interleave(size) + index2sample = index2sample.view([-1] + [1] * (input.ndim - 1)) + index2sample = index2sample.expand_as(input) + + log_likelihood = scatter_log_softmax(input, index2sample, dim=0) + size = size.view([-1] + [1] * (input.ndim - 1)) + assert (target >= 0).all() and (target < size).all() + target_index = target + size.cumsum(0) - size + loss = -log_likelihood.gather(0, target_index) + if reduction == "mean": + return loss.mean() + elif reduction == "sum": + return loss.sum() + elif reduction == "none": + return loss + else: + raise ValueError("Unknown reduction `%s`" % reduction) + + +def variadic_topk(input, size, k, largest=True): + """ + Compute the :math:`k` largest elements over sets with variadic sizes. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + If any set has less than than :math:`k` elements, the size-th largest element will be + repeated to pad the output to :math:`k`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + k (int or LongTensor): the k in "top-k". Can be a fixed value for all sets, + or different values for different sets of shape :math:`(N,)`. + largest (bool, optional): return largest or smallest elements + + Returns + (Tensor, LongTensor): top-k values and indexes + """ + index2graph = torch.repeat_interleave(size) + index2graph = index2graph.view([-1] + [1] * (input.ndim - 1)) + + mask = ~torch.isinf(input) + max = input[mask].max().item() + min = input[mask].min().item() + abs_max = input[mask].abs().max().item() + # special case: max = min + gap = max - min + abs_max * 1e-6 + safe_input = input.clamp(min - gap, max + gap) + offset = gap * 4 + if largest: + offset = -offset + input_ext = safe_input + offset * index2graph + index_ext = input_ext.argsort(dim=0, descending=largest) + if isinstance(k, torch.Tensor) and k.shape == size.shape: + num_actual = torch.min(size, k) + else: + num_actual = size.clamp(max=k) + num_padding = k - num_actual + starts = size.cumsum(0) - size + ends = starts + num_actual + mask = multi_slice_mask(starts, ends, len(index_ext)).nonzero().flatten() + + if (num_padding > 0).any(): + # special case: size < k, pad with the last valid index + padding = ends - 1 + padding2graph = torch.repeat_interleave(num_padding) + mask = _extend(mask, num_actual, padding[padding2graph], num_padding)[0] + + index = index_ext[mask] # (N * k, ...) + value = input.gather(0, index) + if isinstance(k, torch.Tensor) and k.shape == size.shape: + value = value.view(-1, *input.shape[1:]) + index = index.view(-1, *input.shape[1:]) + index = index - (size.cumsum(0) - size).repeat_interleave(k).view([-1] + [1] * (index.ndim - 1)) + else: + value = value.view(-1, k, *input.shape[1:]) + index = index.view(-1, k, *input.shape[1:]) + index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1)) + + return value, index + + +def variadic_sort(input, size, descending=False): + """ + Sort elements in sets with variadic sizes. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + descending (bool, optional): return ascending or descending order + + Returns + (Tensor, LongTensor): sorted values and indexes + """ + index2sample = torch.repeat_interleave(size) + index2sample = index2sample.view([-1] + [1] * (input.ndim - 1)) + + mask = ~torch.isinf(input) + max = input[mask].max().item() + min = input[mask].min().item() + abs_max = input[mask].abs().max().item() + # special case: max = min + gap = max - min + abs_max * 1e-6 + safe_input = input.clamp(min - gap, max + gap) + offset = gap * 4 + if descending: + offset = -offset + input_ext = safe_input + offset * index2sample + index = input_ext.argsort(dim=0, descending=descending) + value = input.gather(0, index) + index = index - (size.cumsum(0) - size)[index2sample] + return value, index + + +def variadic_arange(size): + """ + Return a 1-D tensor that contains integer intervals of variadic sizes. + This is a variadic variant of ``torch.arange(stop).expand(batch_size, -1)``. + + Suppose there are :math:`N` intervals. + + Parameters: + size (LongTensor): size of intervals of shape :math:`(N,)` + """ + starts = size.cumsum(0) - size + + range = torch.arange(size.sum(), device=size.device) + range = range - starts.repeat_interleave(size) + return range + + +def variadic_randperm(size): + """ + Return random permutations for sets with variadic sizes. + The ``i``-th permutation contains integers from 0 to ``size[i] - 1``. + + Suppose there are :math:`N` sets. + + Parameters: + size (LongTensor): size of sets of shape :math:`(N,)` + device (torch.device, optional): device of the tensor + """ + rand = torch.rand(size.sum(), device=size.device) + perm = variadic_sort(rand, size)[1] + return perm + + +def variadic_sample(input, size, num_sample): + """ + Draw samples with replacement from sets with variadic sizes. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + num_sample (int): number of samples to draw from each set + """ + rand = torch.rand(len(size), num_sample, device=size.device) + index = (rand * size.unsqueeze(-1)).long() + index = index + (size.cumsum(0) - size).unsqueeze(-1) + sample = input[index] + return sample + + +def variadic_meshgrid(input1, size1, input2, size2): + """ + Compute the Cartesian product for two batches of sets with variadic sizes. + + Suppose there are :math:`N` sets in each input, + and the sizes of all sets are summed to :math:`B_1` and :math:`B_2` respectively. + + Parameters: + input1 (Tensor): input of shape :math:`(B_1, ...)` + size1 (LongTensor): size of :attr:`input1` of shape :math:`(N,)` + input2 (Tensor): input of shape :math:`(B_2, ...)` + size2 (LongTensor): size of :attr:`input2` of shape :math:`(N,)` + + Returns + (Tensor, Tensor): the first and the second elements in the Cartesian product + """ + grid_size = size1 * size2 + local_index = variadic_arange(grid_size) + local_inner_size = size2.repeat_interleave(grid_size) + offset1 = (size1.cumsum(0) - size1).repeat_interleave(grid_size) + offset2 = (size2.cumsum(0) - size2).repeat_interleave(grid_size) + index1 = torch.div(local_index, local_inner_size, rounding_mode="floor") + offset1 + index2 = local_index % local_inner_size + offset2 + return input1[index1], input2[index2] + + +def variadic_to_padded(input, size, value=0): + """ + Convert a variadic tensor to a padded tensor. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + value (scalar): fill value for padding + + Returns: + (Tensor, BoolTensor): padded tensor and mask + """ + num_sample = len(size) + max_size = size.max() + starts = torch.arange(num_sample, device=size.device) * max_size + ends = starts + size + mask = multi_slice_mask(starts, ends, num_sample * max_size) + mask = mask.view(num_sample, max_size) + shape = (num_sample, max_size) + input.shape[1:] + padded = torch.full(shape, value, dtype=input.dtype, device=size.device) + padded[mask] = input + return padded, mask + + +def padded_to_variadic(padded, size): + """ + Convert a padded tensor to a variadic tensor. + + Parameters: + padded (Tensor): padded tensor of shape :math:`(N, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + """ + num_sample, max_size = padded.shape[:2] + starts = torch.arange(num_sample, device=size.device) * max_size + ends = starts + size + mask = multi_slice_mask(starts, ends, num_sample * max_size) + mask = mask.view(num_sample, max_size) + return padded[mask] + + +def one_hot(index, size): + """ + Expand indexes into one-hot vectors. + + Parameters: + index (Tensor): index + size (int): size of the one-hot dimension + """ + shape = list(index.shape) + [size] + result = torch.zeros(shape, device=index.device) + if index.numel(): + assert index.min() >= 0 + assert index.max() < size + result.scatter_(-1, index.unsqueeze(-1), 1) + return result + + +def clipped_policy_gradient_objective(policy, agent, reward, eps=0.2): + ratio = (policy - agent.detach()).exp() + ratio = ratio.clamp(-10, 10) + objective = torch.min(ratio * reward, ratio.clamp(1 - eps, 1 + eps) * reward) + return objective + + +def policy_gradient_objective(policy, reward): + return policy * reward \ No newline at end of file diff --git a/build/lib/torchdrug/layers/functional/spmm.py b/build/lib/torchdrug/layers/functional/spmm.py new file mode 100644 index 00000000..18dc29bb --- /dev/null +++ b/build/lib/torchdrug/layers/functional/spmm.py @@ -0,0 +1,378 @@ +import os +import sys + +import torch +from torch import autograd + +from torchdrug import utils + +module = sys.modules[__name__] + +path = os.path.join(os.path.dirname(__file__), "extension") +spmm = utils.load_extension("spmm", [os.path.join(path, "spmm.cpp"), os.path.join(path, "rspmm.cpp"), + os.path.join(path, "spmm.cu"), os.path.join(path, "rspmm.cu")]) + + +class SPMMAddMulFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.spmm_add_mul_forward_cuda + else: + forward = spmm.spmm_add_mul_forward_cpu + output = forward(sparse, input) + ctx.save_for_backward(sparse, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.spmm_add_mul_backward_cuda + else: + backward = spmm.spmm_add_mul_backward_cpu + sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, input_grad + + +class SPMMMinMulFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.spmm_min_mul_forward_cuda + else: + forward = spmm.spmm_min_mul_forward_cpu + output = forward(sparse, input) + ctx.save_for_backward(sparse, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.spmm_min_mul_backward_cuda + else: + backward = spmm.spmm_min_mul_backward_cpu + sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, input_grad + + +class SPMMMaxMulFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.spmm_max_mul_forward_cuda + else: + forward = spmm.spmm_max_mul_forward_cpu + output = forward(sparse, input) + ctx.save_for_backward(sparse, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.spmm_max_mul_backward_cuda + else: + backward = spmm.spmm_max_mul_backward_cpu + sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, input_grad + + +class SPMMAddAddFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.spmm_add_add_forward_cuda + else: + forward = spmm.spmm_add_add_forward_cpu + output = forward(sparse, input) + ctx.save_for_backward(sparse, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.spmm_add_add_backward_cuda + else: + backward = spmm.spmm_add_add_backward_cpu + sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, input_grad + + +class SPMMMinAddFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.spmm_min_add_forward_cuda + else: + forward = spmm.spmm_min_add_forward_cpu + output = forward(sparse, input) + ctx.save_for_backward(sparse, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.spmm_min_add_backward_cuda + else: + backward = spmm.spmm_min_add_backward_cpu + sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, input_grad + + +class SPMMMaxAddFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.spmm_max_add_forward_cuda + else: + forward = spmm.spmm_max_add_forward_cpu + output = forward(sparse, input) + ctx.save_for_backward(sparse, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.spmm_max_add_backward_cuda + else: + backward = spmm.spmm_max_add_backward_cpu + sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, input_grad + + +class RSPMMAddMulFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, relation, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.rspmm_add_mul_forward_cuda + else: + forward = spmm.rspmm_add_mul_forward_cpu + output = forward(sparse, relation, input) + ctx.save_for_backward(sparse, relation, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.rspmm_add_mul_backward_cuda + else: + backward = spmm.rspmm_add_mul_backward_cpu + sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, relation_grad, input_grad + + +class RSPMMMinMulFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, relation, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.rspmm_min_mul_forward_cuda + else: + forward = spmm.rspmm_min_mul_forward_cpu + output = forward(sparse, relation, input) + ctx.save_for_backward(sparse, relation, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.rspmm_min_mul_backward_cuda + else: + backward = spmm.rspmm_min_mul_backward_cpu + sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, relation_grad, input_grad + + +class RSPMMMaxMulFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, relation, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.rspmm_max_mul_forward_cuda + else: + forward = spmm.rspmm_max_mul_forward_cpu + output = forward(sparse, relation, input) + ctx.save_for_backward(sparse, relation, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.rspmm_max_mul_backward_cuda + else: + backward = spmm.rspmm_max_mul_backward_cpu + sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, relation_grad, input_grad + + +class RSPMMAddAddFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, relation, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.rspmm_add_add_forward_cuda + else: + forward = spmm.rspmm_add_add_forward_cpu + output = forward(sparse, relation, input) + ctx.save_for_backward(sparse, relation, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.rspmm_add_add_backward_cuda + else: + backward = spmm.rspmm_add_add_backward_cpu + sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, relation_grad, input_grad + + +class RSPMMMinAddFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, relation, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.rspmm_min_add_forward_cuda + else: + forward = spmm.rspmm_min_add_forward_cpu + output = forward(sparse, relation, input) + ctx.save_for_backward(sparse, relation, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.rspmm_min_add_backward_cuda + else: + backward = spmm.rspmm_min_add_backward_cpu + sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, relation_grad, input_grad + + +class RSPMMMaxAddFunction(autograd.Function): + + @staticmethod + def forward(ctx, sparse, relation, input): + assert sparse.is_coalesced() + if input.device.type == "cuda": + forward = spmm.rspmm_max_add_forward_cuda + else: + forward = spmm.rspmm_max_add_forward_cpu + output = forward(sparse, relation, input) + ctx.save_for_backward(sparse, relation, input, output) + return output + + @staticmethod + def backward(ctx, output_grad): + if output_grad.device.type == "cuda": + backward = spmm.rspmm_max_add_backward_cuda + else: + backward = spmm.rspmm_max_add_backward_cpu + sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) + if not ctx.saved_tensors[0].requires_grad: + sparse_grad = None + return sparse_grad, relation_grad, input_grad + + +def generalized_spmm(sparse, input, sum="add", mul="mul"): + r""" + Generalized sparse-dense matrix multiplication. + + This function computes the matrix multiplication of a sparse matrix and a dense input matrix. + The output dense matrix satisfies + + .. math:: + + output_{i,k} = \bigoplus_{j: sparse_{i,j} \neq 0} sparse_{i,j} \otimes input_{j,k} + + where :math:`\oplus` and :math:`\otimes` are the summation and the multiplication operators respectively. + + .. warning:: + + Gradient w.r.t. the sparse matrix is only computed for non-zero entries of the sparse matrix. + This behaves differently from dense-dense matrix multiplication with zero entries. + + Parameters: + sparse (SparseTensor): 2D sparse tensor + input (Tensor): 2D dense tensor + sum (str, optional): generalized summation operator. Available operators are ``add``, ``min`` and ``max``. + mul (str, optional): generalized multiplication operator. Available operators are ``add`` and ``mul``. + """ + name = "SPMM%s%sFunction" % (sum.capitalize(), mul.capitalize()) + if not hasattr(module, name): + raise ValueError("No generalized spmm implementation found for summation `%s` and multiplication `%s`" + % (sum, mul)) + Function = getattr(module, name) + return Function.apply(sparse.coalesce(), input) + + +def generalized_rspmm(sparse, relation, input, sum="add", mul="mul"): + r""" + Generalized relational sparse-dense matrix multiplication. + + This function computes the matrix multiplication of a sparse matrix, a dense relation matrix and + a dense input matrix. The output dense matrix satisfies + + .. math:: + + output_{i,l} = \bigoplus_{j,k: sparse_{i,j,k} \neq 0} sparse_{i, j, k} \times (relation_{k,l} \otimes input_{j,l}) + + where :math:`\oplus` and :math:`\otimes` are the summation and the multiplication operators respectively. + + .. warning:: + + Gradient w.r.t. the sparse matrix is only computed for non-zero entries of the sparse matrix. + This behaves differently from dense-dense matrix multiplication with zero entries. + + Parameters: + sparse (SparseTensor): 3D sparse tensor + relation (Tensor): 2D dense tensor + input (Tensor): 2D dense tensor + sum (str, optional): generalized summation operator. Available operators are ``add``, ``min`` and ``max``. + mul (str, optional): generalized multiplication operator. Available operators are ``add`` and ``mul``. + """ + name = "RSPMM%s%sFunction" % (sum.capitalize(), mul.capitalize()) + if not hasattr(module, name): + raise ValueError("No generalized rspmm implementation found for summation `%s` and multiplication `%s`" + % (sum, mul)) + Function = getattr(module, name) + return Function.apply(sparse.coalesce(), relation, input) \ No newline at end of file diff --git a/build/lib/torchdrug/layers/geometry/__init__.py b/build/lib/torchdrug/layers/geometry/__init__.py new file mode 100644 index 00000000..e1e851a9 --- /dev/null +++ b/build/lib/torchdrug/layers/geometry/__init__.py @@ -0,0 +1,9 @@ +from .graph import GraphConstruction, SpatialLineGraph +from .function import BondEdge, KNNEdge, SpatialEdge, SequentialEdge, AlphaCarbonNode, \ + IdentityNode, RandomEdgeMask, SubsequenceNode, SubspaceNode + +__all__ = [ + "GraphConstruction", "SpatialLineGraph", + "BondEdge", "KNNEdge", "SpatialEdge", "SequentialEdge", "AlphaCarbonNode", + "IdentityNode", "RandomEdgeMask", "SubsequenceNode", "SubspaceNode" +] \ No newline at end of file diff --git a/build/lib/torchdrug/layers/geometry/function.py b/build/lib/torchdrug/layers/geometry/function.py new file mode 100644 index 00000000..22fc10e5 --- /dev/null +++ b/build/lib/torchdrug/layers/geometry/function.py @@ -0,0 +1,343 @@ +import torch +from torch import nn +from torch_cluster import knn_graph, radius_graph + +from torchdrug import core, data +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("layers.geometry.BondEdge") +class BondEdge(nn.Module, core.Configurable): + """ + Construct all bond edges. + """ + + def forward(self, graph): + """ + Return bond edges from the input graph. Edge types are inherited from the input graph. + + Parameters: + graph (Graph): :math:`n` graph(s) + + Returns: + (Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations + """ + return graph.edge_list, graph.num_relation + + +@R.register("layers.geometry.KNNEdge") +class KNNEdge(nn.Module, core.Configurable): + """ + Construct edges between each node and its nearest neighbors. + + Parameters: + k (int, optional): number of neighbors + min_distance (int, optional): minimum distance between the residues of two nodes + """ + + eps = 1e-10 + + def __init__(self, k=10, min_distance=5, max_distance=None): + super(KNNEdge, self).__init__() + self.k = k + self.min_distance = min_distance + self.max_distance = max_distance + + def forward(self, graph): + """ + Return KNN edges constructed from the input graph. + + Parameters: + graph (Graph): :math:`n` graph(s) + + Returns: + (Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations + """ + edge_list = knn_graph(graph.node_position, k=self.k, batch=graph.node2graph).t() + relation = torch.zeros(len(edge_list), 1, dtype=torch.long, device=graph.device) + edge_list = torch.cat([edge_list, relation], dim=-1) + + if self.min_distance > 0: + node_in, node_out = edge_list.t()[:2] + mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() < self.min_distance + edge_list = edge_list[~mask] + + if self.max_distance: + node_in, node_out = edge_list.t()[:2] + mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() > self.max_distance + edge_list = edge_list[~mask] + + node_in, node_out = edge_list.t()[:2] + mask = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1) < self.eps + edge_list = edge_list[~mask] + + return edge_list, 1 + + +@R.register("layers.geometry.SpatialEdge") +class SpatialEdge(nn.Module, core.Configurable): + """ + Construct edges between nodes within a specified radius. + + Parameters: + radius (float, optional): spatial radius + min_distance (int, optional): minimum distance between the residues of two nodes + """ + + eps = 1e-10 + + def __init__(self, radius=5, min_distance=5, max_distance=None, max_num_neighbors=32): + super(SpatialEdge, self).__init__() + self.radius = radius + self.min_distance = min_distance + self.max_distance = max_distance + self.max_num_neighbors = max_num_neighbors + + def forward(self, graph): + """ + Return spatial radius edges constructed based on the input graph. + + Parameters: + graph (Graph): :math:`n` graph(s) + + Returns: + (Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations + """ + edge_list = radius_graph(graph.node_position, r=self.radius, batch=graph.node2graph, max_num_neighbors=self.max_num_neighbors).t() + relation = torch.zeros(len(edge_list), 1, dtype=torch.long, device=graph.device) + edge_list = torch.cat([edge_list, relation], dim=-1) + + if self.min_distance > 0: + node_in, node_out = edge_list.t()[:2] + mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() < self.min_distance + edge_list = edge_list[~mask] + + if self.max_distance: + node_in, node_out = edge_list.t()[:2] + mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() > self.max_distance + edge_list = edge_list[~mask] + + node_in, node_out = edge_list.t()[:2] + mask = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1) < self.eps + edge_list = edge_list[~mask] + + return edge_list, 1 + + +@R.register("layers.geometry.SequentialEdge") +class SequentialEdge(nn.Module, core.Configurable): + """ + Construct edges between atoms within close residues. + + Parameters: + max_distance (int, optional): maximum distance between two residues in the sequence + """ + + def __init__(self, max_distance=2, only_backbone=False): + super(SequentialEdge, self).__init__() + self.max_distance = max_distance + self.only_backbone = only_backbone + + def forward(self, graph): + """ + Return sequential edges constructed based on the input graph. + Edge types are defined by the relative distance between two residues in the sequence + + Parameters: + graph (Graph): :math:`n` graph(s) + + Returns: + (Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations + """ + if self.only_backbone: + is_backbone = (graph.atom_name == graph.atom_name2id["CA"]) \ + | (graph.atom_name == graph.atom_name2id["C"]) \ + | (graph.atom_name == graph.atom_name2id["N"]) + atom2residue = graph.atom2residue[is_backbone] + else: + atom2residue = graph.atom2residue + residue2num_atom = atom2residue.bincount(minlength=graph.num_residue) + edge_list = [] + for i in range(-self.max_distance, self.max_distance + 1): + node_index = torch.arange(graph.num_node, device=graph.device) + residue_index = torch.arange(graph.num_residue, device=graph.device) + if i > 0: + is_node_in = graph.atom2residue < graph.num_cum_residues[graph.atom2graph] - i + is_node_out = graph.atom2residue >= (graph.num_cum_residues - graph.num_residues)[graph.atom2graph] + i + is_residue_in = residue_index < graph.num_cum_residues[graph.residue2graph] - i + is_residue_out = residue_index >= (graph.num_cum_residues - graph.num_residues)[graph.residue2graph] + i + else: + is_node_in = graph.atom2residue >= (graph.num_cum_residues - graph.num_residues)[graph.atom2graph] - i + is_node_out = graph.atom2residue < graph.num_cum_residues[graph.atom2graph] + i + is_residue_in = residue_index >= (graph.num_cum_residues - graph.num_residues)[graph.residue2graph] - i + is_residue_out = residue_index < graph.num_cum_residues[graph.residue2graph] + i + if self.only_backbone: + is_node_in = is_node_in & is_backbone + is_node_out = is_node_out & is_backbone + node_in = node_index[is_node_in] + node_out = node_index[is_node_out] + # group atoms by residue ids + node_in = node_in[graph.atom2residue[node_in].argsort()] + node_out = node_out[graph.atom2residue[node_out].argsort()] + num_node_in = residue2num_atom[is_residue_in] + num_node_out = residue2num_atom[is_residue_out] + node_in, node_out = functional.variadic_meshgrid(node_in, num_node_in, node_out, num_node_out) + # exclude cross-chain edges + is_same_chain = (graph.chain_id[graph.atom2residue[node_in]] == graph.chain_id[graph.atom2residue[node_out]]) + node_in = node_in[is_same_chain] + node_out = node_out[is_same_chain] + relation = torch.ones(len(node_in), dtype=torch.long, device=graph.device) * (i + self.max_distance) + edges = torch.stack([node_in, node_out, relation], dim=-1) + edge_list.append(edges) + + edge_list = torch.cat(edge_list) + + return edge_list, 2 * self.max_distance + 1 + + +@R.register("layers.geometry.AlphaCarbonNode") +class AlphaCarbonNode(nn.Module, core.Configurable): + """ + Construct only alpha carbon atoms. + """ + + def forward(self, graph): + """ + Return a subgraph that only consists of alpha carbon nodes. + + Parameters: + graph (Graph): :math:`n` graph(s) + """ + mask = (graph.atom_name == data.Protein.atom_name2id["CA"]) & (graph.atom2residue != -1) + residue2num_atom = graph.atom2residue[mask].bincount(minlength=graph.num_residue) + residue_mask = residue2num_atom > 0 + mask = mask & residue_mask[graph.atom2residue] + graph = graph.subgraph(mask).subresidue(residue_mask) + assert (graph.num_node == graph.num_residue).all() + + return graph + + +@R.register("layers.geometry.IdentityNode") +class IdentityNode(nn.Module, core.Configurable): + """ + Construct all nodes as the input. + """ + + def forward(self, graph): + """ + Return the input graph as is. + + Parameters: + graph (Graph): :math:`n` graph(s) + """ + return graph + + +@R.register("layers.geometry.RandomEdgeMask") +class RandomEdgeMask(nn.Module, core.Configurable): + """ + Construct nodes by random edge masking. + + Parameters: + mask_rate (float, optional): rate of masked edges + """ + + def __init__(self, mask_rate=0.15): + super(RandomEdgeMask, self).__init__() + self.mask_rate = mask_rate + + def forward(self, graph): + """ + Return a graph with some edges masked out. + + Parameters: + graph (Graph): :math:`n` graph(s) + """ + num_samples = (graph.num_edges * self.mask_rate).long().clamp(min=1) + num_sample = num_samples.sum() + sample2graph = torch.repeat_interleave(num_samples) + edge_index = (torch.rand(num_sample, device=graph.device) * graph.num_edges[sample2graph]).long() + edge_index = edge_index + (graph.num_cum_edges - graph.num_edges)[sample2graph] + edge_mask = ~functional.as_mask(edge_index, graph.num_edge) + + return graph.edge_mask(edge_mask) + + +@R.register("layers.geometry.SubsequenceNode") +class SubsequenceNode(nn.Module, core.Configurable): + """ + Construct nodes by taking a random subsequence of the original graph. + + Parameters: + max_length (int, optional): maximal length of the sequence after cropping + """ + + def __init__(self, max_length=100): + super(SubsequenceNode, self).__init__() + self.max_length = max_length + + def forward(self, graph): + """ + Randomly take a subsequence of the specified length. + Return the full sequence if the sequence is shorter than the specified length. + + Parameters: + graph (Graph): :math:`n` graph(s) + """ + starts = (torch.rand(graph.batch_size, device=graph.device) * + (graph.num_residues - self.max_length).clamp(min=0)).long() + ends = torch.min(starts + self.max_length, graph.num_residues) + starts = starts + graph.num_cum_residues - graph.num_residues + ends = ends + graph.num_cum_residues - graph.num_residues + + residue_mask = functional.multi_slice_mask(starts, ends, graph.num_residue) + graph = graph.subresidue(residue_mask) + + return graph + + +@R.register("layers.geometry.SubspaceNode") +class SubspaceNode(nn.Module, core.Configurable): + """ + Construct nodes by taking a spatial ball of the original graph. + + Parameters: + entity_level (str, optional): level to perform cropping. + Available options are ``node``, ``atom`` and ``residue``. + min_radius (float, optional): minimum radius of the spatial ball + min_neighbor (int, optional): minimum number of nodes in the spatial ball + """ + + def __init__(self, entity_level="node", min_radius=15.0, min_neighbor=50): + super(SubspaceNode, self).__init__() + self.entity_level = entity_level + self.min_radius = min_radius + self.min_neighbor = min_neighbor + + def forward(self, graph): + """ + Randomly pick a node as the center, and crop a spatial ball + that is at least `radius` large and contain at least `k` nodes. + + Parameters: + graph (Graph): :math:`n` graph(s) + """ + node_in = torch.arange(graph.num_node, device=graph.device) + node_in = functional.variadic_sample(node_in, graph.num_nodes, 1).squeeze(-1) + node_in = node_in.repeat_interleave(graph.num_nodes) + node_out = torch.arange(graph.num_node, device=graph.device) + dist = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1) + topk_dist = functional.variadic_topk(dist, graph.num_nodes, self.min_neighbor, largest=False)[0] + radius = (topk_dist[:, -1] * 1.5).clamp(min=self.min_radius) + radius = radius.repeat_interleave(graph.num_nodes) + node_index = node_out[dist < radius] + + if self.entity_level in ["node", "atom"]: + graph = graph.subgraph(node_index) + else: + residue_index = graph.atom2residue[node_index].unique() + graph = graph.subresidue(residue_index) + + return graph diff --git a/build/lib/torchdrug/layers/geometry/graph.py b/build/lib/torchdrug/layers/geometry/graph.py new file mode 100644 index 00000000..7aa16a83 --- /dev/null +++ b/build/lib/torchdrug/layers/geometry/graph.py @@ -0,0 +1,194 @@ +import math + +import torch +from torch import nn + +from torchdrug import core, data +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("layers.GraphConstruction") +class GraphConstruction(nn.Module, core.Configurable): + """ + Construct a new graph from an existing graph. + + See `torchdrug.layers.geometry` for a full list of available node and edge layers. + + Parameters: + node_layers (list of nn.Module, optional): modules to construct nodes of the new graph + edge_layers (list of nn.Module, optional): modules to construct edges of the new graph + edge_feature (str, optional): edge features in the new graph. + Available features are ``residue_type``, ``gearnet``. + + 1. For ``residue_type``, the feature of the edge :math:`e_{ij}` between residue :math:`i` and residue + :math:`j` is the concatenation ``[residue_type(i), residue_type(j)]``. + 2. For ``gearnet``, the feature of the edge :math:`e_{ij}` between residue :math:`i` and residue :math:`j` + is the concatenation ``[residue_type(i), residue_type(j), edge_type(e_ij), + sequential_distance(i,j), spatial_distance(i,j)]``. + + .. note:: + You may customize your own edge features by inheriting this class and define a member function + for your features. Use ``edge_feature="my_feature"`` to call the following feature function. + + .. code:: python + + def edge_my_feature(self, graph, edge_list, num_relation): + ... + return feature # the first dimension must be ``graph.num_edge`` + """ + + max_seq_dist = 10 + + def __init__(self, node_layers=None, edge_layers=None, edge_feature="residue_type"): + super(GraphConstruction, self).__init__() + if node_layers is None: + self.node_layers = nn.ModuleList() + else: + self.node_layers = nn.ModuleList(node_layers) + if edge_layers is None: + edge_layers = nn.ModuleList() + else: + edge_layers = nn.ModuleList(edge_layers) + self.edge_layers = edge_layers + self.edge_feature = edge_feature + + def edge_residue_type(self, graph, edge_list, num_relation): + node_in, node_out, _ = edge_list.t() + residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out] + in_residue_type = graph.residue_type[residue_in] + out_residue_type = graph.residue_type[residue_out] + + return torch.cat([ + functional.one_hot(in_residue_type, len(data.Protein.residue2id)), + functional.one_hot(out_residue_type, len(data.Protein.residue2id)) + ], dim=-1) + + def edge_gearnet(self, graph, edge_list, num_relation): + node_in, node_out, r = edge_list.t() + residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out] + in_residue_type = graph.residue_type[residue_in] + out_residue_type = graph.residue_type[residue_out] + sequential_dist = torch.abs(residue_in - residue_out) + spatial_dist = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1) + + return torch.cat([ + functional.one_hot(in_residue_type, len(data.Protein.residue2id)), + functional.one_hot(out_residue_type, len(data.Protein.residue2id)), + functional.one_hot(r, num_relation), + functional.one_hot(sequential_dist.clamp(max=self.max_seq_dist), self.max_seq_dist + 1), + spatial_dist.unsqueeze(-1) + ], dim=-1) + + def apply_node_layer(self, graph): + for layer in self.node_layers: + graph = layer(graph) + return graph + + def apply_edge_layer(self, graph): + if not self.edge_layers: + return graph + + edge_list = [] + num_edges = [] + num_relations = [] + for layer in self.edge_layers: + edges, num_relation = layer(graph) + edge_list.append(edges) + num_edges.append(len(edges)) + num_relations.append(num_relation) + + edge_list = torch.cat(edge_list) + num_edges = torch.tensor(num_edges, device=graph.device) + num_relations = torch.tensor(num_relations, device=graph.device) + num_relation = num_relations.sum() + offsets = (num_relations.cumsum(0) - num_relations).repeat_interleave(num_edges) + edge_list[:, 2] += offsets + + # reorder edges into a valid PackedGraph + node_in = edge_list[:, 0] + edge2graph = graph.node2graph[node_in] + order = edge2graph.argsort() + edge_list = edge_list[order] + num_edges = edge2graph.bincount(minlength=graph.batch_size) + offsets = (graph.num_cum_nodes - graph.num_nodes).repeat_interleave(num_edges) + + if hasattr(self, "edge_%s" % self.edge_feature): + edge_feature = getattr(self, "edge_%s" % self.edge_feature)(graph, edge_list, num_relation) + elif self.edge_feature is None: + edge_feature = None + else: + raise ValueError("Unknown edge feature `%s`" % self.edge_feature) + data_dict, meta_dict = graph.data_by_meta(include=( + "node", "residue", "node reference", "residue reference", "graph" + )) + + if isinstance(graph, data.PackedProtein): + data_dict["num_residues"] = graph.num_residues + if isinstance(graph, data.PackedMolecule): + data_dict["bond_type"] = torch.zeros_like(edge_list[:, 2]) + return type(graph)(edge_list, num_nodes=graph.num_nodes, num_edges=num_edges, num_relation=num_relation, + view=graph.view, offsets=offsets, edge_feature=edge_feature, + meta_dict=meta_dict, **data_dict) + + def forward(self, graph): + """ + Generate a new graph based on the input graph and pre-defined node and edge layers. + + Parameters: + graph (Graph): :math:`n` graph(s) + + Returns: + graph (Graph): new graph(s) + """ + graph = self.apply_node_layer(graph) + graph = self.apply_edge_layer(graph) + return graph + + +@R.register("layers.SpatialLineGraph") +class SpatialLineGraph(nn.Module, core.Configurable): + """ + Spatial line graph construction module from `Protein Representation Learning by Geometric Structure Pretraining`_. + + .. _Protein Representation Learning by Geometric Structure Pretraining: + https://arxiv.org/pdf/2203.06125.pdf + + Parameters: + num_angle_bin (int, optional): number of bins to discretize angles between edges + """ + + def __init__(self, num_angle_bin=8): + super(SpatialLineGraph, self).__init__() + self.num_angle_bin = num_angle_bin + + def forward(self, graph): + """ + Generate the spatial line graph of the input graph. + The edge types are decided by the angles between two adjacent edges in the input graph. + + Parameters: + graph (PackedGraph): :math:`n` graph(s) + + Returns: + graph (PackedGraph): the spatial line graph + """ + line_graph = graph.line_graph() + node_in, node_out = graph.edge_list[:, :2].t() + edge_in, edge_out = line_graph.edge_list.t() + + # compute the angle ijk + node_i = node_out[edge_out] + node_j = node_in[edge_out] + node_k = node_in[edge_in] + vector1 = graph.node_position[node_i] - graph.node_position[node_j] + vector2 = graph.node_position[node_k] - graph.node_position[node_j] + x = (vector1 * vector2).sum(dim=-1) + y = torch.cross(vector1, vector2).norm(dim=-1) + angle = torch.atan2(y, x) + relation = (angle / math.pi * self.num_angle_bin).long().clamp(max=self.num_angle_bin - 1) + edge_list = torch.cat([line_graph.edge_list, relation.unsqueeze(-1)], dim=-1) + + return type(line_graph)(edge_list, num_nodes=line_graph.num_nodes, offsets=line_graph._offsets, + num_edges=line_graph.num_edges, num_relation=self.num_angle_bin, + meta_dict=line_graph.meta_dict, **line_graph.data_dict) diff --git a/build/lib/torchdrug/layers/pool.py b/build/lib/torchdrug/layers/pool.py new file mode 100644 index 00000000..973def24 --- /dev/null +++ b/build/lib/torchdrug/layers/pool.py @@ -0,0 +1,207 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch_scatter import scatter_add, scatter_mean + +from torchdrug import data + + +class DiffPool(nn.Module): + """ + Differentiable pooling operator from `Hierarchical Graph Representation Learning with Differentiable Pooling`_ + + .. _Hierarchical Graph Representation Learning with Differentiable Pooling: + https://papers.nips.cc/paper/7729-hierarchical-graph-representation-learning-with-differentiable-pooling.pdf + + Parameter + input_dim (int): input dimension + output_node (int): number of nodes after pooling + feature_layer (Module, optional): graph convolution layer for embedding + pool_layer (Module, optional): graph convolution layer for pooling assignment + loss_weight (float, optional): weight of entropy regularization + zero_diagonal (bool, optional): remove self loops in the pooled graph or not + sparse (bool, optional): use sparse assignment or not + """ + + tau = 1 + eps = 1e-10 + + def __init__(self, input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=False, + sparse=False): + super(DiffPool, self).__init__() + self.input_dim = input_dim + self.output_dim = feature_layer.output_dim + self.output_node = output_node + self.feature_layer = feature_layer + self.pool_layer = pool_layer + self.loss_weight = loss_weight + self.zero_diagonal = zero_diagonal + self.sparse = sparse + + if pool_layer is not None: + self.linear = nn.Linear(pool_layer.output_dim, output_node) + else: + self.linear = nn.Linear(input_dim, output_node) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node cluster assignment and pool the nodes. + + Parameters: + graph (Graph): graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + (PackedGraph, Tensor, Tensor): + pooled graph, output node representations, node-to-cluster assignment + """ + feature = input + if self.feature_layer: + feature = self.feature_layer(graph, feature) + + x = input + if self.pool_layer: + x = self.pool_layer(graph, x) + x = self.linear(x) + if self.sparse: + assignment = F.gumbel_softmax(x, hard=True, tau=self.tau, dim=-1) + new_graph, output = self.sparse_pool(graph, feature, assignment) + else: + assignment = F.softmax(x, dim=-1) + new_graph, output = self.dense_pool(graph, feature, assignment) + + if all_loss is not None: + prob = scatter_mean(assignment, graph.node2graph, dim=0, dim_size=graph.batch_size) + entropy = -(prob * (prob + self.eps).log()).sum(dim=-1) + entropy = entropy.mean() + metric["assignment entropy"] = entropy + if self.loss_weight > 0: + all_loss -= entropy * self.loss_weight + + if self.zero_diagonal: + edge_list = new_graph.edge_list[:, :2] + is_diagonal = edge_list[:, 0] == edge_list[:, 1] + new_graph = new_graph.edge_mask(~is_diagonal) + + return new_graph, output, assignment + + def dense_pool(self, graph, input, assignment): + node_in, node_out = graph.edge_list.t()[:2] + # S^T A S, O(|V|k^2 + |E|k) + x = graph.edge_weight.unsqueeze(-1) * assignment[node_out] + x = scatter_add(x, node_in, dim=0, dim_size=graph.num_node) + x = torch.einsum("np, nq -> npq", assignment, x) + adjacency = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size) + # S^T X + x = torch.einsum("na, nd -> nad", assignment, input) + output = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size).flatten(0, 1) + + index = torch.arange(self.output_node, device=graph.device).expand(len(graph), self.output_node, -1) + edge_list = torch.stack([index.transpose(-1, -2), index], dim=-1).flatten(0, -2) + edge_weight = adjacency.flatten() + if isinstance(graph, data.PackedGraph): + num_nodes = torch.ones(len(graph), dtype=torch.long, device=input.device) * self.output_node + num_edges = torch.ones(len(graph), dtype=torch.long, device=input.device) * self.output_node ** 2 + graph = data.PackedGraph(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges) + else: + graph = data.Graph(edge_list, edge_weight=edge_weight, num_node=self.output_node) + return graph, output + + def sparse_pool(self, graph, input, assignment): + assignment = assignment.argmax(dim=-1) + edge_list = graph.edge_list[:, :2] + edge_list = assignment[edge_list] + pooled_node = graph.node2graph * self.output_node + assignment + output = scatter_add(input, pooled_node, dim=0, dim_size=graph.batch_size * self.output_node) + + edge_weight = graph.edge_weight + if isinstance(graph, data.PackedGraph): + num_nodes = torch.ones(len(graph), dtype=torch.long, device=input.device) * self.output_node + num_edges = graph.num_edges + graph = data.PackedGraph(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges) + else: + graph = data.Graph(edge_list, edge_weight=edge_weight, num_node=self.output_node) + return graph, output + + +class MinCutPool(DiffPool): + """ + Min cut pooling operator from `Spectral Clustering with Graph Neural Networks for Graph Pooling`_ + + .. _Spectral Clustering with Graph Neural Networks for Graph Pooling: + http://proceedings.mlr.press/v119/bianchi20a/bianchi20a.pdf + + Parameters: + input_dim (int): input dimension + output_node (int): number of nodes after pooling + feature_layer (Module, optional): graph convolution layer for embedding + pool_layer (Module, optional): graph convolution layer for pooling assignment + loss_weight (float, optional): weight of entropy regularization + zero_diagonal (bool, optional): remove self loops in the pooled graph or not + sparse (bool, optional): use sparse assignment or not + """ + + eps = 1e-10 + + def __init__(self, input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=True, + sparse=False): + super(MinCutPool, self).__init__(input_dim, output_node, feature_layer, pool_layer, loss_weight, zero_diagonal, + sparse) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node cluster assignment and pool the nodes. + + Parameters: + graph (Graph): graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + (PackedGraph, Tensor, Tensor): + pooled graph, output node representations, node-to-cluster assignment + """ + feature = input + if self.feature_layer: + feature = self.feature_layer(graph, feature) + + x = input + if self.pool_layer: + x = self.pool_layer(graph, x) + x = self.linear(x) + if self.sparse: + assignment = F.gumbel_softmax(x, hard=True, tau=self.tau, dim=-1) + new_graph, output = self.sparse_pool(graph, feature, assignment) + else: + assignment = F.softmax(x, dim=-1) + new_graph, output = self.dense_pool(graph, feature, assignment) + + if all_loss is not None: + edge_list = new_graph.edge_list + is_diagonal = edge_list[:, 0] == edge_list[:, 1] + num_intra = scatter_add(new_graph.edge_weight[is_diagonal], new_graph.edge2graph[is_diagonal], + dim=0, dim_size=new_graph.batch_size) + x = torch.einsum("na, n, nc -> nac", assignment, graph.degree_in, assignment) + x = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size) + num_all = torch.einsum("baa -> b", x) + cut_loss = (1 - num_intra / (num_all + self.eps)).mean() + metric["normalized cut loss"] = cut_loss + + x = torch.einsum("na, nc -> nac", assignment, assignment) + x = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size) + x = x / x.flatten(-2).norm(dim=-1, keepdim=True).unsqueeze(-1) + x = x - torch.eye(self.output_node, device=x.device) / (self.output_node ** 0.5) + regularization = x.flatten(-2).norm(dim=-1).mean() + metric["orthogonal regularization"] = regularization + if self.loss_weight > 0: + all_loss += (cut_loss + regularization) * self.loss_weight + + if self.zero_diagonal: + edge_list = new_graph.edge_list[:, :2] + is_diagonal = edge_list[:, 0] == edge_list[:, 1] + new_graph = new_graph.edge_mask(~is_diagonal) + + return new_graph, output, assignment \ No newline at end of file diff --git a/build/lib/torchdrug/layers/readout.py b/build/lib/torchdrug/layers/readout.py new file mode 100644 index 00000000..0446e7e0 --- /dev/null +++ b/build/lib/torchdrug/layers/readout.py @@ -0,0 +1,197 @@ +import torch +from torch import nn +from torch_scatter import scatter_mean, scatter_add, scatter_max, scatter_softmax + + +class Readout(nn.Module): + + def __init__(self, type="node"): + super(Readout, self).__init__() + self.type = type + + def get_index2graph(self, graph): + if self.type == "node": + input2graph = graph.node2graph + elif self.type == "edge": + input2graph = graph.edge2graph + elif self.type == "residue": + input2graph = graph.residue2graph + else: + raise ValueError("Unknown input type `%s` for readout functions" % self.type) + return input2graph + + +class MeanReadout(Readout): + """Mean readout operator over graphs with variadic sizes.""" + + def forward(self, graph, input): + """ + Perform readout over the graph(s). + + Parameters: + graph (Graph): graph(s) + input (Tensor): node representations + + Returns: + Tensor: graph representations + """ + input2graph = self.get_index2graph(graph) + output = scatter_mean(input, input2graph, dim=0, dim_size=graph.batch_size) + return output + + +class SumReadout(Readout): + """Sum readout operator over graphs with variadic sizes.""" + + def forward(self, graph, input): + """ + Perform readout over the graph(s). + + Parameters: + graph (Graph): graph(s) + input (Tensor): node representations + + Returns: + Tensor: graph representations + """ + input2graph = self.get_index2graph(graph) + output = scatter_add(input, input2graph, dim=0, dim_size=graph.batch_size) + return output + + +class MaxReadout(Readout): + """Max readout operator over graphs with variadic sizes.""" + + def forward(self, graph, input): + """ + Perform readout over the graph(s). + + Parameters: + graph (Graph): graph(s) + input (Tensor): node representations + + Returns: + Tensor: graph representations + """ + input2graph = self.get_index2graph(graph) + output = scatter_max(input, input2graph, dim=0, dim_size=graph.batch_size)[0] + return output + + +class AttentionReadout(Readout): + """Attention readout operator over graphs with variadic sizes.""" + + def __init__(self, input_dim, type="node"): + super(AttentionReadout, self).__init__(type) + self.input_dim = input_dim + self.linear = nn.Linear(input_dim, 1) + + def forward(self, graph, input): + index2graph = self.get_index2graph(graph) + weight = self.linear(input) + attention = scatter_softmax(weight, index2graph, dim=0) + output = scatter_add(attention * input, index2graph, dim=0, dim_size=graph.batch_size) + return output + + +class Softmax(Readout): + """Softmax operator over graphs with variadic sizes.""" + + eps = 1e-10 + + def forward(self, graph, input): + """ + Perform softmax over the graph(s). + + Parameters: + graph (Graph): graph(s) + input (Tensor): node logits + + Returns: + Tensor: node probabilities + """ + input2graph = self.get_index2graph(graph) + x = input - scatter_max(input, input2graph, dim=0, dim_size=graph.batch_size)[0][input2graph] + x = x.exp() + normalizer = scatter_add(x, input2graph, dim=0, dim_size=graph.batch_size)[input2graph] + return x / (normalizer + self.eps) + + +class Sort(Readout): + """ + Sort operator over graphs with variadic sizes. + + Parameters: + descending (bool, optional): use descending sort order or not + """ + + def __init__(self, type="node", descending=False): + super(Sort, self).__init__(type) + self.descending = descending + + def forward(self, graph, input): + """ + Perform sort over graph(s). + + Parameters: + graph (Graph): graph(s) + input (Tensor): node values + + Returns: + (Tensor, LongTensor): sorted values, sorted indices + """ + input2graph = self.get_index2graph(graph) + step = input.max(dim=0) - input.min(dim=0) + 1 + if self.descending: + step = -step + x = input + input2graph * step + sorted, index = x.sort(dim=0, descending=self.descending) + sorted = sorted - input2graph * step + return sorted, index + + +class Set2Set(Readout): + """ + Set2Set operator from `Order Matters: Sequence to sequence for sets`_. + + .. _Order Matters\: Sequence to sequence for sets: + https://arxiv.org/pdf/1511.06391.pdf + + Parameters: + input_dim (int): input dimension + num_step (int, optional): number of process steps + num_lstm_layer (int, optional): number of LSTM layers + """ + + def __init__(self, input_dim, type="node", num_step=3, num_lstm_layer=1): + super(Set2Set, self).__init__(type) + self.input_dim = input_dim + self.output_dim = self.input_dim * 2 + self.num_step = num_step + self.lstm = nn.LSTM(input_dim * 2, input_dim, num_lstm_layer) + self.softmax = Softmax(type) + + def forward(self, graph, input): + """ + Perform Set2Set readout over graph(s). + + Parameters: + graph (Graph): graph(s) + input (Tensor): node representations + + Returns: + Tensor: graph representations + """ + input2graph = self.get_index2graph(graph) + hx = (torch.zeros(self.lstm.num_layers, graph.batch_size, self.lstm.hidden_size, device=input.device),) * 2 + query_star = torch.zeros(graph.batch_size, self.output_dim, device=input.device) + + for i in range(self.num_step): + query, hx = self.lstm(query_star.unsqueeze(0), hx) + query = query.squeeze(0) + product = torch.einsum("bd, bd -> b", query[input2graph], input) + attention = self.softmax(graph, product) + output = scatter_add(attention.unsqueeze(-1) * input, input2graph, dim=0, dim_size=graph.batch_size) + query_star = torch.cat([query, output], dim=-1) + + return query_star \ No newline at end of file diff --git a/build/lib/torchdrug/layers/sampler.py b/build/lib/torchdrug/layers/sampler.py new file mode 100644 index 00000000..0f97107b --- /dev/null +++ b/build/lib/torchdrug/layers/sampler.py @@ -0,0 +1,92 @@ +from torch import nn +from torch_scatter import scatter_add + +from torchdrug.layers import functional + + +class NodeSampler(nn.Module): + """ + Node sampler from `GraphSAINT: Graph Sampling Based Inductive Learning Method`_. + + .. _GraphSAINT\: Graph Sampling Based Inductive Learning Method: + https://arxiv.org/pdf/1907.04931.pdf + + Parameters: + budget (int, optional): number of node to keep + ratio (int, optional): ratio of node to keep + """ + + def __init__(self, budget=None, ratio=None): + super(NodeSampler, self).__init__() + if budget is None and ratio is None: + raise ValueError("At least one of `budget` and `ratio` should be provided") + self.budget = budget + self.ratio = ratio + + def forward(self, graph): + """ + Sample a subgraph from the graph. + + Parameters: + graph (Graph): graph(s) + """ + # this is exact for a single graph + # but approximate for packed graphs + num_sample = graph.num_node + if self.budget: + num_sample = min(num_sample, self.budget) + if self.ratio: + num_sample = min(num_sample, int(self.ratio * graph.num_node)) + + prob = scatter_add(graph.edge_weight ** 2, graph.edge_list[:, 1], dim_size=graph.num_node) + prob /= prob.mean() + index = functional.multinomial(prob, num_sample) + new_graph = graph.node_mask(index) + node_out = new_graph.edge_list[:, 1] + new_graph._edge_weight /= num_sample * prob[node_out] / graph.num_node + + return new_graph + + +class EdgeSampler(nn.Module): + """ + Edge sampler from `GraphSAINT: Graph Sampling Based Inductive Learning Method`_. + + .. _GraphSAINT\: Graph Sampling Based Inductive Learning Method: + https://arxiv.org/pdf/1907.04931.pdf + + Parameters: + budget (int, optional): number of node to keep + ratio (int, optional): ratio of node to keep + """ + + def __init__(self, budget=None, ratio=None): + super(EdgeSampler, self).__init__() + if budget is None and ratio is None: + raise ValueError("At least one of `budget` and `ratio` should be provided") + self.budget = budget + self.ratio = ratio + + def forward(self, graph): + """ + Sample a subgraph from the graph. + + Parameters: + graph (Graph): graph(s) + """ + # this is exact for a single graph + # but approximate for packed graphs + node_in, node_out = graph.edge_list.t()[:2] + num_sample = graph.num_edge + if self.budget: + num_sample = min(num_sample, self.budget) + if self.ratio: + num_sample = min(num_sample, int(self.ratio * graph.num_edge)) + + prob = 1 / graph.degree_out[node_out] + 1 / graph.degree_in[node_in] + prob = prob / prob.mean() + index = functional.multinomial(prob, num_sample) + new_graph = graph.edge_mask(index) + new_graph._edge_weight /= num_sample * prob[index] / graph.num_edge + + return new_graph diff --git a/build/lib/torchdrug/metrics/__init__.py b/build/lib/torchdrug/metrics/__init__.py new file mode 100644 index 00000000..a5292fc2 --- /dev/null +++ b/build/lib/torchdrug/metrics/__init__.py @@ -0,0 +1,14 @@ +from .metric import area_under_roc, area_under_prc, r2, QED, logP, penalized_logP, SA, chemical_validity, \ + accuracy, variadic_accuracy, matthews_corrcoef, pearsonr, spearmanr, \ + variadic_area_under_prc, variadic_area_under_roc, variadic_top_precision, f1_max + +# alias +AUROC = area_under_roc +AUPRC = area_under_prc + +__all__ = [ + "area_under_roc", "area_under_prc", "r2", "QED", "logP", "penalized_logP", "SA", "chemical_validity", + "accuracy", "variadic_accuracy", "matthews_corrcoef", "pearsonr", "spearmanr", + "variadic_area_under_prc", "variadic_area_under_roc", "variadic_top_precision", "f1_max", + "AUROC", "AUPRC", +] \ No newline at end of file diff --git a/build/lib/torchdrug/metrics/metric.py b/build/lib/torchdrug/metrics/metric.py new file mode 100644 index 00000000..89d34166 --- /dev/null +++ b/build/lib/torchdrug/metrics/metric.py @@ -0,0 +1,396 @@ +import torch +from torch.nn import functional as F +from torch_scatter import scatter_add, scatter_mean, scatter_max +import networkx as nx +from rdkit import Chem +from rdkit.Chem import Descriptors + +from torchdrug import utils +from torchdrug.layers import functional +from torchdrug.core import Registry as R +from torchdrug.metrics.rdkit import sascorer + + +@R.register("metrics.auroc") +def area_under_roc(pred, target): + """ + Area under receiver operating characteristic curve (ROC). + + Parameters: + pred (Tensor): predictions of shape :math:`(n,)` + target (Tensor): binary targets of shape :math:`(n,)` + """ + order = pred.argsort(descending=True) + target = target[order] + hit = target.cumsum(0) + all = (target == 0).sum() * (target == 1).sum() + auroc = hit[target == 0].sum() / (all + 1e-10) + return auroc + + +@R.register("metrics.auprc") +def area_under_prc(pred, target): + """ + Area under precision-recall curve (PRC). + + Parameters: + pred (Tensor): predictions of shape :math:`(n,)` + target (Tensor): binary targets of shape :math:`(n,)` + """ + order = pred.argsort(descending=True) + target = target[order] + precision = target.cumsum(0) / torch.arange(1, len(target) + 1, device=target.device) + auprc = precision[target == 1].sum() / ((target == 1).sum() + 1e-10) + return auprc + + +@R.register("metrics.r2") +def r2(pred, target): + """ + :math:`R^2` regression score. + + Parameters: + pred (Tensor): predictions of shape :math:`(n,)` + target (Tensor): targets of shape :math:`(n,)` + """ + total = torch.var(target, unbiased=False) + residual = F.mse_loss(pred, target) + return 1 - residual / total + + +@R.register("metrics.logp") +def logP(pred): + """ + Logarithm of partition coefficient between octanol and water for a compound. + + Parameters: + pred (PackedMolecule): molecules to evaluate + """ + logp = [] + for mol in pred: + mol = mol.to_molecule() + try: + with utils.no_rdkit_log(): + mol.UpdatePropertyCache() + score = Descriptors.MolLogP(mol) + except Chem.AtomValenceException: + score = 0 + logp.append(score) + + return torch.tensor(logp, dtype=torch.float, device=pred.device) + + +@R.register("metrics.plogp") +def penalized_logP(pred): + """ + Logarithm of partition coefficient, penalized by cycle length and synthetic accessibility. + + Parameters: + pred (PackedMolecule): molecules to evaluate + """ + # statistics from ZINC250k + logp_mean = 2.4570953396190123 + logp_std = 1.434324401111988 + sa_mean = 3.0525811293166134 + sa_std = 0.8335207024513095 + cycle_mean = 0.0485696876403053 + cycle_std = 0.2860212110245455 + + plogp = [] + for mol in pred: + cycles = nx.cycle_basis(nx.Graph(mol.edge_list[:, :2].tolist())) + if cycles: + max_cycle = max([len(cycle) for cycle in cycles]) + cycle = max(0, max_cycle - 6) + else: + cycle = 0 + mol = mol.to_molecule() + try: + with utils.no_rdkit_log(): + mol.UpdatePropertyCache() + Chem.GetSymmSSSR(mol) + logp = Descriptors.MolLogP(mol) + sa = sascorer.calculateScore(mol) + logp = (logp - logp_mean) / logp_std + sa = (sa - sa_mean) / sa_std + cycle = (cycle - cycle_mean) / cycle_std + score = logp - sa - cycle + except Chem.AtomValenceException: + score = -30 + plogp.append(score) + + return torch.tensor(plogp, dtype=torch.float, device=pred.device) + + +@R.register("metrics.SA") +def SA(pred): + """ + Synthetic accesibility score. + + Parameters: + pred (PackedMolecule): molecules to evaluate + """ + sa = [] + for mol in pred: + with utils.no_rdkit_log(): + score = sascorer.calculateScore(mol.to_molecule()) + sa.append(score) + + return torch.tensor(sa, dtype=torch.float, device=pred.device) + + +@R.register("metrics.qed") +def QED(pred): + """ + Quantitative estimation of drug-likeness. + + Parameters: + pred (PackedMolecule): molecules to evaluate + """ + qed = [] + for mol in pred: + try: + with utils.no_rdkit_log(): + score = Descriptors.qed(mol.to_molecule()) + except Chem.AtomValenceException: + score = -1 + qed.append(score) + + return torch.tensor(qed, dtype=torch.float, device=pred.device) + + +@R.register("metrics.validity") +def chemical_validity(pred): + """ + Chemical validity of molecules. + + Parameters: + pred (PackedMolecule): molecules to evaluate + """ + validity = [] + for i, mol in enumerate(pred): + with utils.no_rdkit_log(): + smiles = mol.to_smiles() + mol = Chem.MolFromSmiles(smiles) + validity.append(1 if mol else 0) + + return torch.tensor(validity, dtype=torch.float, device=pred.device) + + +@R.register("metrics.variadic_auroc") +def variadic_area_under_roc(pred, target, size): + """ + Area under receiver operating characteristic curve (ROC) for sets with variadic sizes. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + pred (Tensor): prediction of shape :math:`(B,)` + target (Tensor): target of shape :math:`(B,)`. + size (Tensor): size of sets of shape :math:`(N,)` + """ + index2graph = torch.repeat_interleave(size) + _, order = functional.variadic_sort(pred, size, descending=True) + cum_size = (size.cumsum(0) - size)[index2graph] + target = target[order + cum_size] + total_hit = functional.variadic_sum(target, size) + total_hit = total_hit.cumsum(0) - total_hit + hit = target.cumsum(0) - total_hit[index2graph] + hit = torch.where(target == 0, hit, torch.zeros_like(hit)) + all = functional.variadic_sum((target == 0).float(), size) * \ + functional.variadic_sum((target == 1).float(), size) + auroc = functional.variadic_sum(hit, size) / (all + 1e-10) + return auroc + + +@R.register("metrics.variadic_auprc") +def variadic_area_under_prc(pred, target, size): + """ + Area under precision-recall curve (PRC) for sets with variadic sizes. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + pred (Tensor): prediction of shape :math:`(B,)` + target (Tensor): target of shape :math:`(B,)`. + size (Tensor): size of sets of shape :math:`(N,)` + """ + index2graph = torch.repeat_interleave(size) + _, order = functional.variadic_sort(pred, size, descending=True) + cum_size = (size.cumsum(0) - size)[index2graph] + target = target[order + cum_size] + total_hit = functional.variadic_sum(target, size) + total_hit = total_hit.cumsum(0) - total_hit + hit = target.cumsum(0) - total_hit[index2graph] + total = torch.ones_like(target).cumsum(0) - (size.cumsum(0) - size)[index2graph] + precision = hit / total + precision = torch.where(target == 1, precision, torch.zeros_like(precision)) + auprc = functional.variadic_sum(precision, size) / \ + (functional.variadic_sum((target == 1).float(), size) + 1e-10) + return auprc + + +@R.register("metrics.f1_max") +def f1_max(pred, target): + """ + F1 score with the optimal threshold. + + This function first enumerates all possible thresholds for deciding positive and negative + samples, and then pick the threshold with the maximal F1 score. + + Parameters: + pred (Tensor): predictions of shape :math:`(B, N)` + target (Tensor): binary targets of shape :math:`(B, N)` + """ + order = pred.argsort(descending=True, dim=1) + target = target.gather(1, order) + precision = target.cumsum(1) / torch.ones_like(target).cumsum(1) + recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10) + is_start = torch.zeros_like(target).bool() + is_start[:, 0] = 1 + is_start = torch.scatter(is_start, 1, order, is_start) + + all_order = pred.flatten().argsort(descending=True) + order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1] + order = order.flatten() + inv_order = torch.zeros_like(order) + inv_order[order] = torch.arange(order.shape[0], device=order.device) + is_start = is_start.flatten()[all_order] + all_order = inv_order[all_order] + precision = precision.flatten() + recall = recall.flatten() + all_precision = precision[all_order] - \ + torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1]) + all_precision = all_precision.cumsum(0) / is_start.cumsum(0) + all_recall = recall[all_order] - \ + torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1]) + all_recall = all_recall.cumsum(0) / pred.shape[0] + all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10) + return all_f1.max() + + +@R.register("metrics.accuracy") +def accuracy(pred, target): + """ + Classification accuracy. + + Suppose there are :math:`N` sets and :math:`C` categories. + + Parameters: + pred (Tensor): prediction of shape :math:`(N, C)` + target (Tensor): target of shape :math:`(N,)` + """ + return (pred.argmax(dim=-1) == target).float().mean() + + +@R.register("metrics.variadic_accuracy") +def variadic_accuracy(input, target, size): + """ + Classification accuracy for categories with variadic sizes. + + Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`. + + Parameters: + input (Tensor): prediction of shape :math:`(B,)` + target (Tensor): target of shape :math:`(N,)`. Each target is a relative index in a sample. + size (Tensor): number of categories of shape :math:`(N,)` + """ + index2graph = torch.repeat_interleave(size) + + input_class = scatter_max(input, index2graph)[1] + target_index = target + size.cumsum(0) - size + accuracy = (input_class == target_index).float() + return accuracy + + +@R.register("metrics.variadic_top_precision") +def variadic_top_precision(pred, target, size, k): + """ + Top-k precision for sets with variadic sizes. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + pred (Tensor): prediction of shape :math:`(B,)` + target (Tensor): target of shape :math:`(B,)` + size (Tensor): size of sets of shape :math:`(N,)` + k (LongTensor): the k in "top-k" for different sets of shape :math:`(N,)` + """ + index = functional.variadic_topk(pred, size, k, largest=True)[1] + index = index + (size.cumsum(0) - size).repeat_interleave(k) + precision = functional.variadic_sum(target[index], k) / k + precision[size < k] = 0 + return precision + + +@R.register("metrics.mcc") +def matthews_corrcoef(pred, target): + """ + Matthews correlation coefficient between prediction and target. + + Definition follows matthews_corrcoef for K classes in sklearn. + For details, see: `https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef` + + Parameters: + pred (Tensor): prediction of shape :math: `(N, K)` + target (Tensor): target of shape :math: `(N,)` + """ + num_class = pred.size(-1) + pred = pred.argmax(-1) + ones = torch.ones(len(target), device=pred.device) + confusion_matrix = scatter_add(ones, target * num_class + pred, dim=0, dim_size=num_class ** 2) + confusion_matrix = confusion_matrix.view(num_class, num_class) + t = confusion_matrix.sum(dim=1) + p = confusion_matrix.sum(dim=0) + c = confusion_matrix.trace() + s = confusion_matrix.sum() + return (c * s - t @ p) / ((s * s - p @ p) * (s * s - t @ t) + 1e-10).sqrt() + + +@R.register("metrics.pearsonr") +def pearsonr(pred, target): + """ + Pearson correlation between prediction and target. + + Parameters: + pred (Tensor): prediction of shape :math: `(N,)` + target (Tensor): target of shape :math: `(N,)` + """ + pred_mean = pred.float().mean() + target_mean = target.float().mean() + pred_centered = pred - pred_mean + target_centered = target - target_mean + pred_normalized = pred_centered / pred_centered.norm(2) + target_normalized = target_centered / target_centered.norm(2) + pearsonr = pred_normalized @ target_normalized + return pearsonr + + +@R.register("metrics.spearmanr") +def spearmanr(pred, target): + """ + Spearman correlation between prediction and target. + + Parameters: + pred (Tensor): prediction of shape :math: `(N,)` + target (Tensor): target of shape :math: `(N,)` + """ + + def get_ranking(input): + input_set, input_inverse = input.unique(return_inverse=True) + order = input_inverse.argsort() + ranking = torch.zeros(len(input_inverse), device=input.device) + ranking[order] = torch.arange(1, len(input) + 1, dtype=torch.float, device=input.device) + + # for elements that have the same value, replace their rankings with the mean of their rankings + mean_ranking = scatter_mean(ranking, input_inverse, dim=0, dim_size=len(input_set)) + ranking = mean_ranking[input_inverse] + return ranking + + pred = get_ranking(pred) + target = get_ranking(target) + covariance = (pred * target).mean() - pred.mean() * target.mean() + pred_std = pred.std(unbiased=False) + target_std = target.std(unbiased=False) + spearmanr = covariance / (pred_std * target_std + 1e-10) + return spearmanr diff --git a/build/lib/torchdrug/metrics/rdkit/__init__.py b/build/lib/torchdrug/metrics/rdkit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/torchdrug/metrics/rdkit/sascorer.py b/build/lib/torchdrug/metrics/rdkit/sascorer.py new file mode 100644 index 00000000..49c5c7b9 --- /dev/null +++ b/build/lib/torchdrug/metrics/rdkit/sascorer.py @@ -0,0 +1,90 @@ +import os +import sys +import math +import pickle + +from rdkit import Chem +from rdkit.Chem import rdMolDescriptors + +from torchdrug import utils + +module = sys.modules[__name__] +path = os.path.dirname(__file__) + +# Calculate synthetic accessibility of molecules +# Code adapted from RDKit +# https://github.com/rdkit/rdkit/blob/master/Contrib/SA_Score/sascorer.py + + +def readFragmentScores(): + url = "https://github.com/rdkit/rdkit/raw/master/Contrib/SA_Score/fpscores.pkl.gz" + md5 = "2f80a169f9075e977154f9caec9e5c26" + + zip_file = utils.download(url, path, md5=md5) + pkl_file = utils.extract(zip_file) + with open(pkl_file, "rb") as fin: + data = pickle.load(fin) + outDict = {} + for i in data: + for j in range(1, len(i)): + outDict[i[j]] = float(i[0]) + return outDict + + +def numBridgeheadsAndSpiro(mol, ri=None): + nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) + nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) + return nBridgehead, nSpiro + + +def calculateScore(m): + if not hasattr(module, "fscores"): + module.fscores = readFragmentScores() + fscores = module.fscores + + fp = rdMolDescriptors.GetMorganFingerprint(m, 2) + fps = fp.GetNonzeroElements() + score1 = 0.0 + nf = 0 + for bitId, v in fps.items(): + nf += v + sfp = bitId + score1 += fscores.get(sfp, -4) * v + score1 /= nf + + nAtoms = m.GetNumAtoms() + nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) + ri = m.GetRingInfo() + nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) + nMacrocycles = 0 + for x in ri.AtomRings(): + if len(x) > 8: + nMacrocycles += 1 + + sizePenalty = nAtoms**1.005 - nAtoms + stereoPenalty = math.log10(nChiralCenters + 1) + spiroPenalty = math.log10(nSpiro + 1) + bridgePenalty = math.log10(nBridgeheads + 1) + macrocyclePenalty = 0.0 + if nMacrocycles > 0: + macrocyclePenalty = math.log10(2) + + score2 = 0.0 - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty + + score3 = 0.0 + if nAtoms > len(fps): + score3 = math.log(float(nAtoms) / len(fps)) * 0.5 + + sascore = score1 + score2 + score3 + + min = -4.0 + max = 2.5 + sascore = 11. - (sascore - min + 1) / (max - min) * 9.0 + if sascore > 8.0: + sascore = 8.0 + math.log(sascore + 1.0 - 9.0) + if sascore > 10.0: + sascore = 10.0 + elif sascore < 1.0: + sascore = 1.0 + + return sascore \ No newline at end of file diff --git a/build/lib/torchdrug/models/__init__.py b/build/lib/torchdrug/models/__init__.py new file mode 100644 index 00000000..deb00084 --- /dev/null +++ b/build/lib/torchdrug/models/__init__.py @@ -0,0 +1,49 @@ +from .chebnet import ChebyshevConvolutionalNetwork +from .gcn import GraphConvolutionalNetwork, RelationalGraphConvolutionalNetwork +from .gat import GraphAttentionNetwork +from .gin import GraphIsomorphismNetwork +from .schnet import SchNet +from .mpnn import MessagePassingNeuralNetwork +from .neuralfp import NeuralFingerprint +from .infograph import InfoGraph, MultiviewContrast +from .flow import GraphAutoregressiveFlow +from .esm import EvolutionaryScaleModeling +from .embedding import TransE, DistMult, ComplEx, RotatE, SimplE +from .neurallp import NeuralLogicProgramming +from .kbgat import KnowledgeBaseGraphAttentionNetwork +from .cnn import ProteinConvolutionalNetwork, ProteinResNet +from .lstm import ProteinLSTM +from .bert import ProteinBERT +from .statistic import Statistic +from .physicochemical import Physicochemical +from .gearnet import GeometryAwareRelationalGraphNeuralNetwork + +# alias +ChebNet = ChebyshevConvolutionalNetwork +GCN = GraphConvolutionalNetwork +GAT = GraphAttentionNetwork +RGCN = RelationalGraphConvolutionalNetwork +GIN = GraphIsomorphismNetwork +MPNN = MessagePassingNeuralNetwork +NFP = NeuralFingerprint +GraphAF = GraphAutoregressiveFlow +ESM = EvolutionaryScaleModeling +NeuralLP = NeuralLogicProgramming +KBGAT = KnowledgeBaseGraphAttentionNetwork +ProteinCNN = ProteinConvolutionalNetwork +GearNet = GeometryAwareRelationalGraphNeuralNetwork + +__all__ = [ + "ChebyshevConvolutionalNetwork", "GraphConvolutionalNetwork", "RelationalGraphConvolutionalNetwork", + "GraphAttentionNetwork", "GraphIsomorphismNetwork", "SchNet", "MessagePassingNeuralNetwork", + "NeuralFingerprint", + "InfoGraph", "MultiviewContrast", + "GraphAutoregressiveFlow", + "EvolutionaryScaleModeling", "ProteinConvolutionalNetwork", "GeometryAwareRelationalGraphNeuralNetwork", + "Statistic", "Physicochemical", + "TransE", "DistMult", "ComplEx", "RotatE", "SimplE", + "NeuralLogicProgramming", "KnowledgeBaseGraphAttentionNetwork", + "ChebNet", "GCN", "GAT", "RGCN", "GIN", "MPNN", "NFP", + "GraphAF", "ESM", "NeuralLP", "KBGAT", + "ProteinCNN", "ProteinResNet", "ProteinLSTM", "ProteinBERT", "GearNet", +] \ No newline at end of file diff --git a/build/lib/torchdrug/models/bert.py b/build/lib/torchdrug/models/bert.py new file mode 100644 index 00000000..6d700574 --- /dev/null +++ b/build/lib/torchdrug/models/bert.py @@ -0,0 +1,100 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from torchdrug import core, layers +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("models.ProteinBERT") +class ProteinBERT(nn.Module, core.Configurable): + """ + Protein BERT proposed in `Evaluating Protein Transfer Learning with TAPE`_. + + .. _Evaluating Protein Transfer Learning with TAPE: + https://arxiv.org/pdf/1906.08230.pdf + + Parameters: + input_dim (int): input dimension + hidden_dim (int, optional): hidden dimension + num_layers (int, optional): number of Transformer blocks + num_heads (int, optional): number of attention heads + intermediate_dim (int, optional): intermediate hidden dimension of Transformer block + activation (str or function, optional): activation function + hidden_dropout (float, optional): dropout ratio of hidden features + attention_dropout (float, optional): dropout ratio of attention maps + max_position (int, optional): maximum number of positions + """ + + def __init__(self, input_dim, hidden_dim=768, num_layers=12, num_heads=12, intermediate_dim=3072, + activation="gelu", hidden_dropout=0.1, attention_dropout=0.1, max_position=8192): + super(ProteinBERT, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = hidden_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.max_position = max_position + + self.num_residue_type = input_dim + self.embedding = nn.Embedding(input_dim + 3, hidden_dim) + self.position_embedding = nn.Embedding(max_position, hidden_dim) + self.layer_norm = nn.LayerNorm(hidden_dim) + self.dropout = nn.Dropout(hidden_dropout) + + self.layers = nn.ModuleList() + for i in range(self.num_layers): + self.layers.append(layers.ProteinBERTBlock(hidden_dim, intermediate_dim, num_heads, + attention_dropout, hidden_dropout, activation)) + + self.linear = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the residue representations and the graph representation(s). + + Parameters: + graph (Protein): :math:`n` protein(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``residue_feature`` and ``graph_feature`` fields: + residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)` + """ + input = graph.residue_type + size_ext = graph.num_residues + # Prepend BOS + bos = torch.ones(graph.batch_size, dtype=torch.long, device=self.device) * self.num_residue_type + input, size_ext = functional._extend(bos, torch.ones_like(size_ext), input, size_ext) + # Append EOS + eos = torch.ones(graph.batch_size, dtype=torch.long, device=self.device) * (self.num_residue_type + 1) + input, size_ext = functional._extend(input, size_ext, eos, torch.ones_like(size_ext)) + # Padding + input, mask = functional.variadic_to_padded(input, size_ext, value=self.num_residue_type + 2) + mask = mask.long().unsqueeze(-1) + + input = self.embedding(input) + position_indices = torch.arange(input.shape[1], device=input.device) + input = input + self.position_embedding(position_indices).unsqueeze(0) + input = self.layer_norm(input) + input = self.dropout(input) + + for layer in self.layers: + input = layer(input, mask) + + residue_feature = functional.padded_to_variadic(input, graph.num_residues) + + graph_feature = input[:, 0] + graph_feature = self.linear(graph_feature) + graph_feature = F.tanh(graph_feature) + + return { + "graph_feature": graph_feature, + "residue_feature": residue_feature + } diff --git a/build/lib/torchdrug/models/chebnet.py b/build/lib/torchdrug/models/chebnet.py new file mode 100644 index 00000000..15d11295 --- /dev/null +++ b/build/lib/torchdrug/models/chebnet.py @@ -0,0 +1,89 @@ +from collections.abc import Sequence + +import torch +from torch import nn + +from torchdrug import core, layers +from torchdrug.core import Registry as R + + +@R.register("models.ChebNet") +class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable): + """ + Chebyshev convolutional network proposed in + `Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering`_. + + .. _Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering: + https://arxiv.org/pdf/1606.09375.pdf + + Parameters: + input_dim (int): input dimension + hidden_dims (list of int): hidden dimensions + edge_input_dim (int, optional): dimension of edge features + k (int, optional): number of Chebyshev polynomials + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + concat_hidden (bool, optional): concat hidden representations from all layers as output + readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + """ + + def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False, + activation="relu", concat_hidden=False, readout="sum"): + super(ChebyshevConvolutionalNetwork, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1] + self.dims = [input_dim] + list(hidden_dims) + self.short_cut = short_cut + self.concat_hidden = concat_hidden + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k, + batch_norm, activation)) + + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + hiddens = [] + layer_input = input + + for layer in self.layers: + hidden = layer(graph, layer_input) + assert not torch.isnan(hidden).any() + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + graph_feature = self.readout(graph, node_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } \ No newline at end of file diff --git a/build/lib/torchdrug/models/cnn.py b/build/lib/torchdrug/models/cnn.py new file mode 100644 index 00000000..24ec7437 --- /dev/null +++ b/build/lib/torchdrug/models/cnn.py @@ -0,0 +1,214 @@ +from collections.abc import Sequence + +import torch +from torch import nn +from torch.nn import functional as F + +from torchdrug import core, layers +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("models.ProteinResNet") +class ProteinResNet(nn.Module, core.Configurable): + """ + Protein ResNet proposed in `Evaluating Protein Transfer Learning with TAPE`_. + + .. _Evaluating Protein Transfer Learning with TAPE: + https://arxiv.org/pdf/1906.08230.pdf + + Parameters: + input_dim (int): input dimension + hidden_dims (list of int): hidden dimensions + kernel_size (int, optional): size of convolutional kernel + stride (int, optional): stride of convolution + padding (int, optional): padding added to both sides of the input + activation (str or function, optional): activation function + short_cut (bool, optional): use short cut or not + concat_hidden (bool, optional): concat hidden representations from all layers as output + layer_norm (bool, optional): apply layer normalization or not + dropout (float, optional): dropout ratio of input features + readout (str, optional): readout function. Available functions are ``sum``, ``mean`` and ``attention``. + """ + + def __init__(self, input_dim, hidden_dims, kernel_size=3, stride=1, padding=1, + activation="gelu", short_cut=False, concat_hidden=False, layer_norm=False, + dropout=0, readout="attention"): + super(ProteinResNet, self).__init__() + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1] + self.dims = list(hidden_dims) + self.short_cut = short_cut + self.concat_hidden = concat_hidden + self.padding_id = input_dim - 1 + + self.embedding = nn.Linear(input_dim, hidden_dims[0]) + self.position_embedding = layers.SinusoidalPositionEmbedding(hidden_dims[0]) + if layer_norm: + self.layer_norm = nn.LayerNorm(hidden_dims[0]) + else: + self.layer_norm = None + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(layers.ProteinResNetBlock(self.dims[i], self.dims[i + 1], kernel_size, + stride, padding, activation)) + + if readout == "sum": + self.readout = layers.SumReadout("residue") + elif readout == "mean": + self.readout = layers.MeanReadout("residue") + elif readout == "attention": + self.readout = layers.AttentionReadout(self.output_dim, "residue") + else: + raise ValueError("Unknown readout `%s`" % readout) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the residue representations and the graph representation(s). + + Parameters: + graph (Protein): :math:`n` protein(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``residue_feature`` and ``graph_feature`` fields: + residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)` + """ + input = graph.residue_feature.float() + input, mask = functional.variadic_to_padded(input, graph.num_residues, value=self.padding_id) + mask = mask.unsqueeze(-1) + + input = self.embedding(input) + self.position_embedding(input).unsqueeze(0) + if self.layer_norm: + input = self.layer_norm(input) + if self.dropout: + input = self.dropout(input) + input = input * mask + + hiddens = [] + layer_input = input + + for layer in self.layers: + hidden = layer(layer_input, mask) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + hidden = torch.cat(hiddens, dim=-1) + else: + hidden = hiddens[-1] + + residue_feature = functional.padded_to_variadic(hidden, graph.num_residues) + graph_feature = self.readout(graph, residue_feature) + + return { + "graph_feature": graph_feature, + "residue_feature": residue_feature + } + + +@R.register("models.ProteinConvolutionalNetwork") +class ProteinConvolutionalNetwork(nn.Module, core.Configurable): + """ + Protein Shallow CNN proposed in `Is Transfer Learning Necessary for Protein Landscape Prediction?`_. + + .. _Is Transfer Learning Necessary for Protein Landscape Prediction?: + https://arxiv.org/pdf/2011.03443.pdf + + Parameters: + input_dim (int): input dimension + hidden_dims (list of int): hidden dimensions + kernel_size (int, optional): size of convolutional kernel + stride (int, optional): stride of convolution + padding (int, optional): padding added to both sides of the input + activation (str or function, optional): activation function + short_cut (bool, optional): use short cut or not + concat_hidden (bool, optional): concat hidden representations from all layers as output + readout (str, optional): readout function. Available functions are ``sum``, ``mean``, ``max`` and ``attention``. + """ + + def __init__(self, input_dim, hidden_dims, kernel_size=3, stride=1, padding=1, + activation='relu', short_cut=False, concat_hidden=False, readout="max"): + super(ProteinConvolutionalNetwork, self).__init__() + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1] + self.dims = [input_dim] + list(hidden_dims) + self.short_cut = short_cut + self.concat_hidden = concat_hidden + self.padding_id = input_dim - 1 + + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append( + nn.Conv1d(self.dims[i], self.dims[i+1], kernel_size, stride, padding) + ) + + if readout == "sum": + self.readout = layers.SumReadout("residue") + elif readout == "mean": + self.readout = layers.MeanReadout("residue") + elif readout == "max": + self.readout = layers.MaxReadout("residue") + elif readout == "attention": + self.readout = layers.AttentionReadout(self.output_dim, "residue") + else: + raise ValueError("Unknown readout `%s`" % readout) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the residue representations and the graph representation(s). + + Parameters: + graph (Protein): :math:`n` protein(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``residue_feature`` and ``graph_feature`` fields: + residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)` + """ + input = graph.residue_feature.float() + input = functional.variadic_to_padded(input, graph.num_residues, value=self.padding_id)[0] + + hiddens = [] + layer_input = input + + for layer in self.layers: + hidden = layer(layer_input.transpose(1, 2)).transpose(1, 2) + hidden = self.activation(hidden) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + hidden = torch.cat(hiddens, dim=-1) + else: + hidden = hiddens[-1] + + residue_feature = functional.padded_to_variadic(hidden, graph.num_residues) + graph_feature = self.readout(graph, residue_feature) + + return { + "graph_feature": graph_feature, + "residue_feature": residue_feature + } diff --git a/build/lib/torchdrug/models/embedding.py b/build/lib/torchdrug/models/embedding.py new file mode 100644 index 00000000..ce8ef09f --- /dev/null +++ b/build/lib/torchdrug/models/embedding.py @@ -0,0 +1,242 @@ +import torch +from torch import nn + +from torchdrug import core +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("models.TransE") +class TransE(nn.Module, core.Configurable): + """ + TransE embedding proposed in `Translating Embeddings for Modeling Multi-relational Data`_. + + .. _Translating Embeddings for Modeling Multi-relational Data: + https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf + + Parameters: + num_entity (int): number of entities + num_relation (int): number of relations + embedding_dim (int): dimension of embeddings + max_score (float, optional): maximal score for triplets + """ + + def __init__(self, num_entity, num_relation, embedding_dim, max_score=12): + super(TransE, self).__init__() + self.num_entity = num_entity + self.num_relation = num_relation + self.max_score = max_score + + self.entity = nn.Parameter(torch.empty(num_entity, embedding_dim)) + self.relation = nn.Parameter(torch.empty(num_relation, embedding_dim)) + + nn.init.uniform_(self.entity, -self.max_score / embedding_dim, self.max_score / embedding_dim) + nn.init.uniform_(self.relation, -self.max_score / embedding_dim, self.max_score / embedding_dim) + + def forward(self, graph, h_index, t_index, r_index, all_loss=None, metric=None): + """ + Compute the score for each triplet. + + Parameters: + graph (Graph): fact graph + h_index (Tensor): indexes of head entities + t_index (Tensor): indexes of tail entities + r_index (Tensor): indexes of relations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + """ + score = functional.transe_score(self.entity, self.relation, h_index, t_index, r_index) + return self.max_score - score + + +@R.register("models.DistMult") +class DistMult(nn.Module, core.Configurable): + """ + DistMult embedding proposed in `Embedding Entities and Relations for Learning and Inference in Knowledge Bases`_. + + .. _Embedding Entities and Relations for Learning and Inference in Knowledge Bases: + https://arxiv.org/pdf/1412.6575.pdf + + Parameters: + num_entity (int): number of entities + num_relation (int): number of relations + embedding_dim (int): dimension of embeddings + l3_regularization (float, optional): weight for l3 regularization + """ + + def __init__(self, num_entity, num_relation, embedding_dim, l3_regularization=0): + super(DistMult, self).__init__() + self.num_entity = num_entity + self.num_relation = num_relation + self.l3_regularization = l3_regularization + + self.entity = nn.Parameter(torch.empty(num_entity, embedding_dim)) + self.relation = nn.Parameter(torch.empty(num_relation, embedding_dim)) + + nn.init.uniform_(self.entity, -0.5, 0.5) + nn.init.uniform_(self.relation, -0.5, 0.5) + + def forward(self, graph, h_index, t_index, r_index, all_loss=None, metric=None): + """ + Compute the score for each triplet. + + Parameters: + graph (Graph): fact graph + h_index (Tensor): indexes of head entities + t_index (Tensor): indexes of tail entities + r_index (Tensor): indexes of relations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + """ + score = functional.distmult_score(self.entity, self.relation, h_index, t_index, r_index) + + if all_loss is not None and self.l3_regularization > 0: + loss = (self.entity.abs() ** 3).sum() + (self.relation.abs() ** 3).sum() + all_loss += loss * self.l3_regularization + metric["l3 regularization"] = loss / (self.num_entity + self.num_relation) + + return score + + +@R.register("models.ComplEx") +class ComplEx(nn.Module, core.Configurable): + """ + ComplEx embedding proposed in `Complex Embeddings for Simple Link Prediction`_. + + .. _Complex Embeddings for Simple Link Prediction: + http://proceedings.mlr.press/v48/trouillon16.pdf + + Parameters: + num_entity (int): number of entities + num_relation (int): number of relations + embedding_dim (int): dimension of embeddings + l3_regularization (float, optional): weight for l3 regularization + """ + + def __init__(self, num_entity, num_relation, embedding_dim, l3_regularization=0): + super(ComplEx, self).__init__() + self.num_entity = num_entity + self.num_relation = num_relation + self.l3_regularization = l3_regularization + + self.entity = nn.Parameter(torch.empty(num_entity, embedding_dim)) + self.relation = nn.Parameter(torch.empty(num_relation, embedding_dim)) + + nn.init.uniform_(self.entity, -0.5, 0.5) + nn.init.uniform_(self.relation, -0.5, 0.5) + + def forward(self, graph, h_index, t_index, r_index, all_loss=None, metric=None): + """ + Compute the score for triplets. + + Parameters: + graph (Graph): fact graph + h_index (Tensor): indexes of head entities + t_index (Tensor): indexes of tail entities + r_index (Tensor): indexes of relations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + """ + score = functional.complex_score(self.entity, self.relation, h_index, t_index, r_index) + + if all_loss is not None and self.l3_regularization > 0: + loss = (self.entity.abs() ** 3).sum() + (self.relation.abs() ** 3).sum() + all_loss += loss * self.l3_regularization + metric["l3 regularization"] = loss / (self.num_entity + self.num_relation) + + return score + + +@R.register("models.RotatE") +class RotatE(nn.Module, core.Configurable): + """ + RotatE embedding proposed in `RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space`_. + + .. _RotatE\: Knowledge Graph Embedding by Relational Rotation in Complex Space: + https://arxiv.org/pdf/1902.10197.pdf + + Parameters: + num_entity (int): number of entities + num_relation (int): number of relations + embedding_dim (int): dimension of embeddings + max_score (float, optional): maximal score for triplets + """ + + def __init__(self, num_entity, num_relation, embedding_dim, max_score=12): + super(RotatE, self).__init__() + self.num_entity = num_entity + self.num_relation = num_relation + self.max_score = max_score + + self.entity = nn.Parameter(torch.empty(num_entity, embedding_dim)) + self.relation = nn.Parameter(torch.empty(num_relation, embedding_dim // 2)) + + nn.init.uniform_(self.entity, -max_score * 2 / embedding_dim, max_score * 2 / embedding_dim) + nn.init.uniform_(self.relation, -max_score * 2 / embedding_dim, max_score * 2 / embedding_dim) + pi = torch.acos(torch.zeros(1)).item() * 2 + self.relation_scale = pi * embedding_dim / max_score / 2 + + def forward(self, graph, h_index, t_index, r_index, all_loss=None, metric=None): + """ + Compute the score for each triplet. + + Parameters: + graph (Graph): fact graph + h_index (Tensor): indexes of head entities + t_index (Tensor): indexes of tail entities + r_index (Tensor): indexes of relations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + """ + score = functional.rotate_score(self.entity, self.relation * self.relation_scale, + h_index, t_index, r_index) + return self.max_score - score + + +@R.register("models.SimplE") +class SimplE(nn.Module, core.Configurable): + """ + SimplE embedding proposed in `SimplE Embedding for Link Prediction in Knowledge Graphs`_. + + .. _SimplE Embedding for Link Prediction in Knowledge Graphs: + https://papers.nips.cc/paper/2018/file/b2ab001909a8a6f04b51920306046ce5-Paper.pdf + + Parameters: + num_entity (int): number of entities + num_relation (int): number of relations + embedding_dim (int): dimension of embeddings + l3_regularization (float, optional): maximal score for triplets + """ + + def __init__(self, num_entity, num_relation, embedding_dim, l3_regularization=0): + super(SimplE, self).__init__() + self.num_entity = num_entity + self.num_relation = num_relation + self.l3_regularization = l3_regularization + + self.entity = nn.Parameter(torch.empty(num_entity, embedding_dim)) + self.relation = nn.Parameter(torch.empty(num_relation, embedding_dim)) + + nn.init.uniform_(self.entity, -0.5, 0.5) + nn.init.uniform_(self.relation, -0.5, 0.5) + + def forward(self, graph, h_index, t_index, r_index, all_loss=None, metric=None): + """ + Compute the score for each triplet. + + Parameters: + graph (Graph): fact graph + h_index (Tensor): indexes of head entities + t_index (Tensor): indexes of tail entities + r_index (Tensor): indexes of relations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + """ + score = functional.simple_score(self.entity, self.relation, h_index, t_index, r_index) + + if all_loss is not None and self.l3_regularization > 0: + loss = (self.entity.abs() ** 3).sum() + (self.relation.abs() ** 3).sum() + all_loss += loss * self.l3_regularization + metric["l3 regularization"] = loss / (self.num_entity + self.num_relation) + + return score \ No newline at end of file diff --git a/build/lib/torchdrug/models/esm.py b/build/lib/torchdrug/models/esm.py new file mode 100644 index 00000000..d80aaace --- /dev/null +++ b/build/lib/torchdrug/models/esm.py @@ -0,0 +1,172 @@ +import os +import warnings + +import torch +from torch import nn +import esm + +from torchdrug import core, layers, utils, data +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("models.ESM") +class EvolutionaryScaleModeling(nn.Module, core.Configurable): + """ + The protein language model, Evolutionary Scale Modeling (ESM) proposed in + `Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences`_. + + .. _Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences: + https://www.biorxiv.org/content/10.1101/622803v1.full.pdf + + Parameters: + path (str): path to store ESM model weights + model (str, optional): model name. Available model names are ``ESM-1b``, ``ESM-1v`` and ``ESM-1b-regression``. + readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + """ + + url = { + "ESM-1b": "https://dl.fbaipublicfiles.com/fair-esm/models/esm1b_t33_650M_UR50S.pt", + "ESM-1v": "https://dl.fbaipublicfiles.com/fair-esm/models/esm1v_t33_650M_UR90S_1.pt", + "ESM-1b-regression": + "https://dl.fbaipublicfiles.com/fair-esm/regression/esm1b_t33_650M_UR50S-contact-regression.pt", + "ESM-2-8M": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t6_8M_UR50D.pt", + "ESM-2-35M": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t12_35M_UR50D.pt", + "ESM-2-150M": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t30_150M_UR50D.pt", + "ESM-2-650M": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt", + "ESM-2-3B": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t36_3B_UR50D.pt", + "ESM-2-15B": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t48_15B_UR50D.pt", + } + + md5 = { + "ESM-1b": "ba8914bc3358cae2254ebc8874ee67f6", + "ESM-1v": "1f04c2d2636b02b544ecb5fbbef8fefd", + "ESM-1b-regression": "e7fe626dfd516fb6824bd1d30192bdb1", + "ESM-2-8M": "8039fc9cee7f71cd2633b13b5a38ff50", + "ESM-2-35M": "a894ddb31522e511e1273abb23b5f974", + "ESM-2-150M": "229fcf8f9f3d4d442215662ca001b906", + "ESM-2-650M": "ba6d997e29db07a2ad9dca20e024b102", + "ESM-2-3B": "d37a0d0dbe7431e48a72072b9180b16b", + "ESM-2-15B": "af61a9c0b792ae50e244cde443b7f4ac", + } + + output_dim = { + "ESM-1b": 1280, + "ESM-1v": 1280, + "ESM-2-8M": 320, + "ESM-2-35M": 480, + "ESM-2-150M": 640, + "ESM-2-650M": 1280, + "ESM-2-3B": 2560, + "ESM-2-15B": 5120, + } + + num_layer = { + "ESM-1b": 33, + "ESM-1v": 33, + "ESM-2-8M": 6, + "ESM-2-35M": 12, + "ESM-2-150M": 30, + "ESM-2-650M": 33, + "ESM-2-3B": 36, + "ESM-2-15B": 48, + } + + max_input_length = 1024 - 2 + + def __init__(self, path, model="ESM-1b", readout="mean"): + super(EvolutionaryScaleModeling, self).__init__() + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + + _model, alphabet = self.load_weight(path, model) + self.alphabet = alphabet + mapping = self.construct_mapping(alphabet) + self.output_dim = self.output_dim[model] + self.model = _model + self.alphabet = alphabet + self.repr_layer = self.num_layer[model] + self.register_buffer("mapping", mapping) + + if readout == "sum": + self.readout = layers.SumReadout("residue") + elif readout == "mean": + self.readout = layers.MeanReadout("residue") + else: + raise ValueError("Unknown readout `%s`" % readout) + + def load_weight(self, path, model): + if model not in self.url: + raise ValueError("Unknown model `%s`" % model) + model_file = utils.download(self.url[model], path, md5=self.md5[model]) + model_data = torch.load(model_file, map_location="cpu") + if model != "ESM-1v" and not model.startswith("ESM-2"): + regression_model = "%s-regression" % model + regression_file = utils.download(self.url[regression_model], path, md5=self.md5[regression_model]) + regression_data = torch.load(regression_file, map_location="cpu") + else: + regression_data = None + model_name = os.path.basename(self.url[model]) + return esm.pretrained.load_model_and_alphabet_core(model_name, model_data, regression_data) + + def construct_mapping(self, alphabet): + mapping = [-1] * max(len(data.Protein.id2residue_symbol), len(self.alphabet)) + for i, token in data.Protein.id2residue_symbol.items(): + mapping[i] = alphabet.get_idx(token) + mapping = torch.tensor(mapping) + return mapping + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the residue representations and the graph representation(s). + + Parameters: + graph (Protein): :math:`n` protein(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``residue_feature`` and ``graph_feature`` fields: + residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)` + """ + input = graph.residue_type + input = self.mapping[input] + input[input == -1] = graph.residue_type[input == -1] + size = graph.num_residues + if (size > self.max_input_length).any(): + warnings.warn("ESM can only encode proteins within %d residues. Truncate the input to fit into ESM." + % self.max_input_length) + starts = size.cumsum(0) - size + size = size.clamp(max=self.max_input_length) + ends = starts + size + mask = functional.multi_slice_mask(starts, ends, graph.num_residue) + input = input[mask] + graph = graph.subresidue(mask) + size_ext = size + if self.alphabet.prepend_bos: + bos = torch.ones(graph.batch_size, dtype=torch.long, device=self.device) * self.alphabet.cls_idx + input, size_ext = functional._extend(bos, torch.ones_like(size_ext), input, size_ext) + if self.alphabet.append_eos: + eos = torch.ones(graph.batch_size, dtype=torch.long, device=self.device) * self.alphabet.eos_idx + input, size_ext = functional._extend(input, size_ext, eos, torch.ones_like(size_ext)) + input = functional.variadic_to_padded(input, size_ext, value=self.alphabet.padding_idx)[0] + + output = self.model(input, repr_layers=[self.repr_layer]) + residue_feature = output["representations"][self.repr_layer] + + residue_feature = functional.padded_to_variadic(residue_feature, size_ext) + starts = size_ext.cumsum(0) - size_ext + if self.alphabet.prepend_bos: + starts = starts + 1 + ends = starts + size + mask = functional.multi_slice_mask(starts, ends, len(residue_feature)) + residue_feature = residue_feature[mask] + graph_feature = self.readout(graph, residue_feature) + + return { + "graph_feature": graph_feature, + "residue_feature": residue_feature + } diff --git a/build/lib/torchdrug/models/flow.py b/build/lib/torchdrug/models/flow.py new file mode 100644 index 00000000..a1184f90 --- /dev/null +++ b/build/lib/torchdrug/models/flow.py @@ -0,0 +1,121 @@ +import torch +from torch import nn + +from torchdrug import core, layers +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("models.GraphAF") +class GraphAutoregressiveFlow(nn.Module, core.Configurable): + """ + Graph autoregressive flow proposed in `GraphAF: a Flow-based Autoregressive Model for Molecular Graph Generation`_. + + .. _GraphAF\: a Flow-based Autoregressive Model for Molecular Graph Generation: + https://arxiv.org/pdf/2001.09382.pdf + + Parameters: + model (nn.Module): graph representation model + prior (nn.Module): prior distribution + use_edge (bool, optional): use edge or not + num_flow_layer (int, optional): number of conditional flow layers + num_mlp_layer (int, optional): number of MLP layers in each conditional flow + dequantization_noise (float, optional): scale of dequantization noise + """ + + def __init__(self, model, prior, use_edge=False, num_layer=6, num_mlp_layer=2, dequantization_noise=0.9): + super(GraphAutoregressiveFlow, self).__init__() + self.model = model + self.prior = prior + self.use_edge = use_edge + self.input_dim = self.output_dim = prior.dim + self.dequantization_noise = dequantization_noise + assert dequantization_noise < 1 + + self.layers = nn.ModuleList() + for i in range(num_layer): + condition_dim = model.output_dim * (3 if use_edge else 1) + self.layers.append(layers.ConditionalFlow(self.input_dim, condition_dim, + [model.output_dim] * (num_mlp_layer - 1))) + + def _standarize_edge(self, graph, edge): + if edge is not None: + edge = edge.clone() + if (edge[:, :2] >= graph.num_nodes.unsqueeze(-1)).any(): + raise ValueError("Edge index exceeds the number of nodes in the graph") + edge[:, :2] += (graph.num_cum_nodes - graph.num_nodes).unsqueeze(-1) + return edge + + def forward(self, graph, input, edge=None, all_loss=None, metric=None): + """ + Compute the log-likelihood for the input given the graph(s). + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): discrete data of shape :math:`(n,)` + edge (Tensor, optional): edge list of shape :math:`(n, 2)`. + If specified, additionally condition on the edge for each input. + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + """ + if self.use_edge and edge is None: + raise ValueError("`use_edge` is true, but no edge is provided") + + edge = self._standarize_edge(graph, edge) + + node_feature = functional.one_hot(graph.atom_type, self.model.input_dim) + feature = self.model(graph, node_feature, all_loss, metric) + node_feature = feature["node_feature"] + graph_feature = feature["graph_feature"] + if self.use_edge: + condition = torch.cat([node_feature[edge], graph_feature.unsqueeze(1)], dim=1).flatten(1) + else: + condition = graph_feature + + x = functional.one_hot(input, self.input_dim) + x = x + self.dequantization_noise * torch.rand_like(x) + + log_dets = [] + for layer in self.layers: + x, log_det = layer(x, condition) + log_dets.append(log_det) + + log_prior = self.prior(x) + log_det = torch.stack(log_dets).sum(dim=0) + log_likelihood = log_prior + log_det + log_likelihood = log_likelihood.sum(dim=-1) + + return log_likelihood # (batch_size,) + + def sample(self, graph, edge=None, all_loss=None, metric=None): + """ + Sample discrete data based on the given graph(s). + + Parameters: + graph (Graph): :math:`n` graph(s) + edge (Tensor, optional): edge list of shape :math:`(n, 2)`. + If specified, additionally condition on the edge for each input. + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + """ + if self.use_edge and edge is None: + raise ValueError("`use_edge` is true, but no edge is provided") + + edge = self._standarize_edge(graph, edge) + + node_feature = functional.one_hot(graph.atom_type, self.model.input_dim) + feature = self.model(graph, node_feature, all_loss, metric) + node_feature = feature["node_feature"] + graph_feature = feature["graph_feature"] + if self.use_edge: + condition = torch.cat([node_feature[edge], graph_feature.unsqueeze(1)], dim=1).flatten(1) + else: + condition = graph_feature + + x = self.prior.sample(len(graph)) + for layer in self.layers[::-1]: + x, log_det = layer.reverse(x, condition) + + output = x.argmax(dim=-1) + + return output # (batch_size,) \ No newline at end of file diff --git a/build/lib/torchdrug/models/gat.py b/build/lib/torchdrug/models/gat.py new file mode 100644 index 00000000..9154fcce --- /dev/null +++ b/build/lib/torchdrug/models/gat.py @@ -0,0 +1,88 @@ +from collections.abc import Sequence + +import torch +from torch import nn + +from torchdrug import core, layers +from torchdrug.core import Registry as R + + +@R.register("models.GAT") +class GraphAttentionNetwork(nn.Module, core.Configurable): + """ + Graph Attention Network proposed in `Graph Attention Networks`_. + + .. _Graph Attention Networks: + https://arxiv.org/pdf/1710.10903.pdf + + Parameters: + input_dim (int): input dimension + hidden_dims (list of int): hidden dimensions + edge_input_dim (int, optional): dimension of edge features + num_head (int, optional): number of attention heads + negative_slope (float, optional): negative slope of leaky relu activation + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + concat_hidden (bool, optional): concat hidden representations from all layers as output + readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + """ + + def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, + batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): + super(GraphAttentionNetwork, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1] + self.dims = [input_dim] + list(hidden_dims) + self.short_cut = short_cut + self.concat_hidden = concat_hidden + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(layers.GraphAttentionConv(self.dims[i], self.dims[i + 1], edge_input_dim, num_head, + negative_slope, batch_norm, activation)) + + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + hiddens = [] + layer_input = input + + for layer in self.layers: + hidden = layer(graph, layer_input) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + graph_feature = self.readout(graph, node_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } \ No newline at end of file diff --git a/build/lib/torchdrug/models/gcn.py b/build/lib/torchdrug/models/gcn.py new file mode 100644 index 00000000..5415d346 --- /dev/null +++ b/build/lib/torchdrug/models/gcn.py @@ -0,0 +1,168 @@ +from collections.abc import Sequence + +import torch +from torch import nn + +from torchdrug import core, layers +from torchdrug.core import Registry as R + + +@R.register("models.GCN") +class GraphConvolutionalNetwork(nn.Module, core.Configurable): + """ + Graph Convolutional Network proposed in `Semi-Supervised Classification with Graph Convolutional Networks`_. + + .. _Semi-Supervised Classification with Graph Convolutional Networks: + https://arxiv.org/pdf/1609.02907.pdf + + Parameters: + input_dim (int): input dimension + hidden_dims (list of int): hidden dimensions + edge_input_dim (int, optional): dimension of edge features + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + concat_hidden (bool, optional): concat hidden representations from all layers as output + readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + """ + + def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, + activation="relu", concat_hidden=False, readout="sum"): + super(GraphConvolutionalNetwork, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1] + self.dims = [input_dim] + list(hidden_dims) + self.short_cut = short_cut + self.concat_hidden = concat_hidden + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(layers.GraphConv(self.dims[i], self.dims[i + 1], edge_input_dim, batch_norm, activation)) + + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + hiddens = [] + layer_input = input + + for layer in self.layers: + hidden = layer(graph, layer_input) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + graph_feature = self.readout(graph, node_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } + + +@R.register("models.RGCN") +class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable): + """ + Relational Graph Convolutional Network proposed in `Modeling Relational Data with Graph Convolutional Networks?`_. + + .. _Modeling Relational Data with Graph Convolutional Networks?: + https://arxiv.org/pdf/1703.06103.pdf + + Parameters: + input_dim (int): input dimension + hidden_dims (list of int): hidden dimensions + num_relation (int): number of relations + edge_input_dim (int, optional): dimension of edge features + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + concat_hidden (bool, optional): concat hidden representations from all layers as output + readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + """ + + def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False, + activation="relu", concat_hidden=False, readout="sum"): + super(RelationalGraphConvolutionalNetwork, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1) + self.dims = [input_dim] + list(hidden_dims) + self.num_relation = num_relation + self.short_cut = short_cut + self.concat_hidden = concat_hidden + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim, + batch_norm, activation)) + + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + + Require the graph(s) to have the same number of relations as this module. + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + hiddens = [] + layer_input = input + + for layer in self.layers: + hidden = layer(graph, layer_input) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + graph_feature = self.readout(graph, node_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } \ No newline at end of file diff --git a/build/lib/torchdrug/models/gearnet.py b/build/lib/torchdrug/models/gearnet.py new file mode 100644 index 00000000..82fe0224 --- /dev/null +++ b/build/lib/torchdrug/models/gearnet.py @@ -0,0 +1,123 @@ +from collections.abc import Sequence + +import torch +from torch import nn +from torch_scatter import scatter_add + +from torchdrug import core, layers +from torchdrug.core import Registry as R + + +@R.register("models.GearNet") +class GeometryAwareRelationalGraphNeuralNetwork(nn.Module, core.Configurable): + """ + Geometry Aware Relational Graph Neural Network proposed in + `Protein Representation Learning by Geometric Structure Pretraining`_. + + .. _Protein Representation Learning by Geometric Structure Pretraining: + https://arxiv.org/pdf/2203.06125.pdf + + Parameters: + input_dim (int): input dimension + hidden_dims (list of int): hidden dimensions + num_relation (int): number of relations + edge_input_dim (int, optional): dimension of edge features + num_angle_bin (int, optional): number of bins to discretize angles between edges. + The discretized angles are used as relations in edge message passing. + If not provided, edge message passing is disabled. + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + concat_hidden (bool, optional): concat hidden representations from all layers as output + readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + """ + + def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, num_angle_bin=None, + short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): + super(GeometryAwareRelationalGraphNeuralNetwork, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1] + self.dims = [input_dim] + list(hidden_dims) + self.edge_dims = [edge_input_dim] + self.dims[:-1] + self.num_relation = num_relation + self.num_angle_bin = num_angle_bin + self.short_cut = short_cut + self.concat_hidden = concat_hidden + self.batch_norm = batch_norm + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(layers.GeometricRelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, + None, batch_norm, activation)) + if num_angle_bin: + self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin) + self.edge_layers = nn.ModuleList() + for i in range(len(self.edge_dims) - 1): + self.edge_layers.append(layers.GeometricRelationalGraphConv( + self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation)) + + if batch_norm: + self.batch_norms = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1])) + + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + hiddens = [] + layer_input = input + if self.num_angle_bin: + line_graph = self.spatial_line_graph(graph) + edge_input = line_graph.node_feature.float() + + for i in range(len(self.layers)): + hidden = self.layers[i](graph, layer_input) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + if self.num_angle_bin: + edge_hidden = self.edge_layers[i](line_graph, edge_input) + edge_weight = graph.edge_weight.unsqueeze(-1) + node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2] + update = scatter_add(edge_hidden * edge_weight, node_out, dim=0, + dim_size=graph.num_node * self.num_relation) + update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1]) + update = self.layers[i].linear(update) + update = self.layers[i].activation(update) + hidden = hidden + update + edge_input = edge_hidden + if self.batch_norm: + hidden = self.batch_norms[i](hidden) + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + graph_feature = self.readout(graph, node_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } \ No newline at end of file diff --git a/build/lib/torchdrug/models/gin.py b/build/lib/torchdrug/models/gin.py new file mode 100644 index 00000000..990bb4b2 --- /dev/null +++ b/build/lib/torchdrug/models/gin.py @@ -0,0 +1,91 @@ +from collections.abc import Sequence + +import torch +from torch import nn + +from torchdrug import core, layers +from torchdrug.core import Registry as R + + +@R.register("models.GIN") +class GraphIsomorphismNetwork(nn.Module, core.Configurable): + """ + Graph Ismorphism Network proposed in `How Powerful are Graph Neural Networks?`_ + + .. _How Powerful are Graph Neural Networks?: + https://arxiv.org/pdf/1810.00826.pdf + + Parameters: + input_dim (int): input dimension + hidden_dims (list of int): hidden dimensions + edge_input_dim (int, optional): dimension of edge features + num_mlp_layer (int, optional): number of MLP layers + eps (int, optional): initial epsilon + learn_eps (bool, optional): learn epsilon or not + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + concat_hidden (bool, optional): concat hidden representations from all layers as output + readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + """ + + def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, + short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, + readout="sum"): + super(GraphIsomorphismNetwork, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1] + self.dims = [input_dim] + list(hidden_dims) + self.short_cut = short_cut + self.concat_hidden = concat_hidden + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + layer_hidden_dims = [self.dims[i + 1]] * (num_mlp_layer - 1) + self.layers.append(layers.GraphIsomorphismConv(self.dims[i], self.dims[i + 1], edge_input_dim, + layer_hidden_dims, eps, learn_eps, batch_norm, activation)) + + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + hiddens = [] + layer_input = input + + for layer in self.layers: + hidden = layer(graph, layer_input) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + graph_feature = self.readout(graph, node_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } \ No newline at end of file diff --git a/build/lib/torchdrug/models/infograph.py b/build/lib/torchdrug/models/infograph.py new file mode 100644 index 00000000..935845d7 --- /dev/null +++ b/build/lib/torchdrug/models/infograph.py @@ -0,0 +1,167 @@ +import copy +import random + +import torch +from torch import nn +from torch.nn import functional as F + +from torchdrug import core, layers +from torchdrug.core import Registry as R + + +@R.register("models.InfoGraph") +class InfoGraph(nn.Module, core.Configurable): + """ + InfoGraph proposed in + `InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information + Maximization`_. + + .. _InfoGraph\: + Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization: + https://arxiv.org/pdf/1908.01000.pdf + + Parameters: + model (nn.Module): node & graph representation model + num_mlp_layer (int, optional): number of MLP layers in mutual information estimators + activation (str or function, optional): activation function + loss_weight (float, optional): weight of both unsupervised & transfer losses + separate_model (bool, optional): separate supervised and unsupervised encoders. + If true, the unsupervised loss will be applied on a separate encoder, + and a transfer loss is applied between the two encoders. + """ + + def __init__(self, model, num_mlp_layer=2, activation="relu", loss_weight=1, separate_model=False): + super(InfoGraph, self).__init__() + self.model = model + self.separate_model = separate_model + self.loss_weight = loss_weight + self.output_dim = self.model.output_dim + + if separate_model: + self.unsupervised_model = copy.deepcopy(model) + self.transfer_mi = layers.MutualInformation(model.output_dim, num_mlp_layer, activation) + else: + self.unsupervised_model = model + self.unsupervised_mi = layers.MutualInformation(model.output_dim, num_mlp_layer, activation) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + Add the mutual information between graph and nodes to the loss. + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + output = self.model(graph, input) + + if all_loss is not None: + if self.separate_model: + unsupervised_output = self.unsupervised_model(graph, input) + mutual_info = self.transfer_mi(output["graph_feature"], unsupervised_output["graph_feature"]) + + metric["distillation mutual information"] = mutual_info + if self.loss_weight > 0: + all_loss -= mutual_info * self.loss_weight + else: + unsupervised_output = output + + graph_index = graph.node2graph + node_index = torch.arange(graph.num_node, device=graph.device) + pair_index = torch.stack([graph_index, node_index], dim=-1) + + mutual_info = self.unsupervised_mi(unsupervised_output["graph_feature"], + unsupervised_output["node_feature"], pair_index) + + metric["graph-node mutual information"] = mutual_info + if self.loss_weight > 0: + all_loss -= mutual_info * self.loss_weight + + return output + + +@R.register("models.MultiviewContrast") +class MultiviewContrast(nn.Module, core.Configurable): + """ + Multiview Contrast proposed in `Protein Representation Learning by Geometric Structure Pretraining`_. + + .. _Protein Representation Learning by Geometric Structure Pretraining: + https://arxiv.org/pdf/2203.06125.pdf + + Parameters: + model (nn.Module): node & graph representation model + crop_funcs (list of nn.Module): list of cropping functions + noise_funcs (list of nn.Module): list of noise functions + num_mlp_layer (int, optional): number of MLP layers in mutual information estimators + activation (str or function, optional): activation function + tau (float, optional): temperature in InfoNCE loss + """ + + eps = 1e-10 + + def __init__(self, model, crop_funcs, noise_funcs, num_mlp_layer=2, activation="relu", tau=0.07): + super(MultiviewContrast, self).__init__() + self.model = model + self.crop_funcs = crop_funcs + self.noise_funcs = noise_funcs + self.tau = tau + + self.mlp = layers.MLP(model.output_dim, [model.output_dim] * num_mlp_layer, activation=activation) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the graph representations of two augmented views. + Each view is generated by randomly picking a cropping function and a noise function. + Add the mutual information between two augmented views to the loss. + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature1``, ``node_feature2``, ``graph_feature1`` and ``graph_feature2`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + for two augmented views respectively + """ + # Get two augmented views + graph = copy.copy(graph) + if graph.view == "residue": + with graph.residue(): + graph.input = input + else: + with graph.atom(): + graph.input = input + crop_func1, noise_func1 = random.sample(self.crop_funcs, 1)[0], random.sample(self.noise_funcs, 1)[0] + graph1 = crop_func1(graph) + graph1 = noise_func1(graph1) + output1 = self.model(graph1, graph1.input) + + crop_func2, noise_func2 = random.sample(self.crop_funcs, 1)[0], random.sample(self.noise_funcs, 1)[0] + graph2 = crop_func2(graph) + graph2 = noise_func2(graph2) + output2 = self.model(graph2, graph2.input) + + # Compute mutual information loss + if all_loss is not None: + x = self.mlp(output1["graph_feature"]) + y = self.mlp(output2["graph_feature"]) + + score = F.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0), dim=-1) + score = score / self.tau + is_positive = torch.diag(torch.ones(len(x), dtype=torch.bool, device=self.device)) + mutual_info = (score[is_positive] - score.logsumexp(dim=-1)).mean() + + metric["multiview mutual information"] = mutual_info + all_loss -= mutual_info + + output = {"node_feature1": output1["node_feature"], "graph_feature1": output1["graph_feature"], + "node_feature2": output2["node_feature"], "graph_feature2": output2["graph_feature"]} + return output diff --git a/build/lib/torchdrug/models/kbgat.py b/build/lib/torchdrug/models/kbgat.py new file mode 100644 index 00000000..4a4a44b1 --- /dev/null +++ b/build/lib/torchdrug/models/kbgat.py @@ -0,0 +1,62 @@ +import torch +from torch import nn + +from torchdrug import core, models, utils +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("models.KBGAT") +@utils.copy_args(models.GraphAttentionNetwork) +class KnowledgeBaseGraphAttentionNetwork(models.GraphAttentionNetwork, core.Configurable): + """ + Knowledge Base Graph Attention Network proposed in + `Learning Attention-based Embeddings for Relation Prediction in Knowledge Graphs`_. + + .. _Learning Attention-based Embeddings for Relation Prediction in Knowledge Graphs: + https://arxiv.org/pdf/1906.01195.pdf + + Parameters: + num_entity (int): number of entities + num_relation (int): number of relations + embedding_dim (int): dimension of embeddings + hidden_dims (list of int): hidden dimensions + max_score (float, optional): maximal score for triplets + **kwargs + """ + + def __init__(self, num_entity, num_relation, embedding_dim, hidden_dims, max_score=12, **kwargs): + super(KnowledgeBaseGraphAttentionNetwork, self).__init__(embedding_dim, hidden_dims, embedding_dim, **kwargs) + self.num_entity = num_entity + self.num_relation = num_relation + self.max_score = max_score + + self.linear = nn.Linear(self.output_dim, embedding_dim) + self.output_dim = embedding_dim + + self.entity = nn.Parameter(torch.zeros(num_entity, embedding_dim)) + self.relation = nn.Parameter(torch.zeros(num_relation, embedding_dim)) + + nn.init.uniform_(self.entity, -max_score / embedding_dim, max_score / embedding_dim) + nn.init.uniform_(self.relation, -max_score / embedding_dim, max_score / embedding_dim) + + def forward(self, graph, h_index, t_index, r_index, all_loss=None, metric=None): + """ + Compute the score for triplets. + + Parameters: + graph (Graph): fact graph + h_index (Tensor): indexes of head entities + t_index (Tensor): indexes of tail entities + r_index (Tensor): indexes of relations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + """ + with graph.edge(): + graph.edge_feature = self.relation[graph.edge_list[:, 2]].detach() + + output = super(KnowledgeBaseGraphAttentionNetwork, self).forward(graph, self.entity, all_loss, metric) + entity = self.linear(output["node_feature"]) + score = functional.transe_score(entity, self.relation, h_index, t_index, r_index) + return self.max_score - score + diff --git a/build/lib/torchdrug/models/lstm.py b/build/lib/torchdrug/models/lstm.py new file mode 100644 index 00000000..eab1406b --- /dev/null +++ b/build/lib/torchdrug/models/lstm.py @@ -0,0 +1,90 @@ +from torch import nn +from torch.nn import functional as F + +from torchdrug import core +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("models.ProteinLSTM") +class ProteinLSTM(nn.Module, core.Configurable): + """ + Protein LSTM proposed in `Evaluating Protein Transfer Learning with TAPE`_. + + .. _Evaluating Protein Transfer Learning with TAPE: + https://arxiv.org/pdf/1906.08230.pdf + + Parameters: + input_dim (int): input dimension + hidden_dim (int, optional): hidden dimension + num_layers (int, optional): number of LSTM layers + activation (str or function, optional): activation function + layer_norm (bool, optional): apply layer normalization or not + dropout (float, optional): dropout ratio of input features + """ + + def __init__(self, input_dim, hidden_dim, num_layers, activation='tanh', layer_norm=False, + dropout=0): + super(ProteinLSTM, self).__init__() + self.input_dim = input_dim + self.output_dim = hidden_dim # output_dim for node feature is 2 * hidden_dim + self.node_output_dim = 2 * hidden_dim + self.num_layers = num_layers + self.padding_id = input_dim - 1 + + self.embedding = nn.Linear(input_dim, hidden_dim) + if layer_norm: + self.layer_norm = nn.LayerNorm(hidden_dim) + else: + self.layer_norm = None + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout, + bidirectional=True) + + self.reweight = nn.Linear(2 * num_layers, 1) + self.linear = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the residue representations and the graph representation(s). + + Parameters: + graph (Protein): :math:`n` protein(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``residue_feature`` and ``graph_feature`` fields: + residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)` + """ + input = graph.residue_feature.float() + input = functional.variadic_to_padded(input, graph.num_residues, value=self.padding_id)[0] + + input = self.embedding(input) + if self.layer_norm: + input = self.layer_norm(input) + if self.dropout: + input = self.dropout(input) + + output, hidden = self.lstm(input) + + residue_feature = functional.padded_to_variadic(output, graph.num_residues) + + # (2 * num_layer, B, d) + graph_feature = self.reweight(hidden[0].permute(1, 2, 0)).squeeze(-1) + graph_feature = self.linear(graph_feature) + graph_feature = self.activation(graph_feature) + + return { + "graph_feature": graph_feature, + "residue_feature": residue_feature + } diff --git a/build/lib/torchdrug/models/mpnn.py b/build/lib/torchdrug/models/mpnn.py new file mode 100644 index 00000000..db8873ad --- /dev/null +++ b/build/lib/torchdrug/models/mpnn.py @@ -0,0 +1,91 @@ +import torch +from torch import nn + +from torchdrug import core, layers +from torchdrug.core import Registry as R + + +@R.register("models.MPNN") +class MessagePassingNeuralNetwork(nn.Module, core.Configurable): + """ + Message Passing Neural Network proposed in `Neural Message Passing for Quantum Chemistry`_. + + This implements the enn-s2s variant in the original paper. + + .. _Neural Message Passing for Quantum Chemistry: + https://arxiv.org/pdf/1704.01212.pdf + + Parameters: + input_dim (int): input dimension + hidden_dim (int): hidden dimension + edge_input_dim (int): dimension of edge features + num_layer (int, optional): number of hidden layers + num_gru_layer (int, optional): number of GRU layers in each node update + num_mlp_layer (int, optional): number of MLP layers in each message function + num_s2s_step (int, optional): number of processing steps in set2set + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + concat_hidden (bool, optional): concat hidden representations from all layers as output + """ + + def __init__(self, input_dim, hidden_dim, edge_input_dim, num_layer=1, num_gru_layer=1, num_mlp_layer=2, + num_s2s_step=3, short_cut=False, batch_norm=False, activation="relu", concat_hidden=False): + super(MessagePassingNeuralNetwork, self).__init__() + + self.input_dim = input_dim + self.edge_input_dim = edge_input_dim + if concat_hidden: + feature_dim = hidden_dim * num_layer + else: + feature_dim = hidden_dim + self.output_dim = feature_dim * 2 + self.node_output_dim = feature_dim + self.num_layer = num_layer + self.short_cut = short_cut + self.concat_hidden = concat_hidden + + self.linear = nn.Linear(input_dim, hidden_dim) + self.layer = layers.MessagePassing(hidden_dim, edge_input_dim, [hidden_dim] * (num_mlp_layer - 1), + batch_norm, activation) + self.gru = nn.GRU(hidden_dim, hidden_dim, num_gru_layer) + + self.readout = layers.Set2Set(feature_dim, num_step=num_s2s_step) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + hiddens = [] + layer_input = self.linear(input) + hx = layer_input.repeat(self.gru.num_layers, 1, 1) + + for i in range(self.num_layer): + x = self.layer(graph, layer_input) + hidden, hx = self.gru(x.unsqueeze(0), hx) + hidden = hidden.squeeze(0) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + graph_feature = self.readout(graph, node_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } diff --git a/build/lib/torchdrug/models/neuralfp.py b/build/lib/torchdrug/models/neuralfp.py new file mode 100644 index 00000000..ec47c2c5 --- /dev/null +++ b/build/lib/torchdrug/models/neuralfp.py @@ -0,0 +1,96 @@ +from collections.abc import Sequence + +import torch +from torch import nn +from torch.nn import functional as F + +from torchdrug import core, layers +from torchdrug.core import Registry as R + + +@R.register("models.NeuralFP") +class NeuralFingerprint(nn.Module, core.Configurable): + """ + Neural Fingerprints from `Convolutional Networks on Graphs for Learning Molecular Fingerprints`_. + + .. _Convolutional Networks on Graphs for Learning Molecular Fingerprints: + https://arxiv.org/pdf/1509.09292.pdf + + Parameters: + input_dim (int): input dimension + output_dim (int): fingerprint dimension + hidden_dims (list of int): hidden dimensions + edge_input_dim (int, optional): dimension of edge features + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + concat_hidden (bool, optional): concat hidden representations from all layers as output + readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + """ + + def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, + activation="relu", concat_hidden=False, readout="sum"): + super(NeuralFingerprint, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = output_dim * (len(hidden_dims) if concat_hidden else 1) + self.dims = [input_dim] + list(hidden_dims) + self.short_cut = short_cut + self.concat_hidden = concat_hidden + + self.layers = nn.ModuleList() + self.linears = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(layers.NeuralFingerprintConv(self.dims[i], self.dims[i + 1], edge_input_dim, + batch_norm, activation)) + self.linears.append(nn.Linear(self.dims[i + 1], output_dim)) + + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + hiddens = [] + outputs = [] + layer_input = input + + for layer, linear in zip(self.layers, self.linears): + hidden = layer(graph, layer_input) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + output = F.softmax(linear(hidden), dim=-1) + hiddens.append(hidden) + outputs.append(output) + layer_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + graph_feature = torch.cat(outputs, dim=-1) + else: + node_feature = hiddens[-1] + graph_feature = torch.stack(outputs).sum(dim=0) + + graph_feature = self.readout(graph, graph_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } \ No newline at end of file diff --git a/build/lib/torchdrug/models/neurallp.py b/build/lib/torchdrug/models/neurallp.py new file mode 100644 index 00000000..da4ce58f --- /dev/null +++ b/build/lib/torchdrug/models/neurallp.py @@ -0,0 +1,112 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch_scatter import scatter_add + +from torchdrug import core, utils +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("models.NeuralLP") +class NeuralLogicProgramming(nn.Module, core.Configurable): + """ + Neural Logic Programming proposed in `Differentiable Learning of Logical Rules for Knowledge Base Reasoning`_. + + .. _Differentiable Learning of Logical Rules for Knowledge Base Reasoning: + https://papers.nips.cc/paper/2017/file/0e55666a4ad822e0e34299df3591d979-Paper.pdf + + Parameters: + num_relation (int): number of relations + hidden_dim (int): dimension of hidden units in LSTM + num_step (int): number of recurrent steps + num_lstm_layer (int, optional): number of LSTM layers + """ + + eps = 1e-10 + + def __init__(self, num_relation, hidden_dim, num_step, num_lstm_layer=1): + super(NeuralLogicProgramming, self).__init__() + + num_relation = int(num_relation) + self.num_relation = num_relation + self.num_step = num_step + + self.query = nn.Embedding(num_relation * 2 + 1, hidden_dim) + self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_lstm_layer) + self.weight_linear = nn.Linear(hidden_dim, num_relation * 2) + self.linear = nn.Linear(1, 1) + + def negative_sample_to_tail(self, h_index, t_index, r_index): + # convert p(h | t, r) to p(t' | h', r') + # h' = t, r' = r^{-1}, t' = h + is_t_neg = (h_index == h_index[:, [0]]).all(dim=-1, keepdim=True) + new_h_index = torch.where(is_t_neg, h_index, t_index) + new_t_index = torch.where(is_t_neg, t_index, h_index) + new_r_index = torch.where(is_t_neg, r_index, r_index + self.num_relation) + return new_h_index, new_t_index, new_r_index + + @utils.cached + def get_t_output(self, graph, h_index, r_index): + end_index = torch.ones_like(r_index) * graph.num_relation + q_index = torch.stack([r_index] * (self.num_step - 1) + [end_index], dim=0) + query = self.query(q_index) + + hidden, hx = self.lstm(query) + memory = functional.one_hot(h_index, graph.num_node).unsqueeze(0) + + for i in range(self.num_step): + key = hidden[i] + value = hidden[:i + 1] + x = torch.einsum("bd, tbd -> bt", key, value) + attention = F.softmax(x, dim=-1) + input = torch.einsum("bt, tbn -> nb", attention, memory) + weight = F.softmax(self.weight_linear(key), dim=-1).t() + + node_in, node_out, relation = graph.edge_list.t() + if graph.num_node * graph.num_relation < graph.num_edge: + # O(|V|d) memory + node_out = node_out * graph.num_relation + relation + adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), graph.edge_weight, + (graph.num_node, graph.num_node * graph.num_relation)) + output = adjacency.t() @ input + output = output.view(graph.num_node, graph.num_relation, -1) + output = (output * weight).sum(dim=1) + else: + # O(|E|) memory + message = input[node_in] + message = message * weight[relation] + edge_weight = graph.edge_weight.unsqueeze(-1) + output = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) + output = output / output.sum(dim=0, keepdim=True).clamp(self.eps) + + memory = torch.cat([memory, output.t().unsqueeze(0)]) + + return output + + def forward(self, graph, h_index, t_index, r_index, all_loss=None, metric=None): + """ + Compute the score for triplets. + + Parameters: + graph (Tensor): fact graph + h_index (Tensor): indexes of head entities + t_index (Tensor): indexes of tail entities + r_index (Tensor): indexes of relations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + """ + assert graph.num_relation == self.num_relation + graph = graph.undirected(add_inverse=True) + + h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index) + hr_index = h_index * graph.num_relation + r_index + hr_index_set, hr_inverse = hr_index.unique(return_inverse=True) + h_index_set = torch.div(hr_index_set, graph.num_relation, rounding_mode="floor") + r_index_set = hr_index_set % graph.num_relation + + output = self.get_t_output(graph, h_index_set, r_index_set) + + score = output[t_index, hr_inverse] + score = self.linear(score.unsqueeze(-1)).squeeze(-1) + return score \ No newline at end of file diff --git a/build/lib/torchdrug/models/physicochemical.py b/build/lib/torchdrug/models/physicochemical.py new file mode 100644 index 00000000..b9e5b6eb --- /dev/null +++ b/build/lib/torchdrug/models/physicochemical.py @@ -0,0 +1,145 @@ +import os + +import torch +from torch import nn + +from torch_scatter import scatter_mean, scatter_add + +from torchdrug import core, layers, utils, data +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("models.Physicochemical") +class Physicochemical(nn.Module, core.Configurable): + """ + The physicochemical feature engineering for protein sequence proposed in + `Prediction of Membrane Protein Types based on the Hydrophobic Index of Amino Acids`_. + + .. _Prediction of Membrane Protein Types based on the Hydrophobic Index of Amino Acids: + https://link.springer.com/article/10.1023/A:1007091128394 + + Parameters: + path (str): path to store feature file + type (str, optional): physicochemical feature. Available features are ``moran``, ``geary`` and ``nmbroto``. + nlag (int, optional): maximum position interval to compute features + hidden_dims (list of int, optional): hidden dimensions + """ + + url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/documents/AAidx.txt" + md5 = "ec612f4df41b93ae03c31ae376c23ce0" + + property_key = ["CIDH920105", "BHAR880101", "CHAM820101", "CHAM820102", + "CHOC760101", "BIGC670101", "CHAM810101", "DAYM780201"] + num_residue_type = len(data.Protein.id2residue_symbol) + + def __init__(self, path, type="moran", nlag=30, hidden_dims=(512,)): + super(Physicochemical, self).__init__() + self.type = type + path = os.path.expanduser(path) + if not os.path.exists(path): + os.makedirs(path) + self.path = path + index_file = utils.download(self.url, path, md5=self.md5) + property = self.read_property(index_file) + self.register_buffer("property", property) + + self.nlag = nlag + self.input_dim = len(self.property_key) * nlag + self.output_dim = hidden_dims[-1] + + self.mlp = layers.Sequential( + layers.MultiLayerPerceptron(self.input_dim, hidden_dims), + nn.ReLU() + ) + + def read_property(self, file): + with open(file, "r") as fin: + lines = fin.readlines() + vocab = lines[0].strip().split("\t")[1:] + + property_dict = {} + for line in lines[1:]: + line = line.strip().split("\t") + property_dict[line[0]] = [float(x) if x != "NA" else 0 for x in line[1:]] + + _property = [] + for key in self.property_key: + _property.append(property_dict[key]) + _property = torch.tensor(_property) + mapping = [data.Protein.residue_symbol2id[residue] for residue in vocab] + property = torch.zeros((len(self.property_key), self.num_residue_type), dtype=torch.float) + property[:, mapping] = _property + + property = (property - property.mean(dim=1, keepdim=True)) / \ + (property.std(dim=1, keepdim=True) + 1e-10) + return property.transpose(0, 1) + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the residue representations and the graph representation(s). + + Parameters: + graph (Protein): :math:`n` protein(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``graph_feature`` field: graph representations of shape :math:`(n, d)` + """ + input = graph.residue_type + + x = self.property[input] # num_residue * 8 + x_mean = scatter_mean(x, graph.residue2graph, dim=0, dim_size=graph.batch_size) # batch_size * 8 + + size = graph.num_residues + starts = size.cumsum(0) - size # batch_size * nlag + starts = starts.unsqueeze(-1).expand(-1, self.nlag) + steps = torch.arange(self.nlag, dtype=torch.long, device=graph.device) + 1 + steps = (graph.num_residues.unsqueeze(-1) - steps.unsqueeze(0)).clamp(min=0) + ends = starts + steps + mask_0 = functional.multi_slice_mask(starts, ends, graph.num_residue) # num_residue * nlag + + ends = size.cumsum(0) # batch_size * nlag + ends = ends.unsqueeze(-1).expand(-1, self.nlag) + starts = ends - steps + mask_1 = functional.multi_slice_mask(starts, ends, graph.num_residue) # num_residue * nlag + + index2sample = torch.repeat_interleave(size) + numerator = torch.zeros((graph.num_residue, self.nlag, x.shape[-1]), dtype=torch.float, device=graph.device) + if self.type == "moran": + _numerator = (x - x_mean[index2sample]).unsqueeze(1).expand(-1, self.nlag, -1)[mask_0] * \ + (x - x_mean[index2sample]).unsqueeze(1).expand(-1, self.nlag, -1)[mask_1] + numerator[mask_0] = _numerator + numerator = numerator / (steps[index2sample].unsqueeze(-1) + 1e-10) + numerator = scatter_add(numerator, graph.residue2graph, dim=0, dim_size=graph.batch_size) # batch_size * nlag * 8 + demonimator = scatter_add((x - x_mean[index2sample]) ** 2, graph.residue2graph, dim=0, dim_size=graph.batch_size) + demonimator = demonimator / graph.num_residues.unsqueeze(-1) + demonimator = demonimator.unsqueeze(1) # batch_size * 1 * 8 + elif self.type == "geary": + _numerator = x.unsqueeze(1).expand(-1, self.nlag, -1)[mask_0] - \ + x.unsqueeze(1).expand(-1, self.nlag, -1)[mask_1] + _numerator = _numerator ** 2 + numerator[mask_0] = _numerator + numerator = numerator / (steps[index2sample].unsqueeze(-1) * 2 + 1e-10) + numerator = scatter_add(numerator, graph.residue2graph, dim=0, dim_size=graph.batch_size) # batch_size * nlag * 8 + demonimator = scatter_add((x - x_mean[index2sample]) ** 2, graph.residue2graph, dim=0, dim_size=graph.batch_size) + demonimator = demonimator / (graph.num_residues - 1 + 1e-10).unsqueeze(-1) + demonimator = demonimator.unsqueeze(1) # batch_size * 1 * 8 + elif self.type == "nmbroto": + _numerator = x.unsqueeze(1).expand(-1, self.nlag, -1)[mask_0] * \ + x.unsqueeze(1).expand(-1, self.nlag, -1)[mask_1] + numerator[mask_0] = _numerator + numerator = scatter_add(numerator, graph.residue2graph, dim=0, dim_size=graph.batch_size) # batch_size * nlag * 8 + demonimator = steps.unsqueeze(-1) # batch_size * nlag * 1 + else: + raise ValueError("Unknown physicochemical feature type `%s`" % self.type) + feature = numerator / (demonimator + 1e-10) + feature = feature.flatten(1, 2) + + graph_feature = self.mlp(feature) + + return { + "graph_feature": graph_feature, + } \ No newline at end of file diff --git a/build/lib/torchdrug/models/schnet.py b/build/lib/torchdrug/models/schnet.py new file mode 100644 index 00000000..0bfb2cf9 --- /dev/null +++ b/build/lib/torchdrug/models/schnet.py @@ -0,0 +1,84 @@ +from collections.abc import Sequence + +import torch +from torch import nn + +from torchdrug import core, layers +from torchdrug.core import Registry as R + + +@R.register("models.SchNet") +class SchNet(nn.Module, core.Configurable): + """ + SchNet from `SchNet: A continuous-filter convolutional neural network for modeling quantum interactions`_. + + .. _SchNet\: A continuous-filter convolutional neural network for modeling quantum interactions: + https://arxiv.org/pdf/1706.08566.pdf + + Parameters: + input_dim (int): input dimension + hidden_dims (list of int): hidden dimensions + edge_input_dim (int, optional): dimension of edge features + cutoff (float, optional): maximal scale for RBF kernels + num_gaussian (int, optional): number of RBF kernels + short_cut (bool, optional): use short cut or not + batch_norm (bool, optional): apply batch normalization or not + activation (str or function, optional): activation function + concat_hidden (bool, optional): concat hidden representations from all layers as output + """ + + def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, + batch_norm=False, activation="shifted_softplus", concat_hidden=False): + super(SchNet, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.input_dim = input_dim + self.output_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1) + self.dims = [input_dim] + list(hidden_dims) + self.short_cut = short_cut + self.concat_hidden = concat_hidden + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff, + num_gaussian, batch_norm, activation)) + + self.readout = layers.SumReadout() + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the node representations and the graph representation(s). + + Require the graph(s) to have node attribute ``node_position``. + + Parameters: + graph (Graph): :math:`n` graph(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``node_feature`` and ``graph_feature`` fields: + node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` + """ + hiddens = [] + layer_input = input + + for layer in self.layers: + hidden = layer(graph, layer_input) + if self.short_cut and hidden.shape == layer_input.shape: + hidden = hidden + layer_input + hiddens.append(hidden) + layer_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + graph_feature = self.readout(graph, node_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } \ No newline at end of file diff --git a/build/lib/torchdrug/models/statistic.py b/build/lib/torchdrug/models/statistic.py new file mode 100644 index 00000000..941cde81 --- /dev/null +++ b/build/lib/torchdrug/models/statistic.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + +from torch_scatter import scatter_add + +from torchdrug import core, layers, data +from torchdrug.core import Registry as R + + +@R.register("models.Statistic") +class Statistic(nn.Module, core.Configurable): + """ + The statistic feature engineering for protein sequence proposed in + `Harnessing Computational Biology for Exact Linear B-cell Epitope Prediction`_. + + .. _Harnessing Computational Biology for Exact Linear B-cell Epitope Prediction: + https://www.liebertpub.com/doi/abs/10.1089/omi.2015.0095 + + Parameters: + type (str, optional): statistic feature. Available feature is ``DDE``. + hidden_dims (list of int, optional): hidden dimensions + """ + + num_residue_type = len(data.Protein.id2residue_symbol) + input_dim = num_residue_type ** 2 + _codons = {"A": 4, "C": 2, "D": 2, "E": 2, "F": 2, "G": 4, "H": 2, "I": 3, "K": 2, "L": 6, + "M": 1, "N": 2, "P": 4, "Q": 2, "R": 6, "S": 6, "T": 4, "V": 4, "W": 1, "Y": 2} + + def __init__(self, type="DDE", hidden_dims=(512,)): + super(Statistic, self).__init__() + self.type = type + self.output_dim = hidden_dims[-1] + + codons = self.calculate_codons() + self.register_buffer("codons", codons) + self.mlp = layers.Sequential( + layers.MultiLayerPerceptron(self.input_dim, hidden_dims), + nn.ReLU() + ) + + def calculate_codons(self): + codons = [0] * self.num_residue_type + for i, token in data.Protein.id2residue_symbol.items(): + codons[i] = self._codons[token] + codons = torch.tensor(codons) + return codons + + def forward(self, graph, input, all_loss=None, metric=None): + """ + Compute the residue representations and the graph representation(s). + + Parameters: + graph (Protein): :math:`n` protein(s) + input (Tensor): input node representations + all_loss (Tensor, optional): if specified, add loss to this tensor + metric (dict, optional): if specified, output metrics to this dict + + Returns: + dict with ``graph_feature`` field: graph representations of shape :math:`(n, d)` + """ + input = graph.residue_type + + index = input[:-1] * self.num_residue_type + input[1:] + index = graph.residue2graph[:-1] * self.input_dim + index + value = torch.ones(graph.num_residue - 1, dtype=torch.float, device=graph.device) + mask = graph.residue2graph[:-1] == graph.residue2graph[1:] + feature = scatter_add(value * mask.float(), index, dim=0, dim_size=graph.batch_size * self.input_dim) + feature = feature.view(graph.batch_size, self.input_dim) + feature = feature / (feature.sum(dim=-1, keepdim=True) + 1e-10) + if self.type == "DDE": + TM = self.codons.unsqueeze(0) * self.codons.unsqueeze(1) / 61 ** 2 + TM = TM.flatten() + TV = (TM * (1 - TM)).unsqueeze(0) / (graph.num_residues - 1 + 1e-10).unsqueeze(1) + feature = (feature - TM.unsqueeze(0)) / (TV.sqrt() + 1e-10) + else: + raise ValueError("Unknown statistic feature type `%s`" % self.type) + + graph_feature = self.mlp(feature) + + return { + "graph_feature": graph_feature, + } \ No newline at end of file diff --git a/build/lib/torchdrug/patch.py b/build/lib/torchdrug/patch.py new file mode 100644 index 00000000..d2f6ace3 --- /dev/null +++ b/build/lib/torchdrug/patch.py @@ -0,0 +1,145 @@ +import os +import inspect +import importlib + +import torch +from torch import nn +from torch import optim +from torch.optim import lr_scheduler as scheduler +from torch.utils.data import dataset +from torch.utils import cpp_extension +from torch import distributed as dist + +from torchdrug import core, data +from torchdrug.core import Registry as R + + +class PatchedModule(nn.Module): + + def __init__(self): + super(PatchedModule, self).__init__() + # TODO: these hooks are bugged. + # self._register_state_dict_hook(PatchedModule.graph_state_dict) + # self._register_load_state_dict_pre_hook(PatchedModule.load_graph_state_dict) + + def graph_state_dict(self, destination, prefix, local_metadata): + local_graphs = [] + for name, param in self._buffers.items(): + if isinstance(param, data.Graph): + local_graphs.append(name) + destination.pop(prefix + name) + for t_name, tensor in zip(data.Graph._tensor_names, param.to_tensors()): + if tensor is not None: + destination[prefix + name + "." + t_name] = tensor + if local_graphs: + local_metadata["graph"] = local_graphs + return destination + + @classmethod + def load_graph_state_dict(cls, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + if "graph" not in local_metadata: + return + + for name in local_metadata["graph"]: + tensors = [] + for t_name in data.Graph._tensor_names: + key = prefix + name + "." + t_name + input_tensor = state_dict.get(key, None) + tensors.append(input_tensor) + try: + state_dict[prefix + name] = data.Graph.from_tensors(tensors) + print("successfully assigned %s" % (prefix + name)) + except: + error_msgs.append("Can't construct Graph `%s` from tensors in the state dict" % key) + return state_dict + + @property + def device(self): + try: + tensor = next(self.parameters()) + except StopIteration: + tensor = next(self.buffers()) + return tensor.device + + def register_buffer(self, name, tensor, persistent=True): + if persistent is False and isinstance(self, torch.jit.ScriptModule): + raise RuntimeError("ScriptModule does not support non-persistent buffers") + + if '_buffers' not in self.__dict__: + raise AttributeError( + "cannot assign buffer before Module.__init__() call") + elif not isinstance(name, str): + raise TypeError("buffer name should be a string. " + "Got {}".format(torch.typename(name))) + elif '.' in name: + raise KeyError("buffer name can't contain \".\"") + elif name == '': + raise KeyError("buffer name can't be empty string \"\"") + elif hasattr(self, name) and name not in self._buffers: + raise KeyError("attribute '{}' already exists".format(name)) + elif tensor is not None and not isinstance(tensor, torch.Tensor) and not isinstance(tensor, data.Graph): + raise TypeError("cannot assign '{}' object to buffer '{}' " + "(torch.Tensor, torchdrug.data.Graph or None required)" + .format(torch.typename(tensor), name)) + else: + self._buffers[name] = tensor + if persistent: + self._non_persistent_buffers_set.discard(name) + else: + self._non_persistent_buffers_set.add(name) + + +def _get_build_directory(name, verbose): + root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR') + if root_extensions_directory is None: + root_extensions_directory = cpp_extension.get_default_build_root() + + if verbose: + print('Using {} as PyTorch extensions root...'.format( + root_extensions_directory)) + + build_directory = os.path.join(root_extensions_directory, name) + if not os.path.exists(build_directory): + if verbose: + print('Creating extension directory {}...'.format(build_directory)) + # This is like mkdir -p, i.e. will also create parent directories. + baton = cpp_extension.FileBaton("lock_%s" % name) + if baton.try_acquire(): + os.makedirs(build_directory) + baton.release() + else: + baton.wait() + + return build_directory + + +def patch(module, name, cls): + backup = getattr(module, name) + setattr(module, "_%s" % name, backup) + setattr(module, name, cls) + + +patch(nn, "Module", PatchedModule) +patch(cpp_extension, "_get_build_directory", _get_build_directory) + +Optimizer = optim.Optimizer +for name, cls in inspect.getmembers(optim): + if inspect.isclass(cls) and issubclass(cls, Optimizer): + cls = core.make_configurable(cls, ignore_args=("params",)) + cls = R.register("optim.%s" % name)(cls) + patch(optim, name, cls) + +Scheduler = scheduler._LRScheduler +for name, cls in inspect.getmembers(scheduler): + if inspect.isclass(cls) and issubclass(cls, Scheduler): + cls = core.make_configurable(cls, ignore_args=("optimizer",)) + cls = R.register("scheduler.%s" % name)(cls) + setattr(optim, name, cls) + +Dataset = dataset.Dataset +for name, cls in inspect.getmembers(dataset): + if inspect.isclass(cls) and issubclass(cls, Dataset): + cls = core.make_configurable(cls) + cls = R.register("dataset.%s" % name)(cls) + patch(dataset, name, cls) +importlib.reload(torch.utils.data) diff --git a/build/lib/torchdrug/tasks/__init__.py b/build/lib/torchdrug/tasks/__init__.py new file mode 100644 index 00000000..0682a543 --- /dev/null +++ b/build/lib/torchdrug/tasks/__init__.py @@ -0,0 +1,50 @@ +from .task import Task + +from .property_prediction import PropertyPrediction, MultipleBinaryClassification, \ + NodePropertyPrediction, InteractionPrediction, Unsupervised +from .pretrain import EdgePrediction, AttributeMasking, ContextPrediction, DistancePrediction, \ + AnglePrediction, DihedralPrediction +from .generation import AutoregressiveGeneration, GCPNGeneration +from .retrosynthesis import CenterIdentification, SynthonCompletion, Retrosynthesis +from .reasoning import KnowledgeGraphCompletion +from .contact_prediction import ContactPrediction + + +_criterion_name = { + "mse": "mean squared error", + "mae": "mean absolute error", + "bce": "binary cross entropy", + "ce": "cross entropy", +} + +_metric_name = { + "mae": "mean absolute error", + "mse": "mean squared error", + "rmse": "root mean squared error", + "acc": "accuracy", + "mcc": "matthews correlation coefficient", +} + + +def _get_criterion_name(criterion): + if criterion in _criterion_name: + return _criterion_name[criterion] + return "%s loss" % criterion + + +def _get_metric_name(metric): + if metric in _metric_name: + return _metric_name[metric] + return metric + + +__all__ = [ + "PropertyPrediction", "MultipleBinaryClassification", "NodePropertyPrediction", "InteractionPrediction", + "Unsupervised", + "EdgePrediction", "AttributeMasking", "ContextPrediction", "DistancePrediction", "AnglePrediction", + "DihedralPrediction", + "AutoregressiveGeneration", "GCPNGeneration", + "CenterIdentification", "SynthonCompletion", "Retrosynthesis", + "KnowledgeGraphCompletion", + "ContactPrediction", +] \ No newline at end of file diff --git a/build/lib/torchdrug/tasks/contact_prediction.py b/build/lib/torchdrug/tasks/contact_prediction.py new file mode 100644 index 00000000..5bf0dd4c --- /dev/null +++ b/build/lib/torchdrug/tasks/contact_prediction.py @@ -0,0 +1,172 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from torchdrug import core, layers, tasks, metrics +from torchdrug.core import Registry as R +from torchdrug.layers import functional + + +@R.register("tasks.ContactPrediction") +class ContactPrediction(tasks.Task, core.Configurable): + """ + Predict whether each amino acid pair contact or not in the folding structure. + + Parameters: + model (nn.Module): protein sequence representation model + max_length (int, optional): maximal length of sequence. Truncate the sequence if it exceeds this limit. + random_truncate (bool, optional): truncate the sequence at a random position. + If not, truncate the suffix of the sequence. + threshold (float, optional): distance threshold for contact + gap (int, optional): sequential distance cutoff for evaluation + criterion (str or dict, optional): training criterion. For dict, the key is criterion and the value + is the corresponding weight. Available criterion is ``bce``. + metric (str or list of str, optional): metric(s). + Available metrics are ``accuracy``, ``prec@Lk`` and ``prec@k``. + num_mlp_layer (int, optional): number of layers in mlp prediction head + verbose (int, optional): output verbose level + """ + + eps = 1e-10 + _option_members = {"task", "criterion", "metric"} + + def __init__(self, model, max_length=500, random_truncate=True, threshold=8.0, gap=6, criterion="bce", + metric=("accuracy", "prec@L5"), num_mlp_layer=1, verbose=0): + super(ContactPrediction, self).__init__() + self.model = model + self.max_length = max_length + self.random_truncate = random_truncate + self.threshold = threshold + self.gap = gap + self.criterion = criterion + self.metric = metric + self.num_mlp_layer = num_mlp_layer + self.verbose = verbose + + if hasattr(self.model, "node_output_dim"): + model_output_dim = self.model.node_output_dim + else: + model_output_dim = self.model.output_dim + hidden_dims = [model_output_dim] * (self.num_mlp_layer - 1) + self.mlp = layers.MLP(2 * model_output_dim, hidden_dims + [1]) + + def truncate(self, batch): + graph = batch["graph"] + size = graph.num_residues + if (size > self.max_length).any(): + if self.random_truncate: + starts = (torch.rand(graph.batch_size, device=graph.device) * \ + (graph.num_residues - self.max_length).clamp(min=0)).long() + ends = torch.min(starts + self.max_length, graph.num_residues) + starts = starts + (graph.num_cum_residues - graph.num_residues) + ends = ends + (graph.num_cum_residues - graph.num_residues) + mask = functional.multi_slice_mask(starts, ends, graph.num_residue) + else: + starts = size.cumsum(0) - size + size = size.clamp(max=self.max_length) + ends = starts + size + mask = functional.multi_slice_mask(starts, ends, graph.num_residue) + graph = graph.subresidue(mask) + + return { + "graph": graph + } + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + batch = self.truncate(batch) + pred = self.predict(batch, all_loss, metric) + target = self.target(batch) + + for criterion, weight in self.criterion.items(): + if criterion == "bce": + loss = F.binary_cross_entropy_with_logits(pred, target["label"], reduction="none") + loss = functional.variadic_mean(loss * target["mask"].float(), size=target["size"]) + else: + raise ValueError("Unknown criterion `%s`" % criterion) + loss = loss.mean() + + name = tasks._get_criterion_name(criterion) + metric[name] = loss + all_loss += loss * weight + + return all_loss, metric + + def predict(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + output = self.model(graph, graph.residue_feature.float(), all_loss=all_loss, metric=metric) + output = output["residue_feature"] + + range = torch.arange(graph.num_residue, device=self.device) + node_in, node_out = functional.variadic_meshgrid(range, graph.num_residues, range, graph.num_residues) + if all_loss is None and node_in.shape[0] > (self.max_length ** 2) * graph.batch_size: + # test + # split large input to reduce memory cost + size = (self.max_length ** 2) * graph.batch_size + node_in_splits = node_in.split(size, dim=0) + node_out_splits = node_out.split(size, dim=0) + pred = [] + for _node_in, _node_out in zip(node_in_splits, node_out_splits): + prod = output[_node_in] * output[_node_out] + diff = (output[_node_in] - output[_node_out]).abs() + pairwise_features = torch.cat((prod, diff), -1) + _pred = self.mlp(pairwise_features) + pred.append(_pred) + pred = torch.cat(pred, dim=0) + else: + prod = output[node_in] * output[node_out] + diff = (output[node_in] - output[node_out]).abs() + pairwise_features = torch.cat((prod, diff), -1) + pred = self.mlp(pairwise_features) + + return pred.squeeze(-1) + + def target(self, batch): + graph = batch["graph"] + valid_mask = graph.mask + residue_position = graph.residue_position + + range = torch.arange(graph.num_residue, device=self.device) + node_in, node_out = functional.variadic_meshgrid(range, graph.num_residues, range, graph.num_residues) + dist = (residue_position[node_in] - residue_position[node_out]).norm(p=2, dim=-1) + label = (dist < self.threshold).float() + + mask = valid_mask[node_in] & valid_mask[node_out] & ((node_in - node_out).abs() >= self.gap) + + return { + "label": label, + "mask": mask, + "size": graph.num_residues ** 2 + } + + def evaluate(self, pred, target): + label = target["label"] + mask = target["mask"] + size = functional.variadic_sum(mask.long(), target["size"]) + label = label[mask] + pred = pred[mask] + + metric = {} + for _metric in self.metric: + if _metric == "accuracy": + score = (pred > 0) == label + score = functional.variadic_mean(score.float(), size).mean() + elif _metric.startswith("prec@L"): + l = target["size"].sqrt().long() + k = int(_metric[7:]) if len(_metric) > 7 else 1 + l = torch.div(l, k, rounding_mode="floor") + score = metrics.variadic_top_precision(pred, label, size, l).mean() + elif _metric.startswith("prec@"): + k = int(_metric[5:]) + k = torch.full_like(size, k) + score = metrics.variadic_top_precision(pred, label, size, k).mean() + else: + raise ValueError("Unknown criterion `%s`" % _metric) + + name = tasks._get_metric_name(_metric) + metric[name] = score + + return metric diff --git a/build/lib/torchdrug/tasks/generation.py b/build/lib/torchdrug/tasks/generation.py new file mode 100644 index 00000000..47b19031 --- /dev/null +++ b/build/lib/torchdrug/tasks/generation.py @@ -0,0 +1,1552 @@ +import copy +import logging +import warnings +from collections import defaultdict + +from tqdm import tqdm + +import torch +import torch.nn.functional as F +from torch import nn +from torch_scatter import scatter_add, scatter_max +from torch_scatter.composite import scatter_log_softmax, scatter_softmax + +from torchdrug import core, data, tasks, metrics, transforms +from torchdrug.core import Registry as R +from torchdrug.layers import functional +from torchdrug import layers + + +logger = logging.getLogger(__name__) + + +@R.register("tasks.AutoregressiveGeneration") +class AutoregressiveGeneration(tasks.Task, core.Configurable): + """ + Autoregressive graph generation task. + + This class can be used to implement GraphAF proposed in + `GraphAF: A Flow-based Autoregressive Model for Molecular Graph Generation`_. + To do so, instantiate the node model and the edge model with two + :class:`GraphAutoregressiveFlow ` models. + + .. _GraphAF\: A Flow-based Autoregressive Model for Molecular Graph Generation: + https://arxiv.org/pdf/2001.09382.pdf + + Parameters: + node_model (nn.Module): node likelihood model + edge_model (nn.Module): edge likelihood model + task (str or list of str, optional): property optimization task(s). Available tasks are ``plogp`` and ``qed``. + num_node_sample (int, optional): number of node samples per graph. -1 for all samples. + num_edge_sample (int, optional): number of edge samples per graph. -1 for all samples. + max_edge_unroll (int, optional): max node id difference. + If not provided, use the statistics from the training set. + max_node (int, optional): max number of node. + If not provided, use the statistics from the training set. + criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values + are the corresponding weights. Available criterions are ``nll`` and ``ppo``. + agent_update_interval (int, optional): update agent every n batch + gamma (float, optional): reward discount rate + reward_temperature (float, optional): temperature for reward. Higher temperature encourages larger mean reward, + while lower temperature encourages larger maximal reward. + baseline_momentum (float, optional): momentum for value function baseline + """ + + eps = 1e-10 + top_k = 10 + _option_members = {"task", "criterion"} + + def __init__(self, node_model, edge_model, task=(), num_node_sample=-1, num_edge_sample=-1, + max_edge_unroll=None, max_node=None, criterion="nll", agent_update_interval=5, gamma=0.9, + reward_temperature=1, baseline_momentum=0.9): + super(AutoregressiveGeneration, self).__init__() + self.node_model = node_model + self.edge_model = edge_model + self.agent_node_model = copy.deepcopy(node_model) + self.agent_edge_model = copy.deepcopy(edge_model) + self.task = task + self.num_atom_type = self.node_model.input_dim + self.num_bond_type = self.edge_model.input_dim + self.num_node_sample = num_node_sample + self.num_edge_sample = num_edge_sample + self.max_edge_unroll = max_edge_unroll + self.max_node = max_node + self.criterion = criterion + self.agent_update_interval = agent_update_interval + self.gamma = gamma + self.reward_temperature = reward_temperature + self.baseline_momentum = baseline_momentum + self.best_results = defaultdict(list) + self.batch_id = 0 + + def preprocess(self, train_set, valid_set, test_set): + """ + Add atom id mapping and random BFS order to the training set. + + Compute ``max_edge_unroll`` and ``max_node`` on the training set if not provided. + """ + remap_atom_type = transforms.RemapAtomType(train_set.atom_types) + train_set.transform = transforms.Compose([ + train_set.transform, + remap_atom_type, + transforms.RandomBFSOrder(), + ]) + self.register_buffer("id2atom", remap_atom_type.id2atom) + self.register_buffer("atom2id", remap_atom_type.atom2id) + + if self.max_edge_unroll is None or self.max_node is None: + self.max_edge_unroll = 0 + self.max_node = 0 + + train_set = tqdm(train_set, "Computing max number of nodes and edge unrolling") + for sample in train_set: + graph = sample["graph"] + if graph.edge_list.numel(): + edge_unroll = (graph.edge_list[:, 0] - graph.edge_list[:, 1]).abs().max().item() + self.max_edge_unroll = max(self.max_edge_unroll, edge_unroll) + self.max_node = max(self.max_node, graph.num_node) + + logger.warning("max node = %d, max edge unroll = %d" % (self.max_node, self.max_edge_unroll)) + + self.register_buffer("node_baseline", torch.zeros(self.max_node + 1)) + self.register_buffer("edge_baseline", torch.zeros(self.max_node + 1)) + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + for criterion, weight in self.criterion.items(): + if criterion == "nll": + _loss, _metric = self.density_estimation_forward(batch) + all_loss += _loss * weight + metric.update(_metric) + elif criterion == "ppo": + _loss, _metric = self.reinforce_forward(batch) + all_loss += _loss * weight + metric.update(_metric) + else: + raise ValueError("Unknown criterion `%s`" % criterion) + + return all_loss, metric + + def reinforce_forward(self, batch): + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + if self.batch_id % self.agent_update_interval == 0: + self.agent_node_model.load_state_dict(self.node_model.state_dict()) + self.agent_edge_model.load_state_dict(self.edge_model.state_dict()) + self.batch_id += 1 + + # generation takes less time when early_stop=True + graph = self.generate(len(batch["graph"]), off_policy=True, early_stop=True) + if len(graph) == 0 or graph.num_nodes.max() == 1: + logger.error("Generation results collapse to singleton molecules") + + all_loss.requires_grad_() + nan = torch.tensor(float("nan"), device=self.device) + for task in self.task: + if task == "plogp": + metric["Penalized logP"] = nan + metric["Penalized logP (max)"] = nan + elif task == "qed": + metric["QED"] = nan + metric["QED (max)"] = nan + metric["node PPO objective"] = nan + metric["edge PPO objective"] = nan + + return all_loss, metric + + reward = torch.zeros(len(graph), device=self.device) + for task in self.task: + if task == "plogp": + plogp = metrics.penalized_logP(graph) + metric["Penalized logP"] = plogp.mean() + metric["Penalized logP (max)"] = plogp.max() + self.update_best_result(graph, plogp, "Penalized logP") + reward += (plogp / self.reward_temperature).exp() + + if plogp.max().item() > 5: + print("Penalized logP max = %s" % plogp.max().item()) + print(self.best_results["Penalized logP"]) + + elif task == "qed": + qed = metrics.QED(graph) + metric["QED"] = qed.mean() + metric["QED (max)"] = qed.max() + self.update_best_result(graph, qed, "QED") + reward += (qed / self.reward_temperature).exp() + #reward += qed * 3 + + if qed.max().item() > 0.93: + print("QED max = %s" % qed.max().item()) + print(self.best_results["QED"]) + + else: + raise ValueError("Unknown task `%s`" % task) + + # these graph-level features will broadcast to all masked graphs + with graph.graph(): + graph.reward = reward + graph.original_num_nodes = graph.num_nodes + graph.atom_type = self.atom2id[graph.atom_type] + + is_training = self.training + # easily got nan if BN is trained + self.bn_eval() + + masked_graph, node_target = self.mask_node(graph, metric) + # reward reshaping + reward = masked_graph.reward + masked_graph.atom_type = self.id2atom[masked_graph.atom_type] + reward = reward * self.gamma ** (masked_graph.original_num_nodes - masked_graph.num_nodes).float() + + # per graph size reward baseline + weight = torch.ones_like(masked_graph.num_nodes, dtype=torch.float) + baseline = scatter_add(reward, masked_graph.num_nodes, dim_size=self.max_node + 1) / \ + (scatter_add(weight, masked_graph.num_nodes, dim_size=self.max_node + 1) + self.eps) + self.node_baseline = self.node_baseline * self.baseline_momentum + baseline * (1 - self.baseline_momentum) + reward -= self.node_baseline[masked_graph.num_nodes] + reward += masked_graph.is_valid + masked_graph.atom_type = self.atom2id[masked_graph.atom_type] + + log_likelihood = self.node_model(masked_graph, node_target, None, all_loss, metric) + agent_log_likelihood = self.agent_node_model(masked_graph, node_target, None, all_loss, metric) + objective = functional.clipped_policy_gradient_objective(log_likelihood, agent_log_likelihood, reward) + objective = objective.mean() + metric["node PPO objective"] = objective + all_loss += -objective + + masked_graph, edge_target, edge = self.mask_edge(graph, metric) + # reward reshaping + reward = masked_graph.reward + masked_graph.atom_type = self.id2atom[masked_graph.atom_type] + reward = reward * self.gamma ** (masked_graph.original_num_nodes - masked_graph.num_nodes).float() + + # per graph size reward baseline + weight = torch.ones_like(masked_graph.num_nodes, dtype=torch.float) + baseline = scatter_add(reward, masked_graph.num_nodes, dim_size=self.max_node + 1) / \ + (scatter_add(weight, masked_graph.num_nodes, dim_size=self.max_node + 1) + self.eps) + self.edge_baseline = self.edge_baseline * self.baseline_momentum + baseline * (1 - self.baseline_momentum) + reward -= self.edge_baseline[masked_graph.num_nodes] + reward += masked_graph.is_valid + masked_graph.atom_type = self.atom2id[masked_graph.atom_type] + + log_likelihood = self.edge_model(masked_graph, edge_target, edge, all_loss, metric) + agent_log_likelihood = self.agent_edge_model(masked_graph, edge_target, edge, all_loss, metric) + objective = functional.clipped_policy_gradient_objective(log_likelihood, agent_log_likelihood, reward) + objective = objective.mean() + metric["edge PPO objective"] = objective + all_loss += -objective + + self.bn_train(is_training) + + return all_loss, metric + + def density_estimation_forward(self, batch): + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + graph = batch["graph"] + masked_graph, node_target = self.mask_node(graph, metric) + log_likelihood = self.node_model(masked_graph, node_target, None, all_loss, metric) + log_likelihood = log_likelihood.mean() + metric["node log likelihood"] = log_likelihood + all_loss += -log_likelihood + + masked_graph, edge_target, edge = self.mask_edge(graph, metric) + log_likelihood = self.edge_model(masked_graph, edge_target, edge, all_loss, metric) + log_likelihood = log_likelihood.mean() + metric["edge log likelihood"] = log_likelihood + all_loss += -log_likelihood + + return all_loss, metric + + def evaluate(self, batch): + pred = None + metric = {} + + graph, target = self.all_node(batch["graph"]) + log_likelihood = self.node_model(graph, target) + log_likelihood = log_likelihood.mean() + metric["node log likelihood"] = log_likelihood + + graph, target = self.all_edge(batch["graph"]) + log_likelihood = self.edge_model(graph, target) + log_likelihood = log_likelihood.mean() + metric["edge log likelihood"] = log_likelihood + + return pred, metric + + def bn_train(self, mode=True): + for module in self.modules(): + if isinstance(module, nn.BatchNorm1d): + module.train(mode) + + def bn_eval(self): + for module in self.modules(): + if isinstance(module, nn.BatchNorm1d): + module.eval() + + def update_best_result(self, graph, score, task): + score = score.cpu() + best_results = self.best_results[task] + for s, i in zip(*score.sort(descending=True)): + s = s.item() + i = i.item() + if len(best_results) == self.top_k and s < best_results[-1][0]: + break + best_results.append((s, graph[i].to_smiles())) + best_results.sort(reverse=True) + best_results = best_results[:self.top_k] + self.best_results[task] = best_results + + @torch.no_grad() + def generate(self, num_sample, max_resample=20, off_policy=False, early_stop=False, verbose=0): + num_relation = self.num_bond_type - 1 + is_training = self.training + self.eval() + + if off_policy: + node_model = self.agent_node_model + edge_model = self.agent_edge_model + else: + node_model = self.node_model + edge_model = self.edge_model + + edge_list = torch.zeros(0, 3, dtype=torch.long, device=self.device) + num_nodes = torch.zeros(num_sample, dtype=torch.long, device=self.device) + num_edges = torch.zeros_like(num_nodes) + atom_type = torch.zeros(0, dtype=torch.long, device=self.device) + graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges, + num_relation=num_relation) + completed = torch.zeros(num_sample, dtype=torch.bool, device=self.device) + + for node_in in range(self.max_node): + atom_pred = node_model.sample(graph) + # why we add atom_pred even if it is completed? + # because we need to batch edge model over (node_in, node_out), even on completed graphs + atom_type, num_nodes = self._append(atom_type, num_nodes, atom_pred) + graph = node_graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges, + num_relation=num_relation) + + start = max(0, node_in - self.max_edge_unroll) + for node_out in range(start, node_in): + is_valid = completed.clone() + edge = torch.tensor([node_in, node_out], device=self.device).repeat(num_sample, 1) + # default: non-edge + bond_pred = (self.num_bond_type - 1) * torch.ones(num_sample, dtype=torch.long, device=self.device) + for i in range(max_resample): + # only resample invalid graphs + mask = ~is_valid + bond_pred[mask] = edge_model.sample(graph, edge)[mask] + # check valency + mask = (bond_pred < edge_model.input_dim - 1) & ~completed + edge_pred = torch.cat([edge, bond_pred.unsqueeze(-1)], dim=-1) + tmp_edge_list, tmp_num_edges = self._append(edge_list, num_edges, edge_pred, mask) + edge_pred = torch.cat([edge.flip(-1), bond_pred.unsqueeze(-1)], dim=-1) + tmp_edge_list, tmp_num_edges = self._append(tmp_edge_list, tmp_num_edges, edge_pred, mask) + tmp_graph = data.PackedMolecule(tmp_edge_list, self.id2atom[atom_type], tmp_edge_list[:, -1], + num_nodes, tmp_num_edges, num_relation=num_relation) + + is_valid = tmp_graph.is_valid | completed + + if is_valid.all(): + break + + if not is_valid.all() and verbose: + num_invalid = num_sample - is_valid.sum().item() + num_working = num_sample - completed.sum().item() + logger.warning("edge (%d, %d): %d / %d molecules are invalid even after %d resampling" % + (node_in, node_out, num_invalid, num_working, max_resample)) + + mask = (bond_pred < edge_model.input_dim - 1) & ~completed + edge_pred = torch.cat([edge, bond_pred.unsqueeze(-1)], dim=-1) + edge_list, num_edges = self._append(edge_list, num_edges, edge_pred, mask) + edge_pred = torch.cat([edge.flip(-1), bond_pred.unsqueeze(-1)], dim=-1) + edge_list, num_edges = self._append(edge_list, num_edges, edge_pred, mask) + graph = data.PackedMolecule(edge_list, atom_type, edge_list[:, -1], num_nodes, num_edges, + num_relation=num_relation) + + if node_in > 0: + assert (graph.num_edges[completed] == node_graph.num_edges[completed]).all() + completed |= graph.num_edges == node_graph.num_edges + if early_stop: + graph.atom_type = self.id2atom[graph.atom_type] + completed |= ~graph.is_valid + graph.atom_type = self.atom2id[graph.atom_type] + if completed.all(): + break + + self.train(is_training) + + # remove isolated atoms + index = graph.degree_out > 0 + # keep at least the first atom for each graph + index[graph.num_cum_nodes - graph.num_nodes] = 1 + graph = graph.subgraph(index) + graph.atom_type = self.id2atom[graph.atom_type] + + graph = graph[graph.is_valid_rdkit] + return graph + + def _append(self, data, num_xs, input, mask=None): + if mask is None: + mask = torch.ones_like(num_xs, dtype=torch.bool) + new_num_xs = num_xs + mask + new_num_cum_xs = new_num_xs.cumsum(0) + new_num_x = new_num_cum_xs[-1].item() + new_data = torch.zeros(new_num_x, *data.shape[1:], dtype=data.dtype, device=data.device) + starts = new_num_cum_xs - new_num_xs + ends = starts + num_xs + index = functional.multi_slice_mask(starts, ends, new_num_x) + new_data[index] = data + new_data[~index] = input[mask] + return new_data, new_num_xs + + @torch.no_grad() + def mask_node(self, graph, metric=None): + if self.num_node_sample == -1: + masked_graph, node_target = self.all_node(graph) + if metric is not None: + metric["node mask / graph"] = torch.tensor([len(masked_graph) / len(graph)], device=graph.device) + else: + masked_graph, node_target = self.sample_node(graph, self.num_node_sample) + return masked_graph, node_target + + @torch.no_grad() + def mask_edge(self, graph, metric=None): + if self.num_edge_sample == -1: + masked_graph, edge_target, edge = self.all_edge(graph) + if metric is not None: + metric["edge mask / graph"] = torch.tensor([len(masked_graph) / len(graph)], device=graph.device) + else: + masked_graph, edge_target, edge = self.sample_edge(graph, self.num_edge_sample) + return masked_graph, edge_target, edge + + @torch.no_grad() + def sample_node(self, graph, num_sample): + graph = graph.repeat(num_sample) + num_nodes = graph.num_nodes + num_keep_nodes = torch.rand(len(graph), device=graph.device) * num_nodes # [0, num_nodes) + num_keep_nodes = num_keep_nodes.long() # [0, num_nodes - 1] + + starts = graph.num_cum_nodes - graph.num_nodes + ends = starts + num_keep_nodes + mask = functional.multi_slice_mask(starts, ends, graph.num_node) + + new_graph = graph.subgraph(mask) + target = graph.subgraph(ends).atom_type + return new_graph, target + + @torch.no_grad() + def all_node(self, graph): + starts, ends, valid = self._all_prefix_slice(graph.num_nodes) + + num_repeat = len(starts) // len(graph) + graph = graph.repeat(num_repeat) + mask = functional.multi_slice_mask(starts, ends, graph.num_node) + new_graph = graph.subgraph(mask) + target = graph.subgraph(ends).atom_type + + return new_graph[valid], target[valid] + + @torch.no_grad() + def sample_edge(self, graph, num_sample): + if (graph.num_nodes < 2).any(): + graph = graph[graph.num_nodes >= 2] + warnings.warn("Graphs with less than 2 nodes can't be used for edge generation learning. Dropped") + + lengths = self._valid_edge_prefix_lengths(graph) + graph = graph.repeat(num_sample) + + num_max_node = graph.num_nodes.max().item() + num_node2num_dense_edge = torch.arange(num_max_node + 1, device=graph.device) ** 2 + num_node2length_idx = (lengths.unsqueeze(-1) < num_node2num_dense_edge.unsqueeze(0)).sum(dim=0) + # uniformly sample a mask from each graph's valid masks + length_indexes = torch.rand(len(graph), device=graph.device) * num_node2length_idx[graph.num_nodes] + length_indexes = length_indexes.long() + num_keep_dense_edges = lengths[length_indexes] + + # undirected: all upper triangular edge ids are flipped to lower triangular ids + # 1 -> 2, 4 -> 6, 5 -> 7 + node_index = graph.edge_list[:, :2] - graph._offsets.unsqueeze(-1) + node_in, node_out = node_index.t() + node_large = node_index.max(dim=-1)[0] + node_small = node_index.min(dim=-1)[0] + edge_id = node_large ** 2 + (node_in >= node_out) * node_large + node_small + undirected_edge_id = node_large * (node_large + 1) + node_small + + edge_mask = undirected_edge_id < num_keep_dense_edges[graph.edge2graph] + circum_box_size = (num_keep_dense_edges + 1.0).sqrt().ceil().long() + starts = graph.num_cum_nodes - graph.num_nodes + ends = starts + circum_box_size + node_mask = functional.multi_slice_mask(starts, ends, graph.num_node) + # compact nodes so that succeeding nodes won't affect graph pooling + new_graph = graph.edge_mask(edge_mask).node_mask(node_mask, compact=True) + + positive_edge = edge_id == num_keep_dense_edges[graph.edge2graph] + positive_graph = scatter_add(positive_edge.long(), graph.edge2graph, dim=0, dim_size=len(graph)).bool() + # default: non-edge + target = (self.num_bond_type - 1) * torch.ones(graph.batch_size, dtype=torch.long, device=graph.device) + target[positive_graph] = graph.edge_list[positive_edge, 2] + + node_in = circum_box_size - 1 + node_out = num_keep_dense_edges - node_in * circum_box_size + edge = torch.stack([node_in, node_out], dim=-1) + + return new_graph, target, edge + + @torch.no_grad() + def all_edge(self, graph): + if (graph.num_nodes < 2).any(): + graph = graph[graph.num_nodes >= 2] + warnings.warn("Graphs with less than 2 nodes can't be used for edge generation learning. Dropped") + + lengths = self._valid_edge_prefix_lengths(graph) + + starts, ends, valid = self._all_prefix_slice(graph.num_nodes ** 2, lengths) + + num_keep_dense_edges = ends - starts + num_repeat = len(starts) // len(graph) + graph = graph.repeat(num_repeat) + + # undirected: all upper triangular edge ids are flipped to lower triangular ids + # 1 -> 2, 4 -> 6, 5 -> 7 + node_index = graph.edge_list[:, :2] - graph._offsets.unsqueeze(-1) + node_in, node_out = node_index.t() + node_large = node_index.max(dim=-1)[0] + node_small = node_index.min(dim=-1)[0] + edge_id = node_large ** 2 + (node_in >= node_out) * node_large + node_small + undirected_edge_id = node_large * (node_large + 1) + node_small + + edge_mask = undirected_edge_id < num_keep_dense_edges[graph.edge2graph] + circum_box_size = (num_keep_dense_edges + 1.0).sqrt().ceil().long() + starts = graph.num_cum_nodes - graph.num_nodes + ends = starts + circum_box_size + node_mask = functional.multi_slice_mask(starts, ends, graph.num_node) + # compact nodes so that succeeding nodes won't affect graph pooling + new_graph = graph.edge_mask(edge_mask).node_mask(node_mask, compact=True) + + positive_edge = edge_id == num_keep_dense_edges[graph.edge2graph] + positive_graph = scatter_add(positive_edge.long(), graph.edge2graph, dim=0, dim_size=len(graph)).bool() + # default: non-edge + target = (self.num_bond_type - 1) * torch.ones(graph.batch_size, dtype=torch.long, device=graph.device) + target[positive_graph] = graph.edge_list[positive_edge, 2] + + node_in = circum_box_size - 1 + node_out = num_keep_dense_edges - node_in * circum_box_size + edge = torch.stack([node_in, node_out], dim=-1) + + return new_graph[valid], target[valid], edge[valid] + + @torch.no_grad() + def _all_prefix_slice(self, num_xs, lengths=None): + # extract a bunch of slices that correspond to the following num_repeat * n masks + # ------ repeat 0 ----- + # graphs[0]: [0, 0, ..., 0] + # ... + # graphs[-1]: [0, 0, ..., 0] + # ------ repeat 1 ----- + # graphs[0]: [1, 0, ..., 0] + # ... + # graphs[-1]: [1, 0, ..., 0] + # ... + # ------ repeat -1 ----- + # graphs[0]: [1, ..., 1, 0] + # ... + # graphs[-1]: [1, ..., 1, 0] + num_cum_xs = num_xs.cumsum(0) + starts = num_cum_xs - num_xs + if lengths is None: + num_max_x = num_xs.max().item() + lengths = torch.arange(num_max_x, device=num_xs.device) + + pack_offsets = torch.arange(len(lengths), device=num_xs.device) * num_cum_xs[-1] + # starts, lengths, ends: (num_repeat, num_graph) + starts = starts.unsqueeze(0) + pack_offsets.unsqueeze(-1) + valid = lengths.unsqueeze(-1) <= num_xs.unsqueeze(0) - 1 + lengths = torch.min(lengths.unsqueeze(-1), num_xs.unsqueeze(0) - 1) + ends = starts + lengths + + starts = starts.flatten() + ends = ends.flatten() + valid = valid.flatten() + + return starts, ends, valid + + @torch.no_grad() + def _valid_edge_prefix_lengths(self, graph): + # valid prefix lengths are across a batch, according to the largest graph + num_max_node = graph.num_nodes.max().item() + # edge id in an adjacency (snake pattern) + # in + # o 0 1 4 + # u 2 3 5 + # t 6 7 8 + lengths = torch.arange(num_max_node ** 2, device=graph.device) + circum_box_size = (lengths + 1.0).sqrt().ceil().long() + # only keep lengths that ends in the lower triangular part of adjacency matrix + lengths = lengths[lengths >= circum_box_size * (circum_box_size - 1)] + # lengths: [0, 2, 3, 6, 7, 8, ...] + # num_node2length_idx: [0, 1, 4, 6, ...] + # num_edge_unrolls + # 0 + # 1 0 + # 2 1 0 + num_edge_unrolls = (lengths + 1.0).sqrt().ceil().long() ** 2 - lengths - 1 + # num_edge_unrolls: [0, 1, 0, 2, 1, 0, ...] + # remove lengths that unroll too much. they always lead to empty targets + lengths = lengths[(num_edge_unrolls <= self.max_edge_unroll) & (num_edge_unrolls > 0)] + + return lengths + + +@R.register("tasks.GCPNGeneration") +class GCPNGeneration(tasks.Task, core.Configurable): + """ + The graph generative model from `Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation`_. + + .. _Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation: + https://papers.nips.cc/paper/7877-graph-convolutional-policy-network-for-goal-directed-molecular-graph-generation.pdf + + Parameters: + model (nn.Module): graph representation model + atom_types (list or set): set of all possible atom types + task (str or list of str, optional): property optimization task(s) + max_edge_unroll (int, optional): max node id difference. + If not provided, use the statistics from the training set. + max_node (int, optional): max number of node. + If not provided, use the statistics from the training set. + criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values + are the corresponding weights. Available criterions are ``nll`` and ``ppo``. + agent_update_interval (int, optional): update the agent every n batch + gamma (float, optional): reward discount rate + reward_temperature (float, optional): temperature for reward. Higher temperature encourages larger mean reward, + while lower temperature encourages larger maximal reward. + baseline_momentum (float, optional): momentum for value function baseline + """ + + eps = 1e-10 + top_k = 10 + _option_members = {"task", "criterion"} + + def __init__(self, model, atom_types, max_edge_unroll=None, max_node=None, task=(), criterion="nll", + hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1, baseline_momentum=0.9): + super(GCPNGeneration, self).__init__() + self.model = model + self.task = task + self.max_edge_unroll = max_edge_unroll + self.max_node = max_node + self.criterion = criterion + self.hidden_dim_mlp = hidden_dim_mlp + self.agent_update_interval = agent_update_interval + self.gamma = gamma + self.reward_temperature = reward_temperature + self.baseline_momentum = baseline_momentum + self.best_results = defaultdict(list) + self.batch_id = 0 + + + remap_atom_type = transforms.RemapAtomType(atom_types) + self.register_buffer("id2atom", remap_atom_type.id2atom) + self.register_buffer("atom2id", remap_atom_type.atom2id) + + self.new_atom_embeddings = nn.Parameter(torch.zeros(self.id2atom.size(0), self.model.output_dim)) + nn.init.normal_(self.new_atom_embeddings, mean=0, std=0.1) + self.inp_dim_stop = self.model.output_dim + self.mlp_stop = layers.MultiLayerPerceptron(self.inp_dim_stop, [self.hidden_dim_mlp, 2], activation='tanh') + + self.inp_dim_node1 = self.model.output_dim + self.model.output_dim + self.mlp_node1 = layers.MultiLayerPerceptron(self.inp_dim_node1, [self.hidden_dim_mlp, 1], activation='tanh') + self.inp_dim_node2 = 2 * self.model.output_dim + self.model.output_dim + self.mlp_node2 = layers.MultiLayerPerceptron(self.inp_dim_node2, [self.hidden_dim_mlp, 1], activation='tanh') + self.inp_dim_edge = 2 * self.model.output_dim + self.mlp_edge = layers.MultiLayerPerceptron(self.inp_dim_edge, [self.hidden_dim_mlp, self.model.num_relation], activation='tanh') + + self.agent_model = copy.deepcopy(self.model) + self.agent_new_atom_embeddings = copy.deepcopy(self.new_atom_embeddings) + self.agent_mlp_stop = copy.deepcopy(self.mlp_stop) + self.agent_mlp_node1 = copy.deepcopy(self.mlp_node1) + self.agent_mlp_node2 = copy.deepcopy(self.mlp_node2) + self.agent_mlp_edge = copy.deepcopy(self.mlp_edge) + + + def preprocess(self, train_set, valid_set, test_set): + """ + Add atom id mapping and random BFS order to the training set. + + Compute ``max_edge_unroll`` and ``max_node`` on the training set if not provided. + """ + remap_atom_type = transforms.RemapAtomType(train_set.atom_types) + train_set.transform = transforms.Compose([ + train_set.transform, + transforms.RandomBFSOrder(), + ]) + + if self.max_edge_unroll is None or self.max_node is None: + self.max_edge_unroll = 0 + self.max_node = 0 + + train_set = tqdm(train_set, "Computing max number of nodes and edge unrolling") + for sample in train_set: + graph = sample["graph"] + if graph.edge_list.numel(): + edge_unroll = (graph.edge_list[:, 0] - graph.edge_list[:, 1]).abs().max().item() + self.max_edge_unroll = max(self.max_edge_unroll, edge_unroll) + self.max_node = max(self.max_node, graph.num_node) + + logger.warning("max node = %d, max edge unroll = %d" % (self.max_node, self.max_edge_unroll)) + + self.register_buffer("moving_baseline", torch.zeros(self.max_node + 1)) + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + for criterion, weight in self.criterion.items(): + if criterion == "nll": + _loss, _metric = self.MLE_forward(batch) + all_loss += _loss * weight + metric.update(_metric) + elif criterion == "ppo": + _loss, _metric = self.reinforce_forward(batch) + all_loss += _loss * weight + metric.update(_metric) + else: + raise ValueError("Unknown criterion `%s`" % criterion) + + return all_loss, metric + + def predict(self, graph, label_dict, use_agent=False): + # step1: get node/graph embeddings + if not use_agent: + output = self.model(graph, graph.node_feature.float()) + else: + output = self.agent_model(graph, graph.node_feature.float()) + + extended_node2graph = torch.arange(graph.num_nodes.size(0), + device=self.device).unsqueeze(1).repeat([1, self.id2atom.size(0)]).view(-1) # (num_graph * 16) + extended_node2graph = torch.cat((graph.node2graph, extended_node2graph)) # (num_node + 16 * num_graph) + + graph_feature_per_node = output["graph_feature"][extended_node2graph] + + # step2: predict stop + stop_feature = output["graph_feature"] #(num_graph, n_out) + if not use_agent: + stop_logits = self.mlp_stop(stop_feature) #(num_graph, 2) + else: + stop_logits = self.agent_mlp_stop(stop_feature) #(num_graph, 2) + + if label_dict == None: + return stop_logits + # step3: predict first node: node1 + node1_feature = output["node_feature"] #(num_node, n_out) + + node1_feature = torch.cat((node1_feature, + self.new_atom_embeddings.repeat([graph.num_nodes.size(0), 1])), 0) # (num_node + 16 * num_graph, n_out) + + node2_feature_node2 = node1_feature.clone() # (num_node + 16 * num_graph, n_out) + # cat graph emb + node1_feature = torch.cat((node1_feature, graph_feature_per_node), 1) + + if not use_agent: + node1_logits = self.mlp_node1(node1_feature).squeeze(1) #(num_node + 16 * num_graph) + else: + node1_logits = self.agent_mlp_node1(node1_feature).squeeze(1) #(num_node + 16 * num_graph) + + #mask the extended part + mask = torch.zeros(node1_logits.size(), device=self.device) + mask[:graph.num_node] = 1 + node1_logits = torch.where(mask>0, node1_logits, -10000.0*torch.ones(node1_logits.size(), device=self.device)) + + # step4: predict second node: node2 + + node1_index_per_graph = (graph.num_cum_nodes - graph.num_nodes) + label_dict["label1"] #(num_graph) + node1_index = node1_index_per_graph[extended_node2graph] # (num_node + 16 * num_graph) + node2_feature_node1 = node1_feature[node1_index] #(num_node + 16 * num_graph, n_out) + node2_feature = torch.cat((node2_feature_node1, node2_feature_node2), 1) #(num_node + 16 * num_graph, 2n_out) + if not use_agent: + node2_logits = self.mlp_node2(node2_feature).squeeze(1) #(num_node + 16 * num_graph) + else: + node2_logits = self.agent_mlp_node2(node2_feature).squeeze(1) #(num_node + 16 * num_graph) + + #mask the selected node1 + mask = torch.zeros(node2_logits.size(), device=self.device) + mask[node1_index_per_graph] = 1 + node2_logits = torch.where(mask==0, node2_logits, -10000.0*torch.ones(node2_logits.size(), device=self.device)) + + # step5: predict edge type + is_new_node = label_dict["label2"] - graph.num_nodes # if an entry is non-negative, this is a new added node. (num_graph) + graph_offset = torch.arange(graph.num_nodes.size(0), device=self.device) + node2_index_per_graph = torch.where(is_new_node >= 0, + graph.num_node + graph_offset * self.id2atom.size(0) + is_new_node, + label_dict["label2"] + graph.num_cum_nodes - graph.num_nodes) # (num_graph) + node2_index = node2_index_per_graph[extended_node2graph] + + edge_feature_node1 = node2_feature_node2[node1_index_per_graph] #(num_graph, n_out) + edge_feature_node2 = node2_feature_node2[node2_index_per_graph] # #(num_graph, n_out) + edge_feature = torch.cat((edge_feature_node1, edge_feature_node2), 1) #(num_graph, 2n_out) + if not use_agent: + edge_logits = self.mlp_edge(edge_feature) # (num_graph, num_relation) + else: + edge_logits = self.agent_mlp_edge(edge_feature) # (num_graph, num_relation) + + index_dict = { + "node1_index_per_graph": node1_index_per_graph, + "node2_index_per_graph": node2_index_per_graph, + "extended_node2graph": extended_node2graph + } + return stop_logits, node1_logits, node2_logits, edge_logits, index_dict + + def reinforce_forward(self, batch): + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + if self.batch_id % self.agent_update_interval == 0: + self.agent_model.load_state_dict(self.model.state_dict()) + self.agent_mlp_stop.load_state_dict(self.mlp_stop.state_dict()) + self.agent_mlp_node1.load_state_dict(self.mlp_node1.state_dict()) + self.agent_mlp_node2.load_state_dict(self.mlp_node2.state_dict()) + self.agent_mlp_edge.load_state_dict(self.mlp_edge.state_dict()) + self.agent_new_atom_embeddings.data = self.new_atom_embeddings.data.clone() + + self.batch_id += 1 + + # generation takes less time when early_stop=True + graph = self.generate(len(batch["graph"]), max_resample=20, off_policy=True, max_step=40 * 2, verbose=1) + if len(graph) == 0 or graph.num_nodes.max() == 1: + logger.error("Generation results collapse to singleton molecules") + + all_loss.requires_grad_() + nan = torch.tensor(float("nan"), device=self.device) + for task in self.task: + if task == "plogp": + metric["Penalized logP"] = nan + metric["Penalized logP (max)"] = nan + elif task == "qed": + metric["QED"] = nan + metric["QED (max)"] = nan + metric["PPO objective"] = nan + + return all_loss, metric + + reward = torch.zeros(len(graph), device=self.device) + for task in self.task: + if task == "plogp": + plogp = metrics.penalized_logP(graph) + metric["Penalized logP"] = plogp.mean() + metric["Penalized logP (max)"] = plogp.max() + self.update_best_result(graph, plogp, "Penalized logP") + # TODO: + reward += (plogp / self.reward_temperature).exp() + + if plogp.max().item() > 5: + print("Penalized logP max = %s" % plogp.max().item()) + print(self.best_results["Penalized logP"]) + + elif task == "qed": + qed = metrics.QED(graph) + metric["QED"] = qed.mean() + metric["QED (max)"] = qed.max() + self.update_best_result(graph, qed, "QED") + # TODO: + #reward += ((qed - 0.9) * 20).exp() + #reward += ((qed - 0.4) * 4 / self.reward_temperature).exp() + #reward += qed + reward += (qed / self.reward_temperature).exp() + + + if qed.max().item() > 0.93: + print("QED max = %s" % qed.max().item()) + print(self.best_results["QED"]) + else: + raise ValueError("Unknown task `%s`" % task) + + # these graph-level features will broadcast to all masked graphs + with graph.graph(): + graph.reward = reward + graph.original_num_nodes = graph.num_nodes + + #graph.atom_type = self.atom2id[graph.atom_type] + + is_training = self.training + # easily got nan if BN is trained + self.bn_eval() + + + + stop_graph, stop_label1, stop_label2, stop_label3, stop_label4 = self.all_stop(graph) + edge_graph, edge_label1, edge_label2, edge_label3, edge_label4 = self.all_edge(graph) + + graph = self._cat([stop_graph, edge_graph]) + label1_target = torch.cat([stop_label1, edge_label1]) + label2_target = torch.cat([stop_label2, edge_label2]) + label3_target = torch.cat([stop_label3, edge_label3]) + label4_target = torch.cat([stop_label4, edge_label4]) + label_dict = {"label1": label1_target, "label2": label2_target, "label3": label3_target, "label4": label4_target} + + # reward reshaping + reward = graph.reward + reward = reward * self.gamma ** (graph.original_num_nodes - graph.num_nodes).float() + + # per graph size reward baseline + weight = torch.ones_like(graph.num_nodes, dtype=torch.float) + baseline = scatter_add(reward, graph.num_nodes, dim_size=self.max_node + 1) / \ + (scatter_add(weight, graph.num_nodes, dim_size=self.max_node + 1) + self.eps) + # TODO: + self.moving_baseline = self.moving_baseline * self.baseline_momentum + baseline * (1 - self.baseline_momentum) + reward -= self.moving_baseline[graph.num_nodes] + reward += graph.is_valid + + # calculate object + stop_logits, node1_logits, node2_logits, edge_logits, index_dict = self.predict(graph, label_dict) + with torch.no_grad(): + old_stop_logits, old_node1_logits, old_node2_logits, old_edge_logits, old_index_dict = self.predict(graph, label_dict, use_agent=True) + + stop_prob = F.log_softmax(stop_logits, dim=-1) + node1_prob = scatter_log_softmax(node1_logits, index_dict["extended_node2graph"]) + node2_prob = scatter_log_softmax(node2_logits, index_dict["extended_node2graph"]) + edge_prob = F.log_softmax(edge_logits, dim=-1) + old_stop_prob = F.log_softmax(old_stop_logits, dim=-1) + old_node1_prob = scatter_log_softmax(old_node1_logits, old_index_dict["extended_node2graph"]) + old_node2_prob = scatter_log_softmax(old_node2_logits, old_index_dict["extended_node2graph"]) + old_edge_prob = F.log_softmax(old_edge_logits, dim=-1) + + cur_logp = stop_prob[:, 0] + node1_prob[index_dict["node1_index_per_graph"]] \ + + node2_prob[index_dict["node2_index_per_graph"]] + torch.gather(edge_prob, -1, label3_target.view(-1, 1)).view(-1) + cur_logp[label4_target==1] = stop_prob[:, 1][label4_target==1] + + old_logp = old_stop_prob[:, 0] + old_node1_prob[old_index_dict["node1_index_per_graph"]] \ + + old_node2_prob[index_dict["node2_index_per_graph"]] + torch.gather(old_edge_prob, -1, label3_target.view(-1, 1)).view(-1) + old_logp[label4_target==1] = old_stop_prob[:, 1][label4_target==1] + objective = functional.clipped_policy_gradient_objective(cur_logp, old_logp, reward) + objective = objective.mean() + metric["PPO objective"] = objective + all_loss += (-objective) + + self.bn_train(is_training) + + return all_loss, metric + + + def MLE_forward(self, batch): + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + graph = batch["graph"] + stop_graph, stop_label1, stop_label2, stop_label3, stop_label4 = self.all_stop(graph) + edge_graph, edge_label1, edge_label2, edge_label3, edge_label4 = self.all_edge(graph) + + graph = self._cat([stop_graph, edge_graph]) + label1_target = torch.cat([stop_label1, edge_label1]) + label2_target = torch.cat([stop_label2, edge_label2]) + label3_target = torch.cat([stop_label3, edge_label3]) + label4_target = torch.cat([stop_label4, edge_label4]) + label_dict = {"label1": label1_target, "label2": label2_target, "label3": label3_target, "label4": label4_target} + stop_logits, node1_logits, node2_logits, edge_logits, index_dict = self.predict(graph, label_dict) + + loss_stop = F.nll_loss(F.log_softmax(stop_logits, dim=-1), label4_target, reduction='none') + loss_stop = 0.5 * (torch.mean(loss_stop[label4_target==0]) + torch.mean(loss_stop[label4_target==1])) + #loss_stop = torch.mean(loss_stop) + metric["stop bce loss"] = loss_stop + all_loss += loss_stop + + loss_node1 = -(scatter_log_softmax(node1_logits, index_dict["extended_node2graph"])[index_dict["node1_index_per_graph"]]) + loss_node1 = torch.mean(loss_node1[label4_target==0]) + metric["node1 loss"] = loss_node1 + all_loss += loss_node1 + + loss_node2 = -(scatter_log_softmax(node2_logits, index_dict["extended_node2graph"])[index_dict["node2_index_per_graph"]]) + loss_node2 = torch.mean(loss_node2[label4_target==0]) + metric["node2 loss"] = loss_node2 + all_loss += loss_node2 + + loss_edge = F.nll_loss(F.log_softmax(edge_logits, dim=-1), label3_target, reduction='none') + + loss_edge = torch.mean(loss_edge[label4_target==0]) + metric["edge loss"] = loss_edge + all_loss += loss_edge + + metric["total loss"] = all_loss + + pred = stop_logits, node1_logits, node2_logits, edge_logits + target = label1_target, label2_target, label3_target, label4_target, index_dict + + metric.update(self.evaluate(pred, target)) + + return all_loss, metric + + def evaluate(self, pred, target): + stop_logits, node1_logits, node2_logits, edge_logits = pred + label1_target, label2_target, label3_target, label4_target, index_dict = target + metric = {} + stop_acc = torch.argmax(stop_logits, -1) == label4_target + metric["stop acc"] = stop_acc.float().mean() + + node1_pred = scatter_max(node1_logits, index_dict["extended_node2graph"])[1] + node1_acc = node1_pred == index_dict["node1_index_per_graph"] + metric["node1 acc"] = node1_acc[label4_target == 0].float().mean() + + node2_pred = scatter_max(node2_logits, index_dict["extended_node2graph"])[1] + node2_acc = node2_pred == index_dict["node2_index_per_graph"] + metric["node2 acc"] = node2_acc[label4_target == 0].float().mean() + + edge_acc = torch.argmax(edge_logits, -1) == label3_target + metric["edge acc"] = edge_acc[label4_target == 0].float().mean() + return metric + + # generation step + # 1. top-1 action + # 2. apply action + + @torch.no_grad() + def _construct_dist(self, prob_, graph): + max_size = max(graph.num_nodes) + self.id2atom.size(0) + probs = torch.zeros((len(graph), max_size), device=prob_.device).view(-1) + start = (graph.num_cum_nodes - graph.num_nodes)[graph.node2graph] + start = torch.arange(graph.num_node, device=self.device) - start + index = torch.arange(graph.num_nodes.size(0), device=self.device) * max_size + index = index[graph.node2graph] + start + probs[index] = prob_[:graph.num_node] + + start_extend = torch.arange(len(self.id2atom), device=self.device).repeat(graph.num_nodes.size()) # (num_graph * 16) + index_extend = torch.arange(len(graph), device=self.device) * max_size + graph.num_nodes + index2graph = torch.arange(len(graph), device=self.device).repeat_interleave(len(self.id2atom)) + index_extend = index_extend[index2graph] + start_extend + probs[index_extend] = prob_[graph.num_node:] + probs = probs.view(len(graph.num_nodes), max_size) + return torch.distributions.Categorical(probs), probs # (n_graph, max_size) + + @torch.no_grad() + def _sample_action(self, graph, off_policy): + if off_policy: + model = self.agent_model + new_atom_embeddings = self.agent_new_atom_embeddings + mlp_stop = self.agent_mlp_stop + mlp_node1 = self.agent_mlp_node1 + mlp_node2 = self.agent_mlp_node2 + mlp_edge = self.agent_mlp_edge + else: + model = self.model + new_atom_embeddings = self.new_atom_embeddings + mlp_stop = self.mlp_stop + mlp_node1 = self.mlp_node1 + mlp_node2 = self.mlp_node2 + mlp_edge = self.mlp_edge + + # step1: get feature + output = model(graph, graph.node_feature.float()) + + extended_node2graph = torch.arange(len(graph), device=self.device).repeat_interleave(len(self.id2atom)) # (num_graph * 16) + extended_node2graph = torch.cat((graph.node2graph, extended_node2graph)) # (num_node + 16 * num_graph) + + graph_feature_per_node = output["graph_feature"][extended_node2graph] + + # step2: predict stop + stop_feature = output["graph_feature"] # (num_graph, n_out) + stop_logits = mlp_stop(stop_feature) # (num_graph, 2) + stop_prob = F.softmax(stop_logits, -1) # (num_graph, 2) + stop_prob_dist = torch.distributions.Categorical(stop_prob) + stop_pred = stop_prob_dist.sample() + # step3: predict first node: node1 + + node1_feature = output["node_feature"] #(num_node, n_out) + + node1_feature = torch.cat((node1_feature, + new_atom_embeddings.repeat([graph.num_nodes.size(0), 1])), 0) # (num_node + 16 * num_graph, n_out) + node2_feature_node2 = node1_feature.clone() # (num_node + 16 * num_graph, n_out) + + node1_feature = torch.cat((node1_feature, graph_feature_per_node), 1) + + node1_logits = mlp_node1(node1_feature).squeeze(1) #(num_node + 16 * num_graph) + #mask the extended part + mask = torch.zeros(node1_logits.size(), device=self.device) + mask[:graph.num_node] = 1 + node1_logits = torch.where(mask>0, node1_logits, -10000.0*torch.ones(node1_logits.size(), device=self.device)) + + node1_prob = scatter_softmax(node1_logits, extended_node2graph) # (num_node + 16 * num_graph) + node1_prob_dist, tmp = self._construct_dist(node1_prob, graph) # (num_graph, max) + + node1_pred = node1_prob_dist.sample() #(num_graph) + node1_index_per_graph = node1_pred + (graph.num_cum_nodes - graph.num_nodes) + # step4: predict second node: node2 + node1_index = node1_index_per_graph[extended_node2graph] # (num_node + 16 * num_graph) + node2_feature_node1 = node1_feature[node1_index] # (num_node + 16 * num_graph, n_out) + + node2_feature = torch.cat((node2_feature_node1, node2_feature_node2), 1) # (num_node + 16 * num_graph, 2n_out) + node2_logits = mlp_node2(node2_feature).squeeze(1) # (num_node + 16 * num_graph) + + # mask the selected node1 + mask = torch.zeros(node2_logits.size(), device=self.device) + mask[node1_index_per_graph] = 1 + node2_logits = torch.where(mask==0, node2_logits, -10000.0*torch.ones(node2_logits.size(), device=self.device)) + node2_prob = scatter_softmax(node2_logits, extended_node2graph) # (num_node + 16 * num_graph) + node2_prob_dist, tmp = self._construct_dist(node2_prob, graph) # (num_graph, max) + node2_pred = node2_prob_dist.sample() # (num_graph,) + is_new_node = node2_pred - graph.num_nodes + graph_offset = torch.arange(graph.num_nodes.size(0), device=self.device) + node2_index_per_graph = torch.where(is_new_node >= 0, + graph.num_node + graph_offset * self.id2atom.size(0) + is_new_node, + node2_pred + graph.num_cum_nodes - graph.num_nodes) + + + # step5: predict edge type + edge_feature_node1 = node2_feature_node2[node1_index_per_graph] # (num_graph, n_out) + edge_feature_node2 = node2_feature_node2[node2_index_per_graph] # (num_graph, n_out) + edge_feature = torch.cat((edge_feature_node1, edge_feature_node2), 1) # (num_graph, 2n_out) + edge_logits = mlp_edge(edge_feature) + edge_prob = F.softmax(edge_logits, -1) # (num_graph, 3) + edge_prob_dist = torch.distributions.Categorical(edge_prob) + edge_pred = edge_prob_dist.sample() + + return stop_pred, node1_pred, node2_pred, edge_pred + + @torch.no_grad() + def _top1_action(self, graph, off_policy): + + if off_policy: + model = self.agent_model + new_atom_embeddings = self.agent_new_atom_embeddings + mlp_stop = self.agent_mlp_stop + mlp_node1 = self.agent_mlp_node1 + mlp_node2 = self.agent_mlp_node2 + mlp_edge = self.agent_mlp_edge + else: + model = self.model + new_atom_embeddings = self.new_atom_embeddings + mlp_stop = self.mlp_stop + mlp_node1 = self.mlp_node1 + mlp_node2 = self.mlp_node2 + mlp_edge = self.mlp_edge + + # step1: get feature + output = model(graph, graph.node_feature.float()) + + extended_node2graph = torch.arange(graph.num_nodes.size(0), + device=self.device).unsqueeze(1).repeat([1, self.id2atom.size(0)]).view(-1) # (num_graph * 16) + extended_node2graph = torch.cat((graph.node2graph, extended_node2graph)) # (num_node + 16 * num_graph) + + graph_feature_per_node = output["graph_feature"][extended_node2graph] + + # step2: predict stop + stop_feature = output["graph_feature"] # (num_graph, n_out) + stop_logits = mlp_stop(stop_feature) # (num_graph, 2) + stop_pred = torch.argmax(stop_logits, -1) # (num_graph,) + # step3: predict first node: node1 + + node1_feature = output["node_feature"] #(num_node, n_out) + + node1_feature = torch.cat((node1_feature, + new_atom_embeddings.repeat([graph.num_nodes.size(0), 1])), 0) # (num_node + 16 * num_graph, n_out) + node2_feature_node2 = node1_feature.clone() # (num_node + 16 * num_graph, n_out) + + node1_feature = torch.cat((node1_feature, graph_feature_per_node), 1) + + node1_logits = mlp_node1(node1_feature).squeeze(1) # (num_node + 16 * num_graph) + # mask the extended part + mask = torch.zeros(node1_logits.size(), device=self.device) + mask[:graph.num_node] = 1 + node1_logits = torch.where(mask>0, node1_logits, -10000.0*torch.ones(node1_logits.size(), device=self.device)) + + node1_index_per_graph = scatter_max(node1_logits, extended_node2graph)[1] # (num_node + 16 * num_graph) + node1_pred = node1_index_per_graph - (graph.num_cum_nodes - graph.num_nodes) + + # step4: predict second node: node2 + node1_index = node1_index_per_graph[extended_node2graph] # (num_node + 16 * num_graph) + node2_feature_node1 = node1_feature[node1_index] # (num_node + 16 * num_graph, n_out) + + node2_feature = torch.cat((node2_feature_node1, node2_feature_node2), 1) # (num_node + 16 * num_graph, 2n_out + node2_logits = mlp_node2(node2_feature).squeeze(1) # (num_node + 16 * num_graph) + + #mask the selected node1 + mask = torch.zeros(node2_logits.size(), device=self.device) + mask[node1_index_per_graph] = 1 + node2_logits = torch.where(mask==0, node2_logits, -10000.0*torch.ones(node2_logits.size(), device=self.device)) + node2_index_per_graph = scatter_max(node2_logits, extended_node2graph)[1] # (num_node + 16 * num_graph) + + is_new_node = node2_index_per_graph - graph.num_node # non negative if is new node + graph_offset = torch.arange(graph.num_nodes.size(0), device=self.device) + node2_pred = torch.where(is_new_node>=0, graph.num_nodes + is_new_node - graph_offset * self.id2atom.size(0), + node2_index_per_graph - (graph.num_cum_nodes - graph.num_nodes)) + + # step5: predict edge type + edge_feature_node1 = node2_feature_node2[node1_index_per_graph] #(num_graph, n_out) + edge_feature_node2 = node2_feature_node2[node2_index_per_graph] # #(num_graph, n_out) + edge_feature = torch.cat((edge_feature_node1, edge_feature_node2), 1) #(num_graph, 2n_out) + edge_logits = mlp_edge(edge_feature) + edge_pred = torch.argmax(edge_logits, -1) + + return stop_pred, node1_pred, node2_pred, edge_pred + + @torch.no_grad() + def _apply_action(self, graph, off_policy, max_resample=10, verbose=0, min_node=5): + # action (num_graph, 4) + + # stopped graph is removed, initialize is_valid as False + is_valid = torch.zeros(len(graph), dtype=torch.bool, device=self.device) + stop_action = torch.zeros(len(graph), dtype=torch.long, device=self.device) + node1_action = torch.zeros(len(graph), dtype=torch.long, device=self.device) + node2_action = torch.zeros(len(graph), dtype=torch.long, device=self.device) + edge_action = torch.zeros(len(graph), dtype=torch.long, device=self.device) + + for i in range(max_resample): + # maximal resample time + mask = ~is_valid + if max_resample == 1: + tmp_stop_action, tmp_node1_action, tmp_node2_action, tmp_edge_action = \ + self._top1_action(graph, off_policy) + else: + tmp_stop_action, tmp_node1_action, tmp_node2_action, tmp_edge_action = \ + self._sample_action(graph, off_policy) + + stop_action[mask] = tmp_stop_action[mask] + node1_action[mask] = tmp_node1_action[mask] + node2_action[mask] = tmp_node2_action[mask] + edge_action[mask] = tmp_edge_action[mask] + + stop_action[graph.num_nodes <= 5] = 0 + # tmp add new nodes + has_new_node = (node2_action >= graph.num_nodes) & (stop_action == 0) + new_atom_id = (node2_action - graph.num_nodes)[has_new_node] + new_atom_type = self.id2atom[new_atom_id] + + atom_type, num_nodes = functional._extend(graph.atom_type, graph.num_nodes, new_atom_type, has_new_node) + + # tmp cast to regular node ids + node2_action = torch.where(has_new_node, graph.num_nodes, node2_action) + + # tmp modify edges + new_edge = torch.stack([node1_action, node2_action], dim=-1) + edge_list = graph.edge_list.clone() + bond_type = graph.bond_type.clone() + edge_list[:, :2] -= graph._offsets.unsqueeze(-1) + is_modified_edge = (edge_list[:, :2] == new_edge[graph.edge2graph]).all(dim=-1) & \ + (stop_action[graph.edge2graph] == 0) + has_modified_edge = scatter_max(is_modified_edge.long(), graph.edge2graph, dim_size=len(graph))[0] > 0 + bond_type[is_modified_edge] = edge_action[has_modified_edge] + edge_list[is_modified_edge, 2] = edge_action[has_modified_edge] + # tmp modify reverse edges + new_edge = new_edge.flip(-1) + is_modified_edge = (edge_list[:, :2] == new_edge[graph.edge2graph]).all(dim=-1) & \ + (stop_action[graph.edge2graph] == 0) + bond_type[is_modified_edge] = edge_action[has_modified_edge] + edge_list[is_modified_edge, 2] = edge_action[has_modified_edge] + + + # tmp add new edges + has_new_edge = (~has_modified_edge) & (stop_action == 0) + new_edge_list = torch.stack([node1_action, node2_action, edge_action], dim=-1)[has_new_edge] + bond_type = functional._extend(bond_type, graph.num_edges, edge_action[has_new_edge], has_new_edge)[0] + edge_list, num_edges = functional._extend(edge_list, graph.num_edges, new_edge_list, has_new_edge) + + # tmp add reverse edges + new_edge_list = torch.stack([node2_action, node1_action, edge_action], dim=-1)[has_new_edge] + bond_type = functional._extend(bond_type, num_edges, edge_action[has_new_edge], has_new_edge)[0] + edge_list, num_edges = functional._extend(edge_list, num_edges, new_edge_list, has_new_edge) + + tmp_graph = type(graph)(edge_list, atom_type=atom_type, bond_type=bond_type, num_nodes=num_nodes, + num_edges=num_edges, num_relation=graph.num_relation) + is_valid = tmp_graph.is_valid | (stop_action == 1) + if is_valid.all(): + break + if not is_valid.all() and verbose: + num_invalid = len(graph) - is_valid.sum().item() + num_working = len(graph) + logger.warning("%d / %d molecules are invalid even after %d resampling" % + (num_invalid, num_working, max_resample)) + + # apply the true action + # inherit attributes + data_dict = graph.data_dict + meta_dict = graph.meta_dict + for key in ["atom_type", "bond_type"]: + data_dict.pop(key) + # pad 0 for node / edge attributes + for k, v in data_dict.items(): + if "node" in meta_dict[k]: + shape = (len(new_atom_type), *v.shape[1:]) + new_data = torch.zeros(shape, dtype=v.dtype, device=self.device) + data_dict[k] = functional._extend(v, graph.num_nodes, new_data, has_new_node)[0] + if "edge" in meta_dict[k]: + shape = (len(new_edge_list) * 2, *v.shape[1:]) + new_data = torch.zeros(shape, dtype=v.dtype, device=self.device) + data_dict[k] = functional._extend(v, graph.num_edges, new_data, has_new_edge * 2)[0] + + new_graph = type(graph)(edge_list, atom_type=atom_type, bond_type=bond_type, num_nodes=num_nodes, + num_edges=num_edges, num_relation=graph.num_relation, + meta_dict=meta_dict, **data_dict) + with new_graph.graph(): + new_graph.is_stopped = stop_action == 1 + + new_graph, feature_valid = self._update_molecule_feature(new_graph) + + return new_graph[feature_valid] + + def _update_molecule_feature(self, graphs): + # This function is very slow + mols = graphs.to_molecule(ignore_error=True) + valid = [mol is not None for mol in mols] + valid = torch.tensor(valid, device=graphs.device) + new_graphs = type(graphs).from_molecule(mols, kekulize=True, atom_feature="symbol") + + node_feature = torch.zeros(graphs.num_node, *new_graphs.node_feature.shape[1:], + dtype=new_graphs.node_feature.dtype, device=graphs.device) + edge_feature = torch.zeros(graphs.num_edge, *new_graphs.edge_feature.shape[1:], + dtype=new_graphs.edge_feature.dtype, device=graphs.device) + bond_type = torch.zeros_like(graphs.bond_type) + node_mask = valid[graphs.node2graph] + edge_mask = valid[graphs.edge2graph] + node_feature[node_mask] = new_graphs.node_feature.to(device=graphs.device) + edge_feature[edge_mask] = new_graphs.edge_feature.to(device=graphs.device) + bond_type[edge_mask] = new_graphs.bond_type.to(device=graphs.device) + + with graphs.node(): + graphs.node_feature = node_feature + with graphs.edge(): + graphs.edge_feature = edge_feature + graphs.bond_type = bond_type + + return graphs, valid + + def bn_train(self, mode=True): + for module in self.modules(): + if isinstance(module, nn.BatchNorm1d): + module.train(mode) + + def bn_eval(self): + for module in self.modules(): + if isinstance(module, nn.BatchNorm1d): + module.eval() + + def update_best_result(self, graph, score, task): + score = score.cpu() + best_results = self.best_results[task] + for s, i in zip(*score.sort(descending=True)): + s = s.item() + i = i.item() + if len(best_results) == self.top_k and s < best_results[-1][0]: + break + best_results.append((s, graph[i].to_smiles())) + best_results.sort(reverse=True) + best_results = best_results[:self.top_k] + self.best_results[task] = best_results + + @torch.no_grad() + def generate(self, num_sample, max_resample=20, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): + is_training = self.training + self.eval() + + graph = data.Molecule.from_smiles(initial_smiles, kekulize=True, atom_feature="symbol").repeat(num_sample) + + # TODO: workaround + if self.device.type == "cuda": + graph = graph.cuda(self.device) + + result = [] + for i in range(max_step): + new_graph = self._apply_action(graph, off_policy, max_resample, verbose=1) + if i == max_step - 1: + # last step, collect all graph that is valid + result.append(new_graph[(new_graph.num_nodes <= (self.max_node))]) + else: + result.append(new_graph[new_graph.is_stopped | (new_graph.num_nodes == (self.max_node))]) + + is_continue = (~new_graph.is_stopped) & (new_graph.num_nodes < (self.max_node)) + graph = new_graph[is_continue] + if len(graph) == 0: + break + + self.train(is_training) + + result = self._cat(result) + return result + + def _append(self, data, num_xs, input, mask=None): + if mask is None: + mask = torch.ones_like(num_xs, dtype=torch.bool) + new_num_xs = num_xs + mask + new_num_cum_xs = new_num_xs.cumsum(0) + new_num_x = new_num_cum_xs[-1].item() + new_data = torch.zeros(new_num_x, *data.shape[1:], dtype=data.dtype, device=data.device) + starts = new_num_cum_xs - new_num_xs + ends = starts + num_xs + index = functional.multi_slice_mask(starts, ends, new_num_x) + new_data[index] = data + new_data[~index] = input[mask] + return new_data, new_num_xs + + @torch.no_grad() + def all_stop(self, graph): + if (graph.num_nodes < 2).any(): + graph = graph[graph.num_nodes >= 2] + warnings.warn("Graphs with less than 2 nodes can't be used for stop prediction learning. Dropped") + + label1 = torch.zeros(len(graph), dtype=torch.long, device=self.device) + label2 = torch.zeros_like(label1) + label3 = torch.zeros_like(label1) + return graph, label1, label2, label3, torch.ones(len(graph), dtype=torch.long, device=self.device) + + + @torch.no_grad() + def all_edge(self, graph): + if (graph.num_nodes < 2).any(): + graph = graph[graph.num_nodes >= 2] + warnings.warn("Graphs with less than 2 nodes can't be used for edge generation learning. Dropped") + + lengths = self._valid_edge_prefix_lengths(graph) + + starts, ends, valid = self._all_prefix_slice(graph.num_nodes ** 2, lengths) + + num_keep_dense_edges = ends - starts + num_repeat = len(starts) // len(graph) + graph = graph.repeat(num_repeat) + + # undirected: all upper triangular edge ids are flipped to lower triangular ids + # 1 -> 2, 4 -> 6, 5 -> 7 + node_index = graph.edge_list[:, :2] - graph._offsets.unsqueeze(-1) + node_in, node_out = node_index.t() + node_large = node_index.max(dim=-1)[0] + node_small = node_index.min(dim=-1)[0] + edge_id = node_large ** 2 + (node_in >= node_out) * node_large + node_small + undirected_edge_id = node_large * (node_large + 1) + node_small + + edge_mask = undirected_edge_id < num_keep_dense_edges[graph.edge2graph] + circum_box_size = (num_keep_dense_edges + 1.0).sqrt().ceil().long() + + # check whether we need to add a new node for the current edge + masked_undirected_edge_id = torch.where(edge_mask, undirected_edge_id, -torch.ones(undirected_edge_id.size(), + dtype=torch.long, device=graph.device)) + current_circum_box_size = scatter_max(masked_undirected_edge_id, graph.edge2graph, dim=0)[0] + current_circum_box_size = (current_circum_box_size + 1.0).sqrt().ceil().long() + is_new_node_edge = (circum_box_size > current_circum_box_size).long() + + starts = graph.num_cum_nodes - graph.num_nodes + ends = starts + circum_box_size - is_new_node_edge + node_mask = functional.multi_slice_mask(starts, ends, graph.num_node) + # compact nodes so that succeeding nodes won't affect graph pooling + new_graph = graph.edge_mask(edge_mask).node_mask(node_mask, compact=True) + + positive_edge = edge_id == num_keep_dense_edges[graph.edge2graph] + positive_graph = scatter_add(positive_edge.long(), graph.edge2graph, dim=0, dim_size=len(graph)).bool() + # default: non-edge + target = (self.model.num_relation) * torch.ones(graph.batch_size, dtype=torch.long, device=graph.device) + target[positive_graph] = graph.edge_list[:, 2][positive_edge] + + # node_in > node_out + node_in = circum_box_size - 1 + node_out = num_keep_dense_edges - node_in * circum_box_size + # if we need to add a new node, what will be its atomid? + new_node_atomid = self.atom2id[graph.atom_type[starts +node_in]] + + # keep only the positive graph, as we will add an edge at each step + new_graph = new_graph[positive_graph] + target = target[positive_graph] + node_in = node_in[positive_graph] + node_out = node_out[positive_graph] + is_new_node_edge = is_new_node_edge[positive_graph] + new_node_atomid = new_node_atomid[positive_graph] + + node_in_extend = new_graph.num_nodes + new_node_atomid + node_in_final = torch.where(is_new_node_edge == 0, node_in, node_in_extend) + + return new_graph, node_out, node_in_final, target, torch.zeros_like(node_out) + + @torch.no_grad() + def _all_prefix_slice(self, num_xs, lengths=None): + # extract a bunch of slices that correspond to the following num_repeat * n masks + # ------ repeat 0 ----- + # graphs[0]: [0, 0, ..., 0] + # ... + # graphs[-1]: [0, 0, ..., 0] + # ------ repeat 1 ----- + # graphs[0]: [1, 0, ..., 0] + # ... + # graphs[-1]: [1, 0, ..., 0] + # ... + # ------ repeat -1 ----- + # graphs[0]: [1, ..., 1, 0] + # ... + # graphs[-1]: [1, ..., 1, 0] + num_cum_xs = num_xs.cumsum(0) + starts = num_cum_xs - num_xs + if lengths is None: + num_max_x = num_xs.max().item() + lengths = torch.arange(num_max_x, device=num_xs.device) + + pack_offsets = torch.arange(len(lengths), device=num_xs.device) * num_cum_xs[-1] + # starts, lengths, ends: (num_repeat, num_graph) + starts = starts.unsqueeze(0) + pack_offsets.unsqueeze(-1) + valid = lengths.unsqueeze(-1) <= num_xs.unsqueeze(0) - 1 + lengths = torch.min(lengths.unsqueeze(-1), num_xs.unsqueeze(0) - 1).clamp(0) + ends = starts + lengths + + starts = starts.flatten() + ends = ends.flatten() + valid = valid.flatten() + + return starts, ends, valid + + @torch.no_grad() + def _valid_edge_prefix_lengths(self, graph): + num_max_node = graph.num_nodes.max().item() + # edge id in an adjacency (snake pattern) + # in + # o 0 1 4 + # u 2 3 5 + # t 6 7 8 + lengths = torch.arange(num_max_node ** 2, device=graph.device) + circum_box_size = (lengths + 1.0).sqrt().ceil().long() + # only keep lengths that ends in the lower triangular part of adjacency matrix + lengths = lengths[lengths >= circum_box_size * (circum_box_size - 1)] + # lengths: [0, 2, 3, 6, 7, 8, ...] + # num_node2length_idx: [0, 1, 4, 6, ...] + # num_edge_unrolls + # 0 + # 1 0 + # 2 1 0 + num_edge_unrolls = (lengths + 1.0).sqrt().ceil().long() ** 2 - lengths - 1 + # num_edge_unrolls: [0, 1, 0, 2, 1, 0, ...] + # remove lengths that unroll too much. they always lead to empty targets. + lengths = lengths[(num_edge_unrolls <= self.max_edge_unroll) & (num_edge_unrolls > 0)] + + return lengths + + def _cat(self, graphs): + for i, graph in enumerate(graphs): + if not isinstance(graph, data.PackedGraph): + graphs[i] = graph.pack([graph]) + + edge_list = torch.cat([graph.edge_list for graph in graphs]) + pack_num_nodes = torch.stack([graph.num_node for graph in graphs]) + pack_num_edges = torch.stack([graph.num_edge for graph in graphs]) + pack_num_cum_edges = pack_num_edges.cumsum(0) + graph_index = pack_num_cum_edges < len(edge_list) + pack_offsets = scatter_add(pack_num_nodes[graph_index], pack_num_cum_edges[graph_index], + dim_size=len(edge_list)) + pack_offsets = pack_offsets.cumsum(0) + + edge_list[:, :2] += pack_offsets.unsqueeze(-1) + offsets = torch.cat([graph._offsets for graph in graphs]) + pack_offsets + + edge_weight = torch.cat([graph.edge_weight for graph in graphs]) + num_nodes = torch.cat([graph.num_nodes for graph in graphs]) + num_edges = torch.cat([graph.num_edges for graph in graphs]) + num_relation = graphs[0].num_relation + assert all(graph.num_relation == num_relation for graph in graphs) + + # only keep attributes that exist in all graphs + keys = set(graphs[0].meta_dict.keys()) + for graph in graphs: + keys = keys.intersection(graph.meta_dict.keys()) + + meta_dict = {k: graphs[0].meta_dict[k] for k in keys} + data_dict = {} + for k in keys: + data_dict[k] = torch.cat([graph.data_dict[k] for graph in graphs]) + + return type(graphs[0])(edge_list, edge_weight=edge_weight, + num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) \ No newline at end of file diff --git a/build/lib/torchdrug/tasks/pretrain.py b/build/lib/torchdrug/tasks/pretrain.py new file mode 100644 index 00000000..df974a50 --- /dev/null +++ b/build/lib/torchdrug/tasks/pretrain.py @@ -0,0 +1,584 @@ +import copy +import math + +import torch +from torch import nn +from torch.nn import functional as F +from torch_scatter import scatter_min + +from torchdrug import core, tasks, layers +from torchdrug.data import constant +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("tasks.EdgePrediction") +class EdgePrediction(tasks.Task, core.Configurable): + """ + Edge prediction task proposed in `Inductive Representation Learning on Large Graphs`_. + + .. _Inductive Representation Learning on Large Graphs: + https://arxiv.org/abs/1706.02216 + + Parameters: + model (nn.Module): node representation model + """ + + def __init__(self, model): + super(EdgePrediction, self).__init__() + self.model = model + + def _get_directed(self, graph): + mask = graph.edge_list[:, 0] < graph.edge_list[:, 1] + graph = graph.edge_mask(mask) + return graph + + def predict(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + + output = self.model(graph, graph.node_feature.float(), all_loss, metric) + node_feature = output["node_feature"] + + graph = self._get_directed(graph) + node_in, node_out = graph.edge_list.t()[:2] + neg_index = (torch.rand(2, graph.num_edge, device=self.device) * graph.num_nodes[graph.edge2graph]).long() + neg_index = neg_index + (graph.num_cum_nodes - graph.num_nodes)[graph.edge2graph] + node_in = torch.cat([node_in, neg_index[0]]) + node_out = torch.cat([node_out, neg_index[1]]) + + pred = torch.einsum("bd, bd -> b", node_feature[node_in], node_feature[node_out]) + return pred + + def target(self, batch): + graph = batch["graph"] + target = torch.ones(graph.num_edge, device=self.device) + target[len(target) // 2:] = 0 + return target + + def evaluate(self, pred, target): + metric = {} + accuracy = ((pred > 0) == (target > 0.5)).float().mean() + + name = tasks._get_metric_name("acc") + metric[name] = accuracy + + return metric + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred = self.predict(batch, all_loss, metric) + target = self.target(batch) + loss = F.binary_cross_entropy_with_logits(pred, target) + name = tasks._get_criterion_name("bce") + metric[name] = loss + metric.update(self.evaluate(pred, target)) + + all_loss += loss + + return all_loss, metric + + +@R.register("tasks.AttributeMasking") +class AttributeMasking(tasks.Task, core.Configurable): + """ + Attribute masking proposed in `Strategies for Pre-training Graph Neural Networks`_. + + .. _Strategies for Pre-training Graph Neural Networks: + https://arxiv.org/abs/1905.12265 + + Parameters: + model (nn.Module): node representation model + mask_rate (float, optional): rate of masked nodes + num_mlp_layer (int, optional): number of MLP layers + """ + + def __init__(self, model, mask_rate=0.15, num_mlp_layer=2, graph_construction_model=None): + super(AttributeMasking, self).__init__() + self.model = model + self.mask_rate = mask_rate + self.num_mlp_layer = num_mlp_layer + self.graph_construction_model = graph_construction_model + + def preprocess(self, train_set, valid_set, test_set): + data = train_set[0] + self.view = getattr(data["graph"], "view", "atom") + if hasattr(self.model, "node_output_dim"): + model_output_dim = self.model.node_output_dim + else: + model_output_dim = self.model.output_dim + if self.view == "atom": + num_label = constant.NUM_ATOM + else: + num_label = constant.NUM_AMINO_ACID + self.mlp = layers.MLP(model_output_dim, [model_output_dim] * (self.num_mlp_layer - 1) + [num_label]) + + def predict_and_target(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + if self.graph_construction_model: + graph = self.graph_construction_model.apply_node_layer(graph) + + num_nodes = graph.num_nodes if self.view in ["atom", "node"] else graph.num_residues + num_cum_nodes = num_nodes.cumsum(0) + num_samples = (num_nodes * self.mask_rate).long().clamp(1) + num_sample = num_samples.sum() + sample2graph = torch.repeat_interleave(num_samples) + node_index = (torch.rand(num_sample, device=self.device) * num_nodes[sample2graph]).long() + node_index = node_index + (num_cum_nodes - num_nodes)[sample2graph] + + if self.view == "atom": + target = graph.atom_type[node_index] + input = graph.node_feature.float() + input[node_index] = 0 + else: + target = graph.residue_type[node_index] + with graph.residue(): + graph.residue_feature[node_index] = 0 + graph.residue_type[node_index] = 0 + # Generate masked edge features. Any better implementation? + if self.graph_construction_model: + graph = self.graph_construction_model.apply_edge_layer(graph) + input = graph.residue_feature.float() + + output = self.model(graph, input, all_loss, metric) + if self.view in ["node", "atom"]: + node_feature = output["node_feature"] + else: + node_feature = output.get("residue_feature", output.get("node_feature")) + node_feature = node_feature[node_index] + pred = self.mlp(node_feature) + + return pred, target + + def evaluate(self, pred, target): + metric = {} + accuracy = (pred.argmax(dim=-1) == target).float().mean() + + name = tasks._get_metric_name("acc") + metric[name] = accuracy + + return metric + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred, target = self.predict_and_target(batch, all_loss, metric) + metric.update(self.evaluate(pred, target)) + + loss = F.cross_entropy(pred, target) + name = tasks._get_criterion_name("ce") + metric[name] = loss + + all_loss += loss + + return all_loss, metric + + +@R.register("tasks.ContextPrediction") +class ContextPrediction(tasks.Task, core.Configurable): + """ + Context prediction task proposed in `Strategies for Pre-training Graph Neural Networks`_. + + .. _Strategies for Pre-training Graph Neural Networks: + https://arxiv.org/abs/1905.12265 + + For a given center node, the subgraph is defined as a k-hop neighborhood (inclusive) around the selected node. + The context graph is defined as the surrounding graph structure between r1- (exclusive) and r2-hop (inclusive) + from the center node. Nodes between k- and r1-hop are picked as anchor nodes for the context representation. + + Parameters: + model (nn.Module): node representation model for subgraphs. + context_model (nn.Module, optional): node representation model for context graphs. + By default, use the same architecture as ``model`` without parameter sharing. + k (int, optional): radius for subgraphs + r1 (int, optional): inner radius for context graphs + r2 (int, optional): outer radius for context graphs + readout (nn.Module, optional): readout function over context anchor nodes + num_negative (int, optional): number of negative samples per positive sample + """ + + def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1): + super(ContextPrediction, self).__init__() + self.model = model + self.k = k + self.r1 = r1 + self.r2 = r2 + self.num_negative = num_negative + assert r1 < k < r2 + + if context_model is None: + self.context_model = copy.deepcopy(model) + else: + self.context_model = context_model + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) + + def substruct_and_context(self, graph): + center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long() + center_index = center_index + graph.num_cum_nodes - graph.num_nodes + dist = torch.full((graph.num_node,), self.r2 + 1, dtype=torch.long, device=self.device) + dist[center_index] = 0 + + # single source shortest path + node_in, node_out = graph.edge_list.t()[:2] + for i in range(self.r2): + new_dist = scatter_min(dist[node_in], node_out, dim_size=graph.num_node)[0] + 1 + dist = torch.min(dist, new_dist) + + substruct_mask = dist <= self.k + context_mask = (dist > self.r1) & (dist <= self.r2) + is_center_node = functional.as_mask(center_index, graph.num_node) + is_anchor_node = (dist > self.r1) & (dist <= self.k) + + substruct = graph.clone() + context = graph.clone() + with substruct.node(): + substruct.is_center_node = is_center_node + with context.node(): + context.is_anchor_node = is_anchor_node + + substruct = substruct.subgraph(substruct_mask) + context = context.subgraph(context_mask) + valid = context.num_nodes > 0 + substruct = substruct[valid] + context = context[valid] + + return substruct, context + + def predict_and_target(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + substruct, context = self.substruct_and_context(graph) + anchor = context.subgraph(context.is_anchor_node) + + substruct_output = self.model(substruct, substruct.node_feature.float(), all_loss, metric) + substruct_feature = substruct_output["node_feature"][substruct.is_center_node] + + context_output = self.context_model(context, context.node_feature.float(), all_loss, metric) + anchor_feature = context_output["node_feature"][context.is_anchor_node] + context_feature = self.readout(anchor, anchor_feature) + + shift = torch.arange(self.num_negative, device=self.device) + 1 + neg_index = (torch.arange(len(context), device=self.device).unsqueeze(-1) + shift) % len(context) # (batch_size, num_negative) + context_feature = torch.cat([context_feature.unsqueeze(1), context_feature[neg_index]], dim=1) + substruct_feature = substruct_feature.unsqueeze(1).expand_as(context_feature) + + pred = torch.einsum("bnd, bnd -> bn", substruct_feature, context_feature) + target = torch.zeros_like(pred) + target[:, 0] = 1 + return pred, target + + def evaluate(self, pred, target): + metric = {} + accuracy = ((pred > 0) == (target > 0.5)).float().mean() + + name = tasks._get_metric_name("acc") + metric[name] = accuracy + + return metric + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred, target = self.predict_and_target(batch, all_loss, metric) + metric.update(self.evaluate(pred, target)) + + loss = F.binary_cross_entropy_with_logits(pred, target) + name = tasks._get_criterion_name("bce") + metric[name] = loss + + all_loss += loss + + return all_loss, metric + + +@R.register("tasks.DistancePrediction") +class DistancePrediction(tasks.Task, core.Configurable): + """ + Pairwise spatial distance prediction task proposed in + `Protein Representation Learning by Geometric Structure Pretraining`_. + + .. _Protein Representation Learning by Geometric Structure Pretraining: + https://arxiv.org/pdf/2203.06125.pdf + + Randomly select some edges and predict the lengths of the edges using the representations of two nodes. + The selected edges are removed from the input graph to prevent trivial solutions. + + Parameters: + model (nn.Module): node representation model + num_sample (int, optional): number of edges selected from each graph + num_mlp_layer (int, optional): number of MLP layers in distance predictor + graph_construction_model (nn.Module, optional): graph construction model + """ + + def __init__(self, model, num_sample=256, num_mlp_layer=2, graph_construction_model=None): + super(DistancePrediction, self).__init__() + self.model = model + self.num_sample = num_sample + self.num_mlp_layer = num_mlp_layer + self.graph_construction_model = graph_construction_model + + self.mlp = layers.MLP(2 * model.output_dim, [model.output_dim] * (num_mlp_layer - 1) + [1]) + + def predict_and_target(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + if self.graph_construction_model: + graph = self.graph_construction_model(graph) + + node_in, node_out = graph.edge_list[:, :2].t() + indices = torch.arange(graph.num_edge, device=self.device) + indices = functional.variadic_sample(indices, graph.num_edges, self.num_sample).flatten(-2, -1) + node_i = node_in[indices] + node_j = node_out[indices] + graph = graph.edge_mask(~functional.as_mask(indices, graph.num_edge)) + + # Calculate distance + target = (graph.node_position[node_i] - graph.node_position[node_j]).norm(p=2, dim=-1) + + output = self.model(graph, graph.node_feature.float() , all_loss, metric)["node_feature"] + node_feature = torch.cat([output[node_i], output[node_j]], dim=-1) + pred = self.mlp(node_feature).squeeze(-1) + + return pred, target + + def evaluate(self, pred, target): + metric = {} + mse = F.mse_loss(pred, target) + + name = tasks._get_metric_name("mse") + metric[name] = mse + + return metric + + def forward(self, batch): + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred, target = self.predict_and_target(batch, all_loss, metric) + metric.update(self.evaluate(pred, target)) + + loss = F.mse_loss(pred, target) + name = tasks._get_criterion_name("mse") + metric[name] = loss + + all_loss += loss + + return all_loss, metric + + +@R.register("tasks.AnglePrediction") +class AnglePrediction(tasks.Task, core.Configurable): + """ + Angle prediction task proposed in `Protein Representation Learning by Geometric Structure Pretraining`_. + + .. _Protein Representation Learning by Geometric Structure Pretraining: + https://arxiv.org/pdf/2203.06125.pdf + + Randomly select pairs of adjacent edges and predict the angles between them using the representations of three + nodes. The selected edges are removed from the input graph to prevent trivial solutions. + + Parameters: + model (nn.Module): node representation model + num_sample (int, optional): number of edge pairs selected from each graph + num_class (int, optional): number of classes to discretize the angles + num_mlp_layer (int, optional): number of MLP layers in angle predictor + graph_construction_model (nn.Module, optional): graph construction model + """ + + def __init__(self, model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None): + super(AnglePrediction, self).__init__() + self.model = model + self.num_sample = num_sample + self.num_mlp_layer = num_mlp_layer + self.graph_construction_model = graph_construction_model + + boundary = torch.arange(0, math.pi, math.pi / num_class) + self.register_buffer("boundary", boundary) + + self.mlp = layers.MLP(3 * model.output_dim, [model.output_dim] * (num_mlp_layer - 1) + [num_class]) + + def predict_and_target(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + if self.graph_construction_model: + graph = self.graph_construction_model(graph) + + node_in, node_out = graph.edge_list[:, :2].t() + + line_graph = graph.line_graph() + edge_in, edge_out = line_graph.edge_list[:, :2].t() + is_self_loop1 = (edge_in == edge_out) + is_self_loop2 = (node_in[edge_in] == node_out[edge_out]) + is_remove = is_self_loop1 | is_self_loop2 + line_graph = line_graph.edge_mask(~is_remove) + edge_in, edge_out = line_graph.edge_list[:, :2].t() + # (k->j) - (j->i) + node_i = node_out[edge_out] + node_j = node_in[edge_out] + node_k = node_in[edge_in] + indices = torch.arange(line_graph.num_edge, device=self.device) + indices = functional.variadic_sample(indices, line_graph.num_edges, self.num_sample).flatten(-2, -1) + node_i = node_i[indices] + node_j = node_j[indices] + node_k = node_k[indices] + + mask = torch.ones((graph.num_edge,), device=graph.device, dtype=torch.bool) + mask[edge_out[indices]] = 0 + mask[edge_in[indices]] = 0 + graph = graph.edge_mask(mask) + + # Calculate angles + vector1 = graph.node_position[node_i] - graph.node_position[node_j] + vector2 = graph.node_position[node_k] - graph.node_position[node_j] + x = (vector1 * vector2).sum(dim=-1) + y = torch.cross(vector1, vector2).norm(dim=-1) + angle = torch.atan2(y, x) + target = torch.bucketize(angle, self.boundary, right=True) - 1 + + output = self.model(graph, graph.node_feature.float() , all_loss, metric)["node_feature"] + node_feature = torch.cat([output[node_i], output[node_j], output[node_k]], dim=-1) + pred = self.mlp(node_feature) + + return pred, target + + def evaluate(self, pred, target): + metric = {} + accuracy = (pred.argmax(dim=-1) == target).float().mean() + + name = tasks._get_metric_name("acc") + metric[name] = accuracy + + return metric + + def forward(self, batch): + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred, target = self.predict_and_target(batch, all_loss, metric) + metric.update(self.evaluate(pred, target)) + + loss = F.cross_entropy(pred, target) + name = tasks._get_criterion_name("ce") + metric[name] = loss + + all_loss += loss + + return all_loss, metric + + +@R.register("tasks.DihedralPrediction") +class DihedralPrediction(tasks.Task, core.Configurable): + """ + Dihedral prediction task proposed in `Protein Representation Learning by Geometric Structure Pretraining`_. + + .. _Protein Representation Learning by Geometric Structure Pretraining: + https://arxiv.org/pdf/2203.06125.pdf + + Randomly select three consecutive edges and predict the dihedrals among them using the representations of four + nodes. The selected edges are removed from the input graph to prevent trivial solutions. + + Parameters: + model (nn.Module): node representation model + num_sample (int, optional): number of edge triplets selected from each graph + num_class (int, optional): number of classes for discretizing the dihedrals + num_mlp_layer (int, optional): number of MLP layers in dihedral angle predictor + graph_construction_model (nn.Module, optional): graph construction model + """ + + def __init__(self, model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None): + super(DihedralPrediction, self).__init__() + self.model = model + self.num_sample = num_sample + self.num_mlp_layer = num_mlp_layer + self.graph_construction_model = graph_construction_model + + boundary = torch.arange(0, math.pi, math.pi / num_class) + self.register_buffer("boundary", boundary) + + self.mlp = layers.MLP(4 * model.output_dim, [model.output_dim] * (num_mlp_layer - 1) + [num_class]) + + def predict_and_target(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + if self.graph_construction_model: + graph = self.graph_construction_model(graph) + + node_in, node_out = graph.edge_list[:, :2].t() + line_graph = graph.line_graph() + edge_in, edge_out = line_graph.edge_list[:, :2].t() + is_self_loop1 = (edge_in == edge_out) + is_self_loop2 = (node_in[edge_in] == node_out[edge_out]) + is_remove = is_self_loop1 | is_self_loop2 + line_graph = line_graph.edge_mask(~is_remove) + edge_in, edge_out = line_graph.edge_list[:, :2].t() + + line2_graph = line_graph.line_graph() + edge2_in, edge2_out = line2_graph.edge_list.t()[:2] + is_self_loop1 = (edge2_in == edge2_out) + is_self_loop2 = (edge_in[edge2_in] == edge_out[edge2_out]) + is_remove = is_self_loop1 | is_self_loop2 + line2_graph = line2_graph.edge_mask(~is_remove) + edge2_in, edge2_out = line2_graph.edge_list[:, :2].t() + # (k->t->j) - (t->j->i) + node_i = node_out[edge_out[edge2_out]] + node_j = node_in[edge_out[edge2_out]] + node_t = node_in[edge_out[edge2_in]] + node_k = node_in[edge_in[edge2_in]] + indices = torch.arange(line2_graph.num_edge, device=self.device) + indices = functional.variadic_sample(indices, line2_graph.num_edges, self.num_sample).flatten(-2, -1) + node_i = node_i[indices] + node_j = node_j[indices] + node_t = node_t[indices] + node_k = node_k[indices] + mask = torch.ones((graph.num_edge,), device=graph.device, dtype=torch.bool) + mask[edge_out[edge2_out[indices]]] = 0 + mask[edge_out[edge2_in[indices]]] = 0 + mask[edge_in[edge2_in[indices]]] = 0 + graph = graph.edge_mask(mask) + + v_ctr = graph.node_position[node_t] - graph.node_position[node_j] # (A, 3) + v1 = graph.node_position[node_i] - graph.node_position[node_j] + v2 = graph.node_position[node_k] - graph.node_position[node_t] + n1 = torch.cross(v_ctr, v1, dim=-1) # Normal vectors of the two planes + n2 = torch.cross(v_ctr, v2, dim=-1) + a = (n1 * n2).sum(dim=-1) + b = torch.cross(n1, n2).norm(dim=-1) + dihedral = torch.atan2(b, a) + target = torch.bucketize(dihedral, self.boundary, right=True) - 1 + + output = self.model(graph, graph.node_feature.float() , all_loss, metric)["node_feature"] + node_feature = torch.cat([output[node_i], output[node_j], output[node_k], output[node_t]], dim=-1) + pred = self.mlp(node_feature) + + return pred, target + + def evaluate(self, pred, target): + metric = {} + accuracy = (pred.argmax(dim=-1) == target).float().mean() + + name = tasks._get_metric_name("acc") + metric[name] = accuracy + + return metric + + def forward(self, batch): + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred, target = self.predict_and_target(batch, all_loss, metric) + metric.update(self.evaluate(pred, target)) + + loss = F.cross_entropy(pred, target) + name = tasks._get_criterion_name("ce") + metric[name] = loss + + all_loss += loss + + return all_loss, metric diff --git a/build/lib/torchdrug/tasks/property_prediction.py b/build/lib/torchdrug/tasks/property_prediction.py new file mode 100644 index 00000000..fb60fe2d --- /dev/null +++ b/build/lib/torchdrug/tasks/property_prediction.py @@ -0,0 +1,569 @@ +import math +from collections import defaultdict + +import torch +from torch import nn +from torch.nn import functional as F + +from torchdrug import core, layers, tasks, metrics, utils +from torchdrug.core import Registry as R +from torchdrug.layers import functional + + +@R.register("tasks.PropertyPrediction") +class PropertyPrediction(tasks.Task, core.Configurable): + """ + Graph / molecule / protein property prediction task. + + This class is also compatible with semi-supervised learning. + + Parameters: + model (nn.Module): graph representation model + task (str, list or dict, optional): training task(s). + For dict, the keys are tasks and the values are the corresponding weights. + criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values + are the corresponding weights. Available criterions are ``mse``, ``bce`` and ``ce``. + metric (str or list of str, optional): metric(s). + Available metrics are ``mae``, ``rmse``, ``auprc`` and ``auroc``. + num_mlp_layer (int, optional): number of layers in mlp prediction head + normalization (bool, optional): whether to normalize the target + num_class (int, optional): number of classes + mlp_batch_norm (bool, optional): apply batch normalization in mlp or not + mlp_dropout (float, optional): dropout in mlp + graph_construction_model (nn.Module, optional): graph construction model + verbose (int, optional): output verbose level + """ + + eps = 1e-10 + _option_members = {"task", "criterion", "metric"} + + def __init__(self, model, task=(), criterion="mse", metric=("mae", "rmse"), num_mlp_layer=1, + normalization=True, num_class=None, mlp_batch_norm=False, mlp_dropout=0, + graph_construction_model=None, verbose=0): + super(PropertyPrediction, self).__init__() + self.model = model + self.task = task + self.criterion = criterion + self.metric = metric + self.num_mlp_layer = num_mlp_layer + # For classification tasks, we disable normalization tricks. + self.normalization = normalization and ("ce" not in criterion) and ("bce" not in criterion) + self.num_class = (num_class,) if isinstance(num_class, int) else num_class + self.mlp_batch_norm = mlp_batch_norm + self.mlp_dropout = mlp_dropout + self.graph_construction_model = graph_construction_model + self.verbose = verbose + + def preprocess(self, train_set, valid_set, test_set): + """ + Compute the mean and derivation for each task on the training set. + """ + values = defaultdict(list) + for sample in train_set: + if not sample.get("labeled", True): + continue + for task in self.task: + if not math.isnan(sample[task]): + values[task].append(sample[task]) + mean = [] + std = [] + weight = [] + num_class = [] + for task, w in self.task.items(): + value = torch.tensor(values[task]) + mean.append(value.float().mean()) + std.append(value.float().std()) + weight.append(w) + if value.ndim > 1: + num_class.append(value.shape[1]) + elif value.dtype == torch.long: + task_class = value.max().item() + if task_class == 1 and "bce" in self.criterion: + num_class.append(1) + else: + num_class.append(task_class + 1) + else: + num_class.append(1) + + self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float)) + self.register_buffer("std", torch.as_tensor(std, dtype=torch.float)) + self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float)) + self.num_class = self.num_class or num_class + + hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) + self.mlp = layers.MLP(self.model.output_dim, hidden_dims + [sum(self.num_class)], + batch_norm=self.mlp_batch_norm, dropout=self.mlp_dropout) + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred = self.predict(batch, all_loss, metric) + + if all([t not in batch for t in self.task]): + # unlabeled data + return all_loss, metric + + target = self.target(batch) + labeled = ~torch.isnan(target) + target[~labeled] = 0 + + for criterion, weight in self.criterion.items(): + if criterion == "mse": + if self.normalization: + loss = F.mse_loss((pred - self.mean) / self.std, (target - self.mean) / self.std, reduction="none") + else: + loss = F.mse_loss(pred, target, reduction="none") + elif criterion == "bce": + loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") + elif criterion == "ce": + loss = F.cross_entropy(pred, target.long().squeeze(-1), reduction="none").unsqueeze(-1) + else: + raise ValueError("Unknown criterion `%s`" % criterion) + loss = functional.masked_mean(loss, labeled, dim=0) + + name = tasks._get_criterion_name(criterion) + if self.verbose > 0: + for t, l in zip(self.task, loss): + metric["%s [%s]" % (name, t)] = l + loss = (loss * self.weight).sum() / self.weight.sum() + metric[name] = loss + all_loss += loss * weight + + return all_loss, metric + + def predict(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + if self.graph_construction_model: + graph = self.graph_construction_model(graph) + output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric) + pred = self.mlp(output["graph_feature"]) + if self.normalization: + pred = pred * self.std + self.mean + return pred + + def target(self, batch): + target = torch.stack([batch[t].float() for t in self.task], dim=-1) + labeled = batch.get("labeled", torch.ones(len(target), dtype=torch.bool, device=target.device)) + target[~labeled] = math.nan + return target + + def evaluate(self, pred, target): + labeled = ~torch.isnan(target) + + metric = {} + for _metric in self.metric: + if _metric == "mae": + score = F.l1_loss(pred, target, reduction="none") + score = functional.masked_mean(score, labeled, dim=0) + elif _metric == "rmse": + score = F.mse_loss(pred, target, reduction="none") + score = functional.masked_mean(score, labeled, dim=0).sqrt() + elif _metric == "acc": + score = [] + num_class = 0 + for i, cur_num_class in enumerate(self.num_class): + _pred = pred[:, num_class:num_class + cur_num_class] + _target = target[:, i] + _labeled = labeled[:, i] + _score = metrics.accuracy(_pred[_labeled], _target[_labeled].long()) + score.append(_score) + num_class += cur_num_class + score = torch.stack(score) + elif _metric == "mcc": + score = [] + num_class = 0 + for i, cur_num_class in enumerate(self.num_class): + _pred = pred[:, num_class:num_class + cur_num_class] + _target = target[:, i] + _labeled = labeled[:, i] + _score = metrics.matthews_corrcoef(_pred[_labeled], _target[_labeled].long()) + score.append(_score) + num_class += cur_num_class + score = torch.stack(score) + elif _metric == "auroc": + score = [] + for _pred, _target, _labeled in zip(pred.t(), target.long().t(), labeled.t()): + _score = metrics.area_under_roc(_pred[_labeled], _target[_labeled]) + score.append(_score) + score = torch.stack(score) + elif _metric == "auprc": + score = [] + for _pred, _target, _labeled in zip(pred.t(), target.long().t(), labeled.t()): + _score = metrics.area_under_prc(_pred[_labeled], _target[_labeled]) + score.append(_score) + score = torch.stack(score) + elif _metric == "r2": + score = [] + for _pred, _target, _labeled in zip(pred.t(), target.t(), labeled.t()): + _score = metrics.r2(_pred[_labeled], _target[_labeled]) + score.append(_score) + score = torch.stack(score) + elif _metric == "spearmanr": + score = [] + for _pred, _target, _labeled in zip(pred.t(), target.t(), labeled.t()): + _score = metrics.spearmanr(_pred[_labeled], _target[_labeled]) + score.append(_score) + score = torch.stack(score) + elif _metric == "pearsonr": + score = [] + for _pred, _target, _labeled in zip(pred.t(), target.t(), labeled.t()): + _score = metrics.pearsonr(_pred[_labeled], _target[_labeled]) + score.append(_score) + score = torch.stack(score) + else: + raise ValueError("Unknown metric `%s`" % _metric) + + name = tasks._get_metric_name(_metric) + for t, s in zip(self.task, score): + metric["%s [%s]" % (name, t)] = s + + return metric + + +@R.register("tasks.MultipleBinaryClassification") +class MultipleBinaryClassification(tasks.Task, core.Configurable): + """ + Multiple binary classification task for graphs / molecules / proteins. + + Parameters: + model (nn.Module): graph representation model + task (list of int, optional): training task id(s). + criterion (list or dict, optional): training criterion(s). For dict, the keys are criterions and the values + are the corresponding weights. Available criterions are ``bce``. + metric (str or list of str, optional): metric(s). + Available metrics are ``auroc@macro``, ``auprc@macro``, ``auroc@micro``, ``auprc@micro`` and ``f1_max``. + num_mlp_layer (int, optional): number of layers in the MLP prediction head + normalization (bool, optional): whether to normalize the target + reweight (bool, optional): whether to re-weight tasks according to the number of positive samples + graph_construction_model (nn.Module, optional): graph construction model + verbose (int, optional): output verbose level + """ + + eps = 1e-10 + _option_members = {"criterion", "metric"} + + def __init__(self, model, task=(), criterion="bce", metric=("auprc@micro", "f1_max"), num_mlp_layer=1, + normalization=True, reweight=False, graph_construction_model=None, verbose=0): + super(MultipleBinaryClassification, self).__init__() + self.model = model + self.task = task + self.register_buffer("task_indices", torch.LongTensor(task)) + self.criterion = criterion + self.metric = metric + self.num_mlp_layer = num_mlp_layer + self.normalization = normalization + self.reweight = reweight + self.graph_construction_model = graph_construction_model + self.verbose = verbose + + hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) + self.mlp = layers.MLP(self.model.output_dim, hidden_dims + [len(task)]) + + def preprocess(self, train_set, valid_set, test_set): + """ + Compute the weight for each task on the training set. + """ + values = [] + for data in train_set: + values.append(data["targets"][self.task_indices]) + values = torch.stack(values, dim=0) + + if self.reweight: + num_positive = values.sum(dim=0) + weight = (num_positive.mean() / num_positive).clamp(1, 10) + else: + weight = torch.ones(len(self.task), dtype=torch.float) + + self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float)) + + def forward(self, batch): + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred = self.predict(batch, all_loss, metric) + target = self.target(batch) + + for criterion, weight in self.criterion.items(): + if criterion == "bce": + loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") + else: + raise ValueError("Unknown criterion `%s`" % criterion) + loss = loss.mean(dim=0) + loss = (loss * self.weight).sum() / self.weight.sum() + + name = tasks._get_criterion_name(criterion) + metric[name] = loss + all_loss += loss * weight + + return all_loss, metric + + def predict(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + if self.graph_construction_model: + graph = self.graph_construction_model(graph) + output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric) + pred = self.mlp(output["graph_feature"]) + return pred + + def target(self, batch): + target = batch["targets"][:, self.task_indices] + return target + + def evaluate(self, pred, target): + metric = {} + for _metric in self.metric: + if _metric == "auroc@micro": + score = metrics.area_under_roc(pred.flatten(), target.long().flatten()) + elif _metric == "auroc@macro": + score = metrics.variadic_area_under_roc(pred, target.long(), dim=0).mean() + elif _metric == "auprc@micro": + score = metrics.area_under_prc(pred.flatten(), target.long().flatten()) + elif _metric == "auprc@macro": + score = metrics.variadic_area_under_prc(pred, target.long(), dim=0).mean() + elif _metric == "f1_max": + score = metrics.f1_max(pred, target) + else: + raise ValueError("Unknown criterion `%s`" % _metric) + + name = tasks._get_metric_name(_metric) + metric[name] = score + + return metric + + +@R.register("tasks.NodePropertyPrediction") +class NodePropertyPrediction(tasks.Task, core.Configurable): + """ + Node / atom / residue property prediction task. + + Parameters: + model (nn.Module): graph representation model + criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values + are the corresponding weights. Available criterions are ``mse``, ``bce`` and ``ce``. + metric (str or list of str, optional): metric(s). + Available metrics are ``mae``, ``rmse``, ``auprc`` and ``auroc``. + num_mlp_layer (int, optional): number of layers in mlp prediction head + normalization (bool, optional): whether to normalize the target + Available entities are ``node``, ``atom`` and ``residue``. + num_class (int, optional): number of classes + verbose (int, optional): output verbose level + """ + + _option_members = {"criterion", "metric"} + + def __init__(self, model, criterion="bce", metric=("macro_auprc", "macro_auroc"), num_mlp_layer=1, + normalization=True, num_class=None, verbose=0): + super(NodePropertyPrediction, self).__init__() + self.model = model + self.criterion = criterion + self.metric = metric + # For classification tasks, we disable normalization tricks. + self.normalization = normalization and ("ce" not in criterion) and ("bce" not in criterion) + self.num_mlp_layer = num_mlp_layer + self.num_class = num_class + self.verbose = verbose + + def preprocess(self, train_set, valid_set, test_set): + """ + Compute the mean and derivation on the training set. + """ + self.view = getattr(train_set[0]["graph"], "view", "atom") + values = torch.cat([data["graph"].target for data in train_set]) + mean = values.float().mean() + std = values.float().std() + if values.dtype == torch.long: + num_class = values.max().item() + if num_class > 1 or "bce" not in self.criterion: + num_class += 1 + else: + num_class = 1 + + self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float)) + self.register_buffer("std", torch.as_tensor(std, dtype=torch.float)) + self.num_class = self.num_class or num_class + + if hasattr(self.model, "node_output_dim"): + model_output_dim = self.model.node_output_dim + else: + model_output_dim = self.model.output_dim + hidden_dims = [model_output_dim] * (self.num_mlp_layer - 1) + self.mlp = layers.MLP(model_output_dim, hidden_dims + [self.num_class]) + + def predict(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric) + if self.view in ["node", "atom"]: + output_feature = output["node_feature"] + else: + output_feature = output.get("residue_feature", output.get("node_feature")) + pred = self.mlp(output_feature) + if self.normalization: + pred = pred * self.std + self.mean + return pred + + def target(self, batch): + size = batch["graph"].num_nodes if self.view in ["node", "atom"] else batch["graph"].num_residues + return { + "label": batch["graph"].target, + "mask": batch["graph"].mask, + "size": size + } + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred, target = self.predict_and_target(batch, all_loss, metric) + labeled = ~torch.isnan(target["label"]) & target["mask"] + + for criterion, weight in self.criterion.items(): + if criterion == "mse": + if self.normalization: + loss = F.mse_loss((pred - self.mean) / self.std, (target - self.mean) / self.std, reduction="none") + else: + loss = F.mse_loss(pred, target, reduction="none") + elif criterion == "bce": + loss = F.binary_cross_entropy_with_logits(pred, target["label"].float(), reduction="none") + elif criterion == "ce": + loss = F.cross_entropy(pred, target["label"], reduction="none") + else: + raise ValueError("Unknown criterion `%s`" % criterion) + loss = functional.masked_mean(loss, labeled, dim=0) + + name = tasks._get_criterion_name(criterion) + metric[name] = loss + all_loss += loss * weight + + all_loss += loss + + return all_loss, metric + + def evaluate(self, pred, target): + metric = {} + _target = target["label"] + _labeled = ~torch.isnan(_target) & target["mask"] + _size = functional.variadic_sum(_labeled.long(), target["size"]) + for _metric in self.metric: + if _metric == "micro_acc": + score = metrics.accuracy(pred[_labeled], _target[_labeled].long()) + elif metric == "micro_auroc": + score = metrics.area_under_roc(pred[_labeled], _target[_labeled]) + elif metric == "micro_auprc": + score = metrics.area_under_prc(pred[_labeled], _target[_labeled]) + elif _metric == "macro_auroc": + score = metrics.variadic_area_under_roc(pred[_labeled], _target[_labeled], _size).mean() + elif _metric == "macro_auprc": + score = metrics.variadic_area_under_prc(pred[_labeled], _target[_labeled], _size).mean() + elif _metric == "macro_acc": + score = pred[_labeled].argmax(-1) == _target[_labeled] + score = functional.variadic_mean(score.float(), _size).mean() + else: + raise ValueError("Unknown criterion `%s`" % _metric) + + name = tasks._get_metric_name(_metric) + metric[name] = score + + return metric + + +@R.register("tasks.InteractionPrediction") +@utils.copy_args(PropertyPrediction, ignore=("graph_construction_model",)) +class InteractionPrediction(PropertyPrediction): + """ + Predict the interaction property of graph pairs. + + Parameters: + model (nn.Module): graph representation model + model2 (nn.Module, optional): graph representation model for the second item. If ``None``, use tied-weight + model for the second item. + **kwargs + """ + + def __init__(self, model, model2=None, **kwargs): + super(InteractionPrediction, self).__init__(model, **kwargs) + self.model2 = model2 or model + + def preprocess(self, train_set, valid_set, test_set): + """ + Compute the mean and derivation for each task on the training set. + """ + values = defaultdict(list) + for sample in train_set: + if not sample.get("labeled", True): + continue + for task in self.task: + if not math.isnan(sample[task]): + values[task].append(sample[task]) + mean = [] + std = [] + weight = [] + num_class = [] + for task, w in self.task.items(): + value = torch.tensor(values[task]) + mean.append(value.float().mean()) + std.append(value.float().std()) + weight.append(w) + if value.ndim > 1: + num_class.append(value.shape[1]) + elif value.dtype == torch.long: + task_class = value.max().item() + if task_class == 1 and "bce" in self.criterion: + num_class.append(1) + else: + num_class.append(task_class + 1) + else: + num_class.append(1) + + self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float)) + self.register_buffer("std", torch.as_tensor(std, dtype=torch.float)) + self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float)) + self.num_class = self.num_class or num_class + + hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) + self.mlp = layers.MLP(self.model.output_dim + self.model2.output_dim, hidden_dims + [sum(self.num_class)]) + + def predict(self, batch, all_loss=None, metric=None): + graph1 = batch["graph1"] + output1 = self.model(graph1, graph1.node_feature.float(), all_loss=all_loss, metric=metric) + graph2 = batch["graph2"] + output2 = self.model2(graph2, graph2.node_feature.float(), all_loss=all_loss, metric=metric) + pred = self.mlp(torch.cat([output1["graph_feature"], output2["graph_feature"]], dim=-1)) + if self.normalization: + pred = pred * self.std + self.mean + return pred + + +@R.register("tasks.Unsupervised") +class Unsupervised(nn.Module, core.Configurable): + """ + Wrapper task for unsupervised learning. + + The unsupervised loss should be computed by the model. + + Parameters: + model (nn.Module): any model + """ + + def __init__(self, model, graph_construction_model=None): + super(Unsupervised, self).__init__() + self.model = model + self.graph_construction_model = graph_construction_model + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred = self.predict(batch, all_loss, metric) + + return all_loss, metric + + def predict(self, batch, all_loss=None, metric=None): + graph = batch["graph"] + if self.graph_construction_model: + graph = self.graph_construction_model(graph) + pred = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric) + return pred diff --git a/build/lib/torchdrug/tasks/reasoning.py b/build/lib/torchdrug/tasks/reasoning.py new file mode 100644 index 00000000..802c44c0 --- /dev/null +++ b/build/lib/torchdrug/tasks/reasoning.py @@ -0,0 +1,253 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import data as torch_data + +from torchdrug import core, tasks +from torchdrug.layers import functional +from torchdrug.core import Registry as R + + +@R.register("tasks.KnowledgeGraphCompletion") +class KnowledgeGraphCompletion(tasks.Task, core.Configurable): + """ + Knowledge graph completion task. + + This class provides routines for the family of knowledge graph embedding models. + + Parameters: + model (nn.Module): knowledge graph completion model + criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values + are the corresponding weights. Available criterions are ``bce``, ``ce`` and ``ranking``. + metric (str or list of str, optional): metric(s). Available metrics are ``mr``, ``mrr`` and ``hits@K``. + num_negative (int, optional): number of negative samples per positive sample + margin (float, optional): margin in ranking criterion + adversarial_temperature (float, optional): temperature for self-adversarial negative sampling. + Set ``0`` to disable self-adversarial negative sampling. + strict_negative (bool, optional): use strict negative sampling or not + fact_ratio (float, optional): split the training set into facts and labels. + Set ``None`` to use the whole training set as both facts and labels. + sample_weight (bool, optional): whether to down-weight triplets from entities of large degrees + filtered_ranking (bool, optional): use filtered or unfiltered ranking for evaluation + full_batch_eval (bool, optional): whether to feed test negative samples by full batch or mini batch. + Full batch speeds up evaluation significantly, but may cause OOM problems for some models and datasets. + """ + _option_members = {"criterion", "metric"} + + def __init__(self, model, criterion="bce", metric=("mr", "mrr", "hits@1", "hits@3", "hits@10"), + num_negative=128, margin=6, adversarial_temperature=0, strict_negative=True, fact_ratio=None, + sample_weight=True, filtered_ranking=True, full_batch_eval=False): + super(KnowledgeGraphCompletion, self).__init__() + self.model = model + self.criterion = criterion + self.metric = metric + self.num_negative = num_negative + self.margin = margin + self.adversarial_temperature = adversarial_temperature + self.strict_negative = strict_negative + self.fact_ratio = fact_ratio + self.sample_weight = sample_weight + self.filtered_ranking = filtered_ranking + self.full_batch_eval = full_batch_eval + + def preprocess(self, train_set, valid_set, test_set): + if isinstance(train_set, torch_data.Subset): + dataset = train_set.dataset + else: + dataset = train_set + self.num_entity = dataset.num_entity + self.num_relation = dataset.num_relation + self.register_buffer("graph", dataset.graph) + fact_mask = torch.ones(len(dataset), dtype=torch.bool) + fact_mask[valid_set.indices] = 0 + fact_mask[test_set.indices] = 0 + if self.fact_ratio: + length = int(len(train_set) * self.fact_ratio) + index = torch.randperm(len(train_set))[length:] + train_indices = torch.tensor(train_set.indices) + fact_mask[train_indices[index]] = 0 + train_set = torch_data.Subset(train_set, index) + self.register_buffer("fact_graph", dataset.graph.edge_mask(fact_mask)) + + if self.sample_weight: + degree_hr = torch.zeros(self.num_entity, self.num_relation, dtype=torch.long) + degree_tr = torch.zeros(self.num_entity, self.num_relation, dtype=torch.long) + for h, t, r in train_set: + degree_hr[h, r] += 1 + degree_tr[t, r] += 1 + self.register_buffer("degree_hr", degree_hr) + self.register_buffer("degree_tr", degree_tr) + + return train_set, valid_set, test_set + + def forward(self, batch, all_loss=None, metric=None): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred = self.predict(batch, all_loss, metric) + pos_h_index, pos_t_index, pos_r_index = batch.t() + + for criterion, weight in self.criterion.items(): + if criterion == "bce": + target = torch.zeros_like(pred) + target[:, 0] = 1 + loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") + + neg_weight = torch.ones_like(pred) + if self.adversarial_temperature > 0: + with torch.no_grad(): + neg_weight[:, 1:] = F.softmax(pred[:, 1:] / self.adversarial_temperature, dim=-1) + else: + neg_weight[:, 1:] = 1 / self.num_negative + loss = (loss * neg_weight).sum(dim=-1) / neg_weight.sum(dim=-1) + elif criterion == "ce": + target = torch.zeros(len(pred), dtype=torch.long, device=self.device) + loss = F.cross_entropy(pred, target, reduction="none") + elif criterion == "ranking": + positive = pred[:, :1] + negative = pred[:, 1:] + target = torch.ones_like(negative) + loss = F.margin_ranking_loss(positive, negative, target, margin=self.margin) + else: + raise ValueError("Unknown criterion `%s`" % criterion) + + if self.sample_weight: + sample_weight = self.degree_hr[pos_h_index, pos_r_index] * self.degree_tr[pos_t_index, pos_r_index] + sample_weight = 1 / sample_weight.float().sqrt() + loss = (loss * sample_weight).sum() / sample_weight.sum() + else: + loss = loss.mean() + + name = tasks._get_criterion_name(criterion) + metric[name] = loss + all_loss += loss * weight + + return all_loss, metric + + def predict(self, batch, all_loss=None, metric=None): + pos_h_index, pos_t_index, pos_r_index = batch.t() + batch_size = len(batch) + + if all_loss is None: + # test + all_index = torch.arange(self.num_entity, device=self.device) + t_preds = [] + h_preds = [] + num_negative = self.num_entity if self.full_batch_eval else self.num_negative + for neg_index in all_index.split(num_negative): + r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index)) + h_index, t_index = torch.meshgrid(pos_h_index, neg_index) + t_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric) + t_preds.append(t_pred) + t_pred = torch.cat(t_preds, dim=-1) + for neg_index in all_index.split(num_negative): + r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index)) + t_index, h_index = torch.meshgrid(pos_t_index, neg_index) + h_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric) + h_preds.append(h_pred) + h_pred = torch.cat(h_preds, dim=-1) + pred = torch.stack([t_pred, h_pred], dim=1) + # in case of GPU OOM + pred = pred.cpu() + else: + # train + if self.strict_negative: + neg_index = self._strict_negative(pos_h_index, pos_t_index, pos_r_index) + else: + neg_index = torch.randint(self.num_entity, (batch_size, self.num_negative), device=self.device) + h_index = pos_h_index.unsqueeze(-1).repeat(1, self.num_negative + 1) + t_index = pos_t_index.unsqueeze(-1).repeat(1, self.num_negative + 1) + r_index = pos_r_index.unsqueeze(-1).repeat(1, self.num_negative + 1) + t_index[:batch_size // 2, 1:] = neg_index[:batch_size // 2] + h_index[batch_size // 2:, 1:] = neg_index[batch_size // 2:] + pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric) + + return pred + + def target(self, batch): + # test target + batch_size = len(batch) + pos_h_index, pos_t_index, pos_r_index = batch.t() + any = -torch.ones_like(pos_h_index) + + pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1) + edge_index, num_t_truth = self.graph.match(pattern) + t_truth_index = self.graph.edge_list[edge_index, 1] + pos_index = torch.repeat_interleave(num_t_truth) + t_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device) + t_mask[pos_index, t_truth_index] = 0 + + pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1) + edge_index, num_h_truth = self.graph.match(pattern) + h_truth_index = self.graph.edge_list[edge_index, 0] + pos_index = torch.repeat_interleave(num_h_truth) + h_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device) + h_mask[pos_index, h_truth_index] = 0 + + mask = torch.stack([t_mask, h_mask], dim=1) + target = torch.stack([pos_t_index, pos_h_index], dim=1) + + # in case of GPU OOM + return mask.cpu(), target.cpu() + + def evaluate(self, pred, target): + mask, target = target + + pos_pred = pred.gather(-1, target.unsqueeze(-1)) + if self.filtered_ranking: + ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1 + else: + ranking = torch.sum(pos_pred <= pred, dim=-1) + 1 + + metric = {} + for _metric in self.metric: + if _metric == "mr": + score = ranking.float().mean() + elif _metric == "mrr": + score = (1 / ranking.float()).mean() + elif _metric.startswith("hits@"): + threshold = int(_metric[5:]) + score = (ranking <= threshold).float().mean() + else: + raise ValueError("Unknown metric `%s`" % _metric) + + name = tasks._get_metric_name(_metric) + metric[name] = score + + return metric + + def visualize(self, batch): + h_index, t_index, r_index = batch.t() + return self.model.visualize(self.fact_graph, h_index, t_index, r_index) + + @torch.no_grad() + def _strict_negative(self, pos_h_index, pos_t_index, pos_r_index): + batch_size = len(pos_h_index) + any = -torch.ones_like(pos_h_index) + + pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1) + pattern = pattern[:batch_size // 2] + edge_index, num_t_truth = self.fact_graph.match(pattern) + t_truth_index = self.fact_graph.edge_list[edge_index, 1] + pos_index = torch.repeat_interleave(num_t_truth) + t_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device) + t_mask[pos_index, t_truth_index] = 0 + neg_t_candidate = t_mask.nonzero()[:, 1] + num_t_candidate = t_mask.sum(dim=-1) + neg_t_index = functional.variadic_sample(neg_t_candidate, num_t_candidate, self.num_negative) + + pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1) + pattern = pattern[batch_size // 2:] + edge_index, num_h_truth = self.fact_graph.match(pattern) + h_truth_index = self.fact_graph.edge_list[edge_index, 0] + pos_index = torch.repeat_interleave(num_h_truth) + h_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device) + h_mask[pos_index, h_truth_index] = 0 + neg_h_candidate = h_mask.nonzero()[:, 1] + num_h_candidate = h_mask.sum(dim=-1) + neg_h_index = functional.variadic_sample(neg_h_candidate, num_h_candidate, self.num_negative) + + neg_index = torch.cat([neg_t_index, neg_h_index]) + + return neg_index diff --git a/build/lib/torchdrug/tasks/retrosynthesis.py b/build/lib/torchdrug/tasks/retrosynthesis.py new file mode 100644 index 00000000..019ceef7 --- /dev/null +++ b/build/lib/torchdrug/tasks/retrosynthesis.py @@ -0,0 +1,1208 @@ +import inspect +from collections import deque + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import data as torch_data +from torch_scatter import scatter_max, scatter_add + +from torchdrug import core, tasks, data, metrics, transforms +from torchdrug.layers import functional +from torchdrug.core import Registry as R +from torchdrug import layers + +import logging +logger = logging.getLogger(__name__) + + +@R.register("tasks.CenterIdentification") +class CenterIdentification(tasks.Task, core.Configurable): + """ + Reaction center identification task. + + This class is a part of retrosynthesis prediction. + + Parameters: + model (nn.Module): graph representation model + feature (str or list of str, optional): additional features for prediction. Available features are + reaction: type of the reaction + graph: graph representation of the product + atom: original atom feature + bond: original bond feature + num_mlp_layer (int, optional): number of MLP layers + """ + + _option_members = {"feature"} + + def __init__(self, model, feature=("reaction", "graph", "atom", "bond"), num_mlp_layer=2): + super(CenterIdentification, self).__init__() + self.model = model + self.num_mlp_layer = num_mlp_layer + self.feature = feature + + def preprocess(self, train_set, valid_set, test_set): + reaction_types = set() + bond_types = set() + for sample in train_set: + reaction_types.add(sample["reaction"]) + for graph in sample["graph"]: + bond_types.update(graph.edge_list[:, 2].tolist()) + self.num_reaction = len(reaction_types) + self.num_relation = len(bond_types) + node_feature_dim = train_set[0]["graph"][0].node_feature.shape[-1] + edge_feature_dim = train_set[0]["graph"][0].edge_feature.shape[-1] + + node_dim = self.model.output_dim + edge_dim = 0 + graph_dim = 0 + for _feature in sorted(self.feature): + if _feature == "reaction": + graph_dim += self.num_reaction + elif _feature == "graph": + graph_dim += self.model.output_dim + elif _feature == "atom": + node_dim += node_feature_dim + elif _feature == "bond": + edge_dim += edge_feature_dim + else: + raise ValueError("Unknown feature `%s`" % _feature) + + node_dim += graph_dim # inherit graph features + edge_dim += node_dim * 2 # inherit node features + + hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) + self.edge_mlp = layers.MLP(edge_dim, hidden_dims + [1]) + self.node_mlp = layers.MLP(node_dim, hidden_dims + [1]) + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred = self.predict(batch, all_loss, metric) + target = self.target(batch) + metric.update(self.evaluate(pred, target)) + + target, size = target + target = functional.variadic_max(target, size)[1] + loss = functional.variadic_cross_entropy(pred, target, size) + + name = tasks._get_criterion_name("ce") + metric[name] = loss + + all_loss += loss + + return all_loss, metric + + def _collate(self, edge_data, node_data, graph): + new_data = torch.zeros(len(edge_data) + len(node_data), *edge_data.shape[1:], + dtype=edge_data.dtype, device=edge_data.device) + num_cum_xs = graph.num_cum_edges + graph.num_cum_nodes + num_xs = graph.num_edges + graph.num_nodes + starts = num_cum_xs - num_xs + ends = starts + graph.num_edges + index = functional.multi_slice_mask(starts, ends, num_cum_xs[-1]) + new_data[index] = edge_data + new_data[~index] = node_data + return new_data + + def target(self, batch): + reactant, product = batch["graph"] + graph = product.directed() + + target = self._collate(graph.edge_label, graph.node_label, graph) + size = graph.num_edges + graph.num_nodes + return target, size + + def predict(self, batch, all_loss=None, metric=None): + reactant, product = batch["graph"] + output = self.model(product, product.node_feature.float(), all_loss, metric) + + graph = product.directed() + + node_feature = [output["node_feature"]] + edge_feature = [] + graph_feature = [] + for _feature in sorted(self.feature): + if _feature == "reaction": + reaction_feature = torch.zeros(len(graph), self.num_reaction, dtype=torch.float32, device=self.device) + reaction_feature.scatter_(1, batch["reaction"].unsqueeze(-1), 1) + graph_feature.append(reaction_feature) + elif _feature == "graph": + graph_feature.append(output["graph_feature"]) + elif _feature == "atom": + node_feature.append(graph.node_feature.float()) + elif _feature == "bond": + edge_feature.append(graph.edge_feature.float()) + else: + raise ValueError("Unknown feature `%s`" % _feature) + + graph_feature = torch.cat(graph_feature, dim=-1) + # inherit graph features + node_feature.append(graph_feature[graph.node2graph]) + node_feature = torch.cat(node_feature, dim=-1) + # inherit node features + edge_feature.append(node_feature[graph.edge_list[:, :2]].flatten(1)) + edge_feature = torch.cat(edge_feature, dim=-1) + + edge_pred = self.edge_mlp(edge_feature).squeeze(-1) + node_pred = self.node_mlp(node_feature).squeeze(-1) + + pred = self._collate(edge_pred, node_pred, graph) + + return pred + + def evaluate(self, pred, target): + target, size = target + + metric = {} + target = functional.variadic_max(target, size)[1] + accuracy = metrics.variadic_accuracy(pred, target, size).mean() + + name = tasks._get_metric_name("acc") + metric[name] = accuracy + + return metric + + @torch.no_grad() + def predict_synthon(self, batch, k=1): + """ + Predict top-k synthons from target molecules. + + Parameters: + batch (dict): batch of target molecules + k (int, optional): return top-k results + + Returns: + list of dict: top k records. + Each record is a batch dict of keys ``synthon``, ``num_synthon``, ``reaction_center``, + ``log_likelihood`` and ``reaction``. + """ + pred = self.predict(batch) + target, size = self.target(batch) + logp = functional.variadic_log_softmax(pred, size) + + reactant, product = batch["graph"] + graph = product.directed() + with graph.graph(): + graph.product_id = torch.arange(len(graph), device=self.device) + + graph = graph.repeat_interleave(k) + reaction = batch["reaction"].repeat_interleave(k) + with graph.graph(): + graph.split_id = torch.arange(k, device=self.device).repeat(len(graph) // k) + + logp, center_topk = functional.variadic_topk(logp, size, k) + logp = logp.flatten() + center_topk = center_topk.flatten() + + is_edge = center_topk < graph.num_edges + node_index = center_topk + graph.num_cum_nodes - graph.num_nodes - graph.num_edges + edge_index = center_topk + graph.num_cum_edges - graph.num_edges + center_topk_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device), + center_topk[:-1]]) + product_id_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device), + graph.product_id[:-1]]) + is_duplicate = (center_topk == center_topk_shifted) & (graph.product_id == product_id_shifted) + node_index = node_index[~is_edge] + edge_index = edge_index[is_edge] + edge_mask = ~functional.as_mask(edge_index, graph.num_edge) + + reaction_center = torch.zeros(len(graph), 2, dtype=torch.long, device=self.device) + reaction_center[is_edge] = graph.atom_map[graph.edge_list[edge_index, :2]] + reaction_center[~is_edge, 0] = graph.atom_map[node_index] + + # remove the edges from products + graph = graph.edge_mask(edge_mask) + graph = graph[~is_duplicate] + reaction_center = reaction_center[~is_duplicate] + logp = logp[~is_duplicate] + reaction = reaction[~is_duplicate] + synthon, num_synthon = graph.connected_components() + synthon = synthon.undirected() # (< num_graph * k) + + result = { + "synthon": synthon, + "num_synthon": num_synthon, + "reaction_center": reaction_center, + "log_likelihood": logp, + "reaction": reaction, + } + + return result + + +class RandomBFSOrder(object): + + def __call__(self, item): + assert hasattr(item["graph"][0], "reaction_center") + reactant, synthon = item["graph"] + + edge_list = reactant.edge_list[:, :2].tolist() + neighbor = [[] for _ in range(reactant.num_node)] + for h, t in edge_list: + neighbor[h].append(t) + depth = [-1] * reactant.num_node + + # select a mapped atom as BFS root + reactant2id = reactant.atom_map + id2synthon = -torch.ones(synthon.atom_map.max() + 1, dtype=torch.long, device=synthon.device) + id2synthon[synthon.atom_map] = torch.arange(synthon.num_node, device=synthon.device) + reactant2synthon = id2synthon[reactant2id] + + candidate = (reactant2synthon != -1).nonzero().squeeze(-1) + i = candidate[torch.randint(len(candidate), (1,))].item() + + queue = deque([i]) + depth[i] = 0 + order = [] + while queue: + h = queue.popleft() + order.append(h) + for t in neighbor[h]: + if depth[t] == -1: + depth[t] = depth[h] + 1 + queue.append(t) + + reactant = reactant.subgraph(order) + + if reactant.num_edge > 0: + node_index = reactant.edge_list[:, :2] + node_large = node_index.max(dim=-1)[0] + node_small = node_index.min(dim=-1)[0] + undirected_edge_id = node_large * (node_large + 1) + node_small + undirected_edge_id = undirected_edge_id * 2 + (node_index[:, 0] > node_index[:, 1]) + + # rearrange edges into autoregressive order + edge_order = undirected_edge_id.argsort() + reactant = reactant.edge_mask(edge_order) + + assert hasattr(reactant, "reaction_center") + + item = item.copy() + item["graph"] = (reactant, synthon) + + return item + + +@R.register("tasks.SynthonCompletion") +class SynthonCompletion(tasks.Task, core.Configurable): + """ + Synthon completion task. + + This class is a part of retrosynthesis prediction. + + Parameters: + model (nn.Module): graph representation model + feature (str or list of str, optional): additional features for prediction. Available features are + reaction: type of the reaction + graph: graph representation of the synthon + atom: original atom feature + num_mlp_layer (int, optional): number of MLP layers + """ + + _option_members = {"feature"} + + def __init__(self, model, feature=("reaction", "graph", "atom"), num_mlp_layer=2): + super(SynthonCompletion, self).__init__() + self.model = model + self.num_mlp_layer = num_mlp_layer + self.feature = feature + self.input_linear = nn.Linear(2, self.model.input_dim) + + def preprocess(self, train_set, valid_set, test_set): + reaction_types = set() + atom_types = set() + bond_types = set() + for sample in train_set: + reaction_types.add(sample["reaction"]) + for graph in sample["graph"]: + atom_types.update(graph.atom_type.tolist()) + bond_types.update(graph.edge_list[:, 2].tolist()) + # TODO: only for fast debugging, to remove + # atom_types = torch.tensor([5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 29, 30, 34, 35, 50, 53]) + # bond_types = torch.tensor([0, 1, 2]) + atom_types = torch.tensor(sorted(atom_types)) + atom2id = -torch.ones(atom_types.max() + 1, dtype=torch.long) + atom2id[atom_types] = torch.arange(len(atom_types)) + self.register_buffer("id2atom", atom_types) + self.register_buffer("atom2id", atom2id) + self.num_reaction = len(reaction_types) + self.num_atom_type = len(atom_types) + self.num_bond_type = len(bond_types) + node_feature_dim = train_set[0]["graph"][0].node_feature.shape[-1] + + if isinstance(train_set, torch_data.Subset): + dataset = train_set.dataset + else: + dataset = train_set + dataset.transform = transforms.Compose([ + dataset.transform, + RandomBFSOrder(), + ]) + sig = inspect.signature(data.PackedMolecule.from_molecule) + keys = set(sig.parameters.keys()) + kwargs = dataset.config_dict() + feature_kwargs = {} + for k, v in kwargs.items(): + if k in keys: + feature_kwargs[k] = v + self.feature_kwargs = feature_kwargs + + node_dim = self.model.output_dim + edge_dim = 0 + graph_dim = 0 + for _feature in sorted(self.feature): + if _feature == "reaction": + graph_dim += self.num_reaction + elif _feature == "graph": + graph_dim += self.model.output_dim + elif _feature == "atom": + node_dim += node_feature_dim + else: + raise ValueError("Unknown feature `%s`" % _feature) + + self.new_atom_feature = nn.Embedding(self.num_atom_type, node_dim) + + node_dim += graph_dim # inherit graph features + edge_dim += node_dim * 2 # inherit node features + + hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) + self.node_in_mlp = layers.MLP(node_dim, hidden_dims + [1]) + self.node_out_mlp = layers.MLP(edge_dim, hidden_dims + [1]) + self.edge_mlp = layers.MLP(edge_dim, hidden_dims + [1]) + self.bond_mlp = layers.MLP(edge_dim, hidden_dims + [self.num_bond_type]) + self.stop_mlp = layers.MLP(graph_dim, hidden_dims + [1]) + + def _update_molecule_feature(self, graphs): + # This function is very slow + graphs = graphs.ion_to_molecule() + mols = graphs.to_molecule(ignore_error=True) + valid = [mol is not None for mol in mols] + valid = torch.tensor(valid, device=graphs.device) + new_graphs = type(graphs).from_molecule(mols, **self.feature_kwargs) + + node_feature = torch.zeros(graphs.num_node, *new_graphs.node_feature.shape[1:], + dtype=new_graphs.node_feature.dtype, device=graphs.device) + edge_feature = torch.zeros(graphs.num_edge, *new_graphs.edge_feature.shape[1:], + dtype=new_graphs.edge_feature.dtype, device=graphs.device) + bond_type = torch.zeros_like(graphs.bond_type) + node_mask = valid[graphs.node2graph] + edge_mask = valid[graphs.edge2graph] + node_feature[node_mask] = new_graphs.node_feature.to(device=graphs.device) + edge_feature[edge_mask] = new_graphs.edge_feature.to(device=graphs.device) + bond_type[edge_mask] = new_graphs.bond_type.to(device=graphs.device) + + with graphs.node(): + graphs.node_feature = node_feature + with graphs.edge(): + graphs.edge_feature = edge_feature + graphs.bond_type = bond_type + + return graphs, valid + + @torch.no_grad() + def _all_prefix_slice(self, num_xs, lengths=None): + # extract a bunch of slices that correspond to the following num_repeat * n masks + # ------ repeat 0 ----- + # graphs[0]: [0, 0, ..., 0] + # ... + # graphs[-1]: [0, 0, ..., 0] + # ------ repeat 1 ----- + # graphs[0]: [1, 0, ..., 0] + # ... + # graphs[-1]: [1, 0, ..., 0] + # ... + # ------ repeat -1 ----- + # graphs[0]: [1, ..., 1, 0] + # ... + # graphs[-1]: [1, ..., 1, 0] + num_cum_xs = num_xs.cumsum(0) + starts = num_cum_xs - num_xs + if lengths is None: + num_max_x = num_xs.max().item() + lengths = torch.arange(0, num_max_x, 2, device=num_xs.device) + + pack_offsets = torch.arange(len(lengths), device=num_xs.device) * num_cum_xs[-1] + # starts, lengths, ends: (num_repeat, num_graph) + starts = starts.unsqueeze(0) + pack_offsets.unsqueeze(-1) + valid = lengths.unsqueeze(-1) <= num_xs.unsqueeze(0) - 2 + lengths = torch.min(lengths.unsqueeze(-1), num_xs.unsqueeze(0) - 2).clamp(0) + ends = starts + lengths + + starts = starts.flatten() + ends = ends.flatten() + valid = valid.flatten() + + return starts, ends, valid + + @torch.no_grad() + def _get_reaction_feature(self, reactant, synthon): + + def get_edge_map(graph, num_nodes): + node_in, node_out = graph.edge_list.t()[:2] + node_in2id = graph.atom_map[node_in] + node_out2id = graph.atom_map[node_out] + edge_map = node_in2id * num_nodes[graph.edge2graph] + node_out2id + # edges containing any unmapped node is considered to be unmapped + edge_map[(node_in2id == 0) | (node_out2id == 0)] = 0 + return edge_map + + def get_mapping(reactant_x, synthon_x, reactant_x2graph, synthon_x2graph): + num_xs = scatter_max(reactant_x, reactant_x2graph)[0] + num_xs = num_xs.clamp(0) + 1 + num_cum_xs = num_xs.cumsum(0) + offset = num_cum_xs - num_xs + reactant2id = reactant_x + offset[reactant_x2graph] + synthon2id = synthon_x + offset[synthon_x2graph] + assert synthon2id.min() > 0 + id2synthon = -torch.ones(num_cum_xs[-1], dtype=torch.long, device=self.device) + id2synthon[synthon2id] = torch.arange(len(synthon2id), device=self.device) + reactant2synthon = id2synthon[reactant2id] + + return reactant2synthon + + # reactant & synthon may have different number of nodes + # reactant.num_nodes >= synthon.num_nodes + assert (reactant.num_nodes >= synthon.num_nodes).all() + reactant_edge_map = get_edge_map(reactant, reactant.num_nodes) + synthon_edge_map = get_edge_map(synthon, reactant.num_nodes) + + node_r2s = get_mapping(reactant.atom_map, synthon.atom_map, reactant.node2graph, synthon.node2graph) + edge_r2s = get_mapping(reactant_edge_map, synthon_edge_map, reactant.edge2graph, synthon.edge2graph) + + is_new_node = node_r2s == -1 + is_new_edge = edge_r2s == -1 + is_modified_edge = (edge_r2s != -1) & (reactant.bond_type != synthon.bond_type[edge_r2s]) + is_reaction_center = (reactant.atom_map > 0) & \ + (reactant.atom_map.unsqueeze(-1) == + reactant.reaction_center[reactant.node2graph]).any(dim=-1) + + return node_r2s, edge_r2s, is_new_node, is_new_edge, is_modified_edge, is_reaction_center + + @torch.no_grad() + def all_edge(self, reactant, synthon): + graph = reactant.clone() + node_r2s, edge_r2s, is_new_node, is_new_edge, is_modified_edge, is_reaction_center = \ + self._get_reaction_feature(reactant, synthon) + with graph.node(): + graph.node_r2s = node_r2s + graph.is_new_node = is_new_node + graph.is_reaction_center = is_reaction_center + with graph.edge(): + graph.edge_r2s = edge_r2s + graph.is_new_edge = is_new_edge + graph.is_modified_edge = is_modified_edge + + starts, ends, valid = self._all_prefix_slice(reactant.num_edges) + num_repeat = len(starts) // len(reactant) + graph = graph.repeat(num_repeat) + + # autoregressive condition range for each sample + condition_mask = functional.multi_slice_mask(starts, ends, graph.num_edge) + # special case: end == graph.num_edge. In this case, valid is always false + assert ends.max() <= graph.num_edge + ends = ends.clamp(0, graph.num_edge - 1) + node_in, node_out, bond_target = graph.edge_list[ends].t() + # modified edges which don't appear in conditions should keep their old bond types + # i.e. bond types in synthons + unmodified = ~condition_mask & graph.is_modified_edge + unmodified = unmodified.nonzero().squeeze(-1) + assert not (graph.bond_type[unmodified] == synthon.bond_type[graph.edge_r2s[unmodified]]).any() + graph.edge_list[unmodified, 2] = synthon.edge_list[graph.edge_r2s[unmodified], 2] + + reverse_target = graph.edge_list[ends][:, [1, 0, 2]] + is_reverse_target = (graph.edge_list == reverse_target[graph.edge2graph]).all(dim=-1) + # keep edges that exist in the synthon + # remove the reverse of new target edges + edge_mask = (condition_mask & ~is_reverse_target) | ~graph.is_new_edge + + atom_in = graph.atom_type[node_in] + atom_out = graph.atom_type[node_out] + # keep one supervision for undirected edges + # remove samples that try to predict existing edges + valid &= (node_in < node_out) & (graph.is_new_edge[ends] | graph.is_modified_edge[ends]) + graph = graph.edge_mask(edge_mask) + + # sanitize the molecules + # this will change atom index, so we manually remap the target nodes + compact_mapping = -torch.ones(graph.num_node, dtype=torch.long, device=self.device) + node_mask = graph.degree_in + graph.degree_out > 0 + # special case: for graphs without any edge, the first node should be kept + index = torch.arange(graph.num_node, device=self.device) + single_node_mask = (graph.num_edges == 0)[graph.node2graph] & \ + (index == (graph.num_cum_nodes - graph.num_nodes)[graph.node2graph]) + node_index = (node_mask | single_node_mask).nonzero().squeeze(-1) + compact_mapping[node_index] = torch.arange(len(node_index), device=self.device) + node_in = compact_mapping[node_in] + node_out = compact_mapping[node_out] + graph = graph.subgraph(node_index) + + node_in_target = node_in - graph.num_cum_nodes + graph.num_nodes + assert (node_in_target[valid] < graph.num_nodes[valid]).all() and (node_in_target[valid] >= 0).all() + # node2 might be a new node + node_out_target = torch.where(node_out == -1, self.atom2id[atom_out] + graph.num_nodes, + node_out - graph.num_cum_nodes + graph.num_nodes) + stop_target = torch.zeros(len(node_in_target), device=self.device) + + graph = graph[valid] + node_in_target = node_in_target[valid] + node_out_target = node_out_target[valid] + bond_target = bond_target[valid] + stop_target = stop_target[valid] + + assert (graph.num_edges % 2 == 0).all() + # node / edge features may change because we mask some nodes / edges + graph, feature_valid = self._update_molecule_feature(graph) + + return graph[feature_valid], node_in_target[feature_valid], node_out_target[feature_valid], \ + bond_target[feature_valid], stop_target[feature_valid] + + @torch.no_grad() + def all_stop(self, reactant, synthon): + graph = reactant.clone() + node_r2s, edge_r2s, is_new_node, is_new_edge, is_modified_edge, is_reaction_center = \ + self._get_reaction_feature(reactant, synthon) + with graph.node(): + graph.node_r2s = node_r2s + graph.is_new_node = is_new_node + graph.is_reaction_center = is_reaction_center + with graph.edge(): + graph.edge_r2s = edge_r2s + graph.is_new_edge = is_new_edge + graph.is_modified_edge = is_modified_edge + + node_in_target = torch.zeros(len(graph), dtype=torch.long, device=self.device) + node_out_target = torch.zeros_like(node_in_target) + bond_target = torch.zeros_like(node_in_target) + stop_target = torch.ones(len(graph), device=self.device) + + # keep consistent with other training data + graph, feature_valid = self._update_molecule_feature(graph) + + return graph[feature_valid], node_in_target[feature_valid], node_out_target[feature_valid], \ + bond_target[feature_valid], stop_target[feature_valid] + + def forward(self, batch): + """""" + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred, target = self.predict_and_target(batch, all_loss, metric) + node_in_pred, node_out_pred, bond_pred, stop_pred = pred + node_in_target, node_out_target, bond_target, stop_target, size = target + + loss = functional.variadic_cross_entropy(node_in_pred, node_in_target, size, reduction="none") + loss = functional.masked_mean(loss, stop_target == 0) + metric["node in ce loss"] = loss + all_loss += loss + + loss = functional.variadic_cross_entropy(node_out_pred, node_out_target, size, reduction="none") + loss = functional.masked_mean(loss, stop_target == 0) + metric["node out ce loss"] = loss + all_loss += loss + + loss = F.cross_entropy(bond_pred, bond_target, reduction="none") + loss = functional.masked_mean(loss, stop_target == 0) + metric["bond ce loss"] = loss + all_loss += loss + + # Do we need to balance stop pred? + loss = F.binary_cross_entropy_with_logits(stop_pred, stop_target) + metric["stop bce loss"] = loss + all_loss += loss + + metric["total loss"] = all_loss + metric.update(self.evaluate(pred, target)) + + return all_loss, metric + + def evaluate(self, pred, target): + node_in_pred, node_out_pred, bond_pred, stop_pred = pred + node_in_target, node_out_target, bond_target, stop_target, size = target + + metric = {} + + node_in_acc = metrics.variadic_accuracy(node_in_pred, node_in_target, size) + accuracy = functional.masked_mean(node_in_acc, stop_target == 0) + metric["node in accuracy"] = accuracy + + node_out_acc = metrics.variadic_accuracy(node_out_pred, node_out_target, size) + accuracy = functional.masked_mean(node_out_acc, stop_target == 0) + metric["node out accuracy"] = accuracy + + bond_acc = (bond_pred.argmax(-1) == bond_target).float() + accuracy = functional.masked_mean(bond_acc, stop_target == 0) + metric["bond accuracy"] = accuracy + + stop_acc = ((stop_pred > 0.5) == (stop_target > 0.5)).float() + metric["stop accuracy"] = stop_acc.mean() + + total_acc = (node_in_acc > 0.5) & (node_out_acc > 0.5) & (bond_acc > 0.5) & (stop_acc > 0.5) + total_acc = torch.where(stop_target == 0, total_acc, stop_acc > 0.5).float() + metric["total accuracy"] = total_acc.mean() + + return metric + + def _cat(self, graphs): + for i, graph in enumerate(graphs): + if not isinstance(graph, data.PackedGraph): + graphs[i] = graph.pack([graph]) + + edge_list = torch.cat([graph.edge_list for graph in graphs]) + pack_num_nodes = torch.stack([graph.num_node for graph in graphs]) + pack_num_edges = torch.stack([graph.num_edge for graph in graphs]) + pack_num_cum_edges = pack_num_edges.cumsum(0) + graph_index = pack_num_cum_edges < len(edge_list) + pack_offsets = scatter_add(pack_num_nodes[graph_index], pack_num_cum_edges[graph_index], + dim_size=len(edge_list)) + pack_offsets = pack_offsets.cumsum(0) + + edge_list[:, :2] += pack_offsets.unsqueeze(-1) + offsets = torch.cat([graph._offsets for graph in graphs]) + pack_offsets + + edge_weight = torch.cat([graph.edge_weight for graph in graphs]) + num_nodes = torch.cat([graph.num_nodes for graph in graphs]) + num_edges = torch.cat([graph.num_edges for graph in graphs]) + num_relation = graphs[0].num_relation + assert all(graph.num_relation == num_relation for graph in graphs) + + # only keep attributes that exist in all graphs + keys = set(graphs[0].meta_dict.keys()) + for graph in graphs: + keys = keys.intersection(graph.meta_dict.keys()) + + meta_dict = {k: graphs[0].meta_dict[k] for k in keys} + data_dict = {} + for k in keys: + data_dict[k] = torch.cat([graph.data_dict[k] for graph in graphs]) + + return type(graphs[0])(edge_list, edge_weight=edge_weight, + num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def target(self, batch): + reactant, synthon = batch["graph"] + + graph1, node_in_target1, node_out_target1, bond_target1, stop_target1 = self.all_edge(reactant, synthon) + graph2, node_in_target2, node_out_target2, bond_target2, stop_target2 = self.all_stop(reactant, synthon) + + node_in_target = torch.cat([node_in_target1, node_in_target2]) + node_out_target = torch.cat([node_out_target1, node_out_target2]) + bond_target = torch.cat([bond_target1, bond_target2]) + stop_target = torch.cat([stop_target1, stop_target2]) + size = torch.cat([graph1.num_nodes, graph2.num_nodes]) + # add new atom candidates into the size of each graph + size_ext = size + self.num_atom_type + + return node_in_target, node_out_target, bond_target, stop_target, size_ext + + def _topk_action(self, graph, k): + synthon_feature = torch.stack([graph.is_new_node, graph.is_reaction_center], dim=-1).float() + node_feature = graph.node_feature.float() + self.input_linear(synthon_feature) + output = self.model(graph, node_feature) + + node_feature = [output["node_feature"]] + graph_feature = [] + for _feature in sorted(self.feature): + if _feature == "reaction": + reaction_feature = torch.zeros(len(graph), self.num_reaction, dtype=torch.float32, device=self.device) + reaction_feature.scatter_(1, graph.reaction.unsqueeze(-1), 1) + graph_feature.append(reaction_feature) + elif _feature == "graph": + graph_feature.append(output["graph_feature"]) + elif _feature == "atom": + node_feature.append(graph.node_feature.float()) + else: + raise ValueError("Unknown feature `%s`" % _feature) + + graph_feature = torch.cat(graph_feature, dim=-1) + # inherit graph features + node_feature.append(graph_feature[graph.node2graph]) + node_feature = torch.cat(node_feature, dim=-1) + + new_node_feature = self.new_atom_feature.weight.repeat(len(graph), 1) + new_graph_feature = graph_feature.unsqueeze(1).repeat(1, self.num_atom_type, 1).flatten(0, 1) + new_node_feature = torch.cat([new_node_feature, new_graph_feature], dim=-1) + node_feature, num_nodes_ext = self._extend(node_feature, graph.num_nodes, new_node_feature) + + node2graph_ext = torch.repeat_interleave(num_nodes_ext) + num_cum_nodes_ext = num_nodes_ext.cumsum(0) + starts = num_cum_nodes_ext - num_nodes_ext + graph.num_nodes + ends = num_cum_nodes_ext + is_new_node = functional.multi_slice_mask(starts, ends, num_cum_nodes_ext[-1]) + infinity = float("inf") + + node_in_pred = self.node_in_mlp(node_feature).squeeze(-1) + stop_pred = self.stop_mlp(graph_feature).squeeze(-1) + + # mask out node-in prediction on new atoms + node_in_pred[is_new_node] = -infinity + node_in_logp = functional.variadic_log_softmax(node_in_pred, num_nodes_ext) # (num_node,) + stop_logp = F.logsigmoid(stop_pred) + act_logp = F.logsigmoid(-stop_pred) + node_in_topk = functional.variadic_topk(node_in_logp, num_nodes_ext, k)[1] + assert (node_in_topk >= 0).all() and (node_in_topk < num_nodes_ext.unsqueeze(-1)).all() + node_in = node_in_topk + (num_cum_nodes_ext - num_nodes_ext).unsqueeze(-1) # (num_graph, k) + + # (num_node, node_in_k, feature_dim) + node_out_feature = torch.cat([node_feature[node_in][node2graph_ext], + node_feature.unsqueeze(1).expand(-1, k, -1)], dim=-1) + node_out_pred = self.node_out_mlp(node_out_feature).squeeze(-1) + # mask out node-out prediction on self-loops + node_out_pred.scatter_(0, node_in, -infinity) + # (num_node, node_in_k) + node_out_logp = functional.variadic_log_softmax(node_out_pred, num_nodes_ext) + # (num_graph, node_out_k, node_in_k) + node_out_topk = functional.variadic_topk(node_out_logp, num_nodes_ext, k)[1] + assert (node_out_topk >= 0).all() and (node_out_topk < num_nodes_ext.view(-1, 1, 1)).all() + node_out = node_out_topk + (num_cum_nodes_ext - num_nodes_ext).view(-1, 1, 1) + + # (num_graph, node_out_k, node_in_k, feature_dim * 2) + edge = torch.stack([node_in.unsqueeze(1).expand_as(node_out), node_out], dim=-1) + bond_feature = node_feature[edge].flatten(-2) + bond_pred = self.bond_mlp(bond_feature).squeeze(-1) + bond_logp = F.log_softmax(bond_pred, dim=-1) # (num_graph, node_out_k, node_in_k, num_relation) + bond_type = torch.arange(bond_pred.shape[-1], device=self.device) + bond_type = bond_type.view(1, 1, 1, -1).expand_as(bond_logp) + + # (num_graph, node_out_k, node_in_k, num_relation) + node_in_logp = node_in_logp.gather(0, node_in.flatten(0, 1)).view(-1, 1, k, 1) + node_out_logp = node_out_logp.gather(0, node_out.flatten(0, 1)).view(-1, k, k, 1) + act_logp = act_logp.view(-1, 1, 1, 1) + logp = node_in_logp + node_out_logp + bond_logp + act_logp + + # (num_graph, node_out_k, node_in_k, num_relation, 4) + node_in_topk = node_in_topk.view(-1, 1, k, 1).expand_as(logp) + node_out_topk = node_out_topk.view(-1, k, k, 1).expand_as(logp) + action = torch.stack([node_in_topk, node_out_topk, bond_type, torch.zeros_like(bond_type)], dim=-1) + + # add stop action + logp = torch.cat([logp.flatten(1), stop_logp.unsqueeze(-1)], dim=1) + stop = torch.tensor([0, 0, 0, 1], device=self.device) + stop = stop.view(1, 1, -1).expand(len(graph), -1, -1) + action = torch.cat([action.flatten(1, -2), stop], dim=1) + topk = logp.topk(k, dim=-1)[1] + + return action.gather(1, topk.unsqueeze(-1).expand(-1, -1, 4)), logp.gather(1, topk) + + def _apply_action(self, graph, action, logp): + # only support non-variadic k-actions + assert len(graph) == len(action) + num_action = action.shape[1] + + graph = graph.repeat_interleave(num_action) + + action = action.flatten(0, 1) # (num_graph * k, 4) + logp = logp.flatten(0, 1) # (num_graph * k) + new_node_in, new_node_out, new_bond_type, stop = action.t() + + # add new nodes + has_new_node = (new_node_out >= graph.num_nodes) & (stop == 0) + new_atom_id = (new_node_out - graph.num_nodes)[has_new_node] + new_atom_type = self.id2atom[new_atom_id] + is_new_node = torch.ones(len(new_atom_type), dtype=torch.bool, device=self.device) + is_reaction_center = torch.zeros(len(new_atom_type), dtype=torch.bool, device=self.device) + atom_type, num_nodes = functional._extend(graph.atom_type, graph.num_nodes, new_atom_type, has_new_node) + is_new_node = functional._extend(graph.is_new_node, graph.num_nodes, is_new_node, has_new_node)[0] + is_reaction_center = functional._extend(graph.is_reaction_center, graph.num_nodes, is_reaction_center, has_new_node)[0] + + # cast to regular node ids + new_node_out = torch.where(has_new_node, graph.num_nodes, new_node_out) + + # modify edges + new_edge = torch.stack([new_node_in, new_node_out], dim=-1) + edge_list = graph.edge_list.clone() + bond_type = graph.bond_type.clone() + edge_list[:, :2] -= graph._offsets.unsqueeze(-1) + is_modified_edge = (edge_list[:, :2] == new_edge[graph.edge2graph]).all(dim=-1) & \ + (stop[graph.edge2graph] == 0) + has_modified_edge = scatter_max(is_modified_edge.long(), graph.edge2graph, dim_size=len(graph))[0] > 0 + bond_type[is_modified_edge] = new_bond_type[has_modified_edge] + edge_list[is_modified_edge, 2] = new_bond_type[has_modified_edge] + # modify reverse edges + new_edge = new_edge.flip(-1) + is_modified_edge = (edge_list[:, :2] == new_edge[graph.edge2graph]).all(dim=-1) & \ + (stop[graph.edge2graph] == 0) + bond_type[is_modified_edge] = new_bond_type[has_modified_edge] + edge_list[is_modified_edge, 2] = new_bond_type[has_modified_edge] + + # add new edges + has_new_edge = (~has_modified_edge) & (stop == 0) + new_edge_list = torch.stack([new_node_in, new_node_out, new_bond_type], dim=-1)[has_new_edge] + bond_type = functional._extend(bond_type, graph.num_edges, new_bond_type[has_new_edge], has_new_edge)[0] + edge_list, num_edges = functional._extend(edge_list, graph.num_edges, new_edge_list, has_new_edge) + # add reverse edges + new_edge_list = torch.stack([new_node_out, new_node_in, new_bond_type], dim=-1)[has_new_edge] + bond_type = functional._extend(bond_type, num_edges, new_bond_type[has_new_edge], has_new_edge)[0] + edge_list, num_edges = functional._extend(edge_list, num_edges, new_edge_list, has_new_edge) + + logp = logp + graph.logp + + # inherit attributes + data_dict = graph.data_dict + meta_dict = graph.meta_dict + for key in ["atom_type", "bond_type", "is_new_node", "is_reaction_center", "logp"]: + data_dict.pop(key) + # pad 0 for node / edge attributes + for k, v in data_dict.items(): + if "node" in meta_dict[k]: + shape = (len(new_atom_type), *v.shape[1:]) + new_data = torch.zeros(shape, dtype=v.dtype, device=self.device) + data_dict[k] = functional._extend(v, graph.num_nodes, new_data, has_new_node)[0] + if "edge" in meta_dict[k]: + shape = (len(new_edge_list) * 2, *v.shape[1:]) + new_data = torch.zeros(shape, dtype=v.dtype, device=self.device) + data_dict[k] = functional._extend(v, graph.num_edges, new_data, has_new_edge * 2)[0] + + new_graph = type(graph)(edge_list, atom_type=atom_type, bond_type=bond_type, num_nodes=num_nodes, + num_edges=num_edges, num_relation=graph.num_relation, + is_new_node=is_new_node, is_reaction_center=is_reaction_center, logp=logp, + meta_dict=meta_dict, **data_dict) + with new_graph.graph(): + new_graph.is_stopped = stop == 1 + valid = logp > float("-inf") + new_graph = new_graph[valid] + + new_graph, feature_valid = self._update_molecule_feature(new_graph) + return new_graph[feature_valid] + + @torch.no_grad() + def predict_reactant(self, batch, num_beam=10, max_prediction=20, max_step=20): + if "synthon" in batch: + synthon = batch["synthon"] + synthon2product = torch.repeat_interleave(batch["num_synthon"]) + assert (synthon2product < len(batch["reaction"])).all() + reaction = batch["reaction"][synthon2product] + else: + reactant, synthon = batch["graph"] + reaction = batch["reaction"] + + # In any case, ensure that the synthon is a molecule rather than an ion + # This is consistent across train/test routines in synthon completion + synthon, feature_valid = self._update_molecule_feature(synthon) + synthon = synthon[feature_valid] + reaction = reaction[feature_valid] + + graph = synthon + with graph.graph(): + # for convenience, because we need to manipulate graph a lot + graph.reaction = reaction + graph.synthon_id = torch.arange(len(graph), device=graph.device) + if not hasattr(graph, "logp"): + graph.logp = torch.zeros(len(graph), device=graph.device) + with graph.node(): + graph.is_new_node = torch.zeros(graph.num_node, dtype=torch.bool, device=graph.device) + graph.is_reaction_center = (graph.atom_map > 0) & \ + (graph.atom_map.unsqueeze(-1) == + graph.reaction_center[graph.node2graph]).any(dim=-1) + + result = [] + num_prediction = torch.zeros(len(synthon), dtype=torch.long, device=self.device) + for i in range(max_step): + logger.warning("action step: %d" % i) + logger.warning("batched beam size: %d" % len(graph)) + # each candidate has #beam actions + action, logp = self._topk_action(graph, num_beam) + + # each candidate is expanded to at most #beam (depending on validity) new candidates + new_graph = self._apply_action(graph, action, logp) + # assert (new_graph[is_stopped].logp > float("-inf")).all() + offset = -2 * (new_graph.logp.max() - new_graph.logp.min()) + key = new_graph.synthon_id * offset + new_graph.logp + order = key.argsort(descending=True) + new_graph = new_graph[order] + + num_candidate = new_graph.synthon_id.bincount(minlength=len(synthon)) + topk = functional.variadic_topk(new_graph.logp, num_candidate, num_beam)[1] + topk_index = topk + (num_candidate.cumsum(0) - num_candidate).unsqueeze(-1) + topk_index = torch.unique(topk_index) + new_graph = new_graph[topk_index] + result.append(new_graph[new_graph.is_stopped]) + num_added = scatter_add(new_graph.is_stopped.long(), new_graph.synthon_id, dim_size=len(synthon)) + num_prediction += num_added + + # remove samples that already hit max prediction + is_continue = (~new_graph.is_stopped) & (num_prediction[new_graph.synthon_id] < max_prediction) + graph = new_graph[is_continue] + if len(graph) == 0: + break + + result = self._cat(result) + # sort by synthon id + order = result.synthon_id.argsort() + result = result[order] + + # remove duplicate predictions + is_duplicate = [] + synthon_id = -1 + for graph in result: + if graph.synthon_id != synthon_id: + synthon_id = graph.synthon_id + smiles_set = set() + smiles = graph.to_smiles(isomeric=False, atom_map=False, canonical=True) + is_duplicate.append(smiles in smiles_set) + smiles_set.add(smiles) + is_duplicate = torch.tensor(is_duplicate, device=self.device) + result = result[~is_duplicate] + num_prediction = result.synthon_id.bincount(minlength=len(synthon)) + + # remove extra predictions + topk = functional.variadic_topk(result.logp, num_prediction, max_prediction)[1] + topk_index = topk + (num_prediction.cumsum(0) - num_prediction).unsqueeze(-1) + topk_index = topk_index.flatten(0) + topk_index_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device), topk_index[:-1]]) + is_duplicate = topk_index == topk_index_shifted + result = result[topk_index[~is_duplicate]] + + return result # (< num_graph * max_prediction) + + def _extend(self, data, num_xs, input, input2graph=None): + if input2graph is None: + num_input_per_graph = len(input) // len(num_xs) + input2graph = torch.arange(len(num_xs), device=data.device).unsqueeze(-1) + input2graph = input2graph.repeat(1, num_input_per_graph).flatten() + num_inputs = input2graph.bincount(minlength=len(num_xs)) + new_num_xs = num_xs + num_inputs + new_num_cum_xs = new_num_xs.cumsum(0) + new_num_x = new_num_cum_xs[-1].item() + new_data = torch.zeros(new_num_x, *data.shape[1:], dtype=data.dtype, device=data.device) + starts = new_num_cum_xs - new_num_xs + ends = starts + num_xs + index = functional.multi_slice_mask(starts, ends, new_num_x) + new_data[index] = data + new_data[~index] = input + return new_data, new_num_xs + + def predict_and_target(self, batch, all_loss=None, metric=None): + reactant, synthon = batch["graph"] + reactant = reactant.clone() + with reactant.graph(): + reactant.reaction = batch["reaction"] + + graph1, node_in_target1, node_out_target1, bond_target1, stop_target1 = self.all_edge(reactant, synthon) + graph2, node_in_target2, node_out_target2, bond_target2, stop_target2 = self.all_stop(reactant, synthon) + + graph = self._cat([graph1, graph2]) + + node_in_target = torch.cat([node_in_target1, node_in_target2]) + node_out_target = torch.cat([node_out_target1, node_out_target2]) + bond_target = torch.cat([bond_target1, bond_target2]) + stop_target = torch.cat([stop_target1, stop_target2]) + size = graph.num_nodes + # add new atom candidates into the size of each graph + size_ext = size + self.num_atom_type + + synthon_feature = torch.stack([graph.is_new_node, graph.is_reaction_center], dim=-1).float() + node_feature = graph.node_feature.float() + self.input_linear(synthon_feature) + output = self.model(graph, node_feature, all_loss, metric) + + node_feature = [output["node_feature"]] + graph_feature = [] + for _feature in sorted(self.feature): + if _feature == "reaction": + reaction_feature = torch.zeros(len(graph), self.num_reaction, dtype=torch.float32, device=self.device) + reaction_feature.scatter_(1, graph.reaction.unsqueeze(-1), 1) + graph_feature.append(reaction_feature) + elif _feature == "graph": + graph_feature.append(output["graph_feature"]) + elif _feature == "atom": + node_feature.append(graph.node_feature) + else: + raise ValueError("Unknown feature `%s`" % _feature) + + graph_feature = torch.cat(graph_feature, dim=-1) + # inherit graph features + node_feature.append(graph_feature[graph.node2graph]) + node_feature = torch.cat(node_feature, dim=-1) + + new_node_feature = self.new_atom_feature.weight.repeat(len(graph), 1) + new_graph_feature = graph_feature.unsqueeze(1).repeat(1, self.num_atom_type, 1).flatten(0, 1) + new_node_feature = torch.cat([new_node_feature, new_graph_feature], dim=-1) + node_feature, num_nodes_ext = self._extend(node_feature, graph.num_nodes, new_node_feature) + assert (num_nodes_ext == size_ext).all() + + node2graph_ext = torch.repeat_interleave(num_nodes_ext) + num_cum_nodes_ext = num_nodes_ext.cumsum(0) + starts = num_cum_nodes_ext - num_nodes_ext + graph.num_nodes + ends = num_cum_nodes_ext + is_new_node = functional.multi_slice_mask(starts, ends, num_cum_nodes_ext[-1]) + + node_in = node_in_target + num_cum_nodes_ext - num_nodes_ext + node_out = node_out_target + num_cum_nodes_ext - num_nodes_ext + edge = torch.stack([node_in, node_out], dim=-1) + + node_out_feature = torch.cat([node_feature[node_in][node2graph_ext], node_feature], dim=-1) + bond_feature = node_feature[edge].flatten(-2) + node_in_pred = self.node_in_mlp(node_feature).squeeze(-1) + node_out_pred = self.node_out_mlp(node_out_feature).squeeze(-1) + bond_pred = self.bond_mlp(bond_feature).squeeze(-1) + stop_pred = self.stop_mlp(graph_feature).squeeze(-1) + + infinity = torch.tensor(float("inf"), device=self.device) + # mask out node-in prediction on new atoms + node_in_pred[is_new_node] = -infinity + # mask out node-out prediction on self-loops + node_out_pred[node_in] = -infinity + + return (node_in_pred, node_out_pred, bond_pred, stop_pred), \ + (node_in_target, node_out_target, bond_target, stop_target, size_ext) + + +@R.register("tasks.Retrosynthesis") +class Retrosynthesis(tasks.Task, core.Configurable): + """ + Retrosynthesis task. + + This class wraps pretrained center identification and synthon completion modeules into a pipeline. + + Parameters: + center_identification (CenterIdentification): sub task of center identification + synthon_completion (SynthonCompletion): sub task of synthon completion + center_topk (int, optional): number of reaction centers to predict for each product + num_synthon_beam (int, optional): size of beam search for each synthon + max_prediction (int, optional): max number of final predictions for each product + metric (str or list of str, optional): metric(s). Available metrics are ``top-K``. + """ + + _option_members = {"metric"} + + def __init__(self, center_identification, synthon_completion, center_topk=2, num_synthon_beam=10, max_prediction=20, + metric=("top-1", "top-3", "top-5", "top-10")): + super(Retrosynthesis, self).__init__() + self.center_identification = center_identification + self.synthon_completion = synthon_completion + self.center_topk = center_topk + self.num_synthon_beam = num_synthon_beam + self.max_prediction = max_prediction + self.metric = metric + + def load_state_dict(self, state_dict, strict=True): + if not strict: + raise ValueError("Retrosynthesis only supports load_state_dict() with strict=True") + keys = set(state_dict.keys()) + for model in [self.center_identification, self.synthon_completion]: + if set(model.state_dict().keys()) == keys: + return model.load_state_dict(state_dict, strict) + raise RuntimeError("Neither of sub modules matches with state_dict") + + def predict(self, batch, all_loss=None, metric=None): + synthon_batch = self.center_identification.predict_synthon(batch, self.center_topk) + + synthon = synthon_batch["synthon"] + num_synthon = synthon_batch["num_synthon"] + assert (num_synthon >= 1).all() and (num_synthon <= 2).all() + synthon2split = torch.repeat_interleave(num_synthon) + with synthon.graph(): + synthon.reaction_center = synthon_batch["reaction_center"][synthon2split] + synthon.split_logp = synthon_batch["log_likelihood"][synthon2split] + + reactant = self.synthon_completion.predict_reactant(synthon_batch, self.num_synthon_beam, self.max_prediction) + + logps = [] + reactant_ids = [] + product_ids = [] + + # case 1: one synthon + is_single = num_synthon[synthon2split[reactant.synthon_id]] == 1 + reactant_id = is_single.nonzero().squeeze(-1) + logps.append(reactant.split_logp[reactant_id] + reactant.logp[reactant_id]) + product_ids.append(reactant.product_id[reactant_id]) + # pad -1 + reactant_ids.append(torch.stack([reactant_id, -torch.ones_like(reactant_id)], dim=-1)) + + # case 2: two synthons + # use proposal to avoid O(n^2) complexity + reactant1 = torch.arange(len(reactant), device=self.device) + reactant1 = reactant1.unsqueeze(-1).expand(-1, self.max_prediction * 2) + reactant2 = reactant1 + torch.arange(self.max_prediction * 2, device=self.device) + valid = reactant2 < len(reactant) + reactant1 = reactant1[valid] + reactant2 = reactant2[valid] + synthon1 = reactant.synthon_id[reactant1] + synthon2 = reactant.synthon_id[reactant2] + valid = (synthon1 < synthon2) & (synthon2split[synthon1] == synthon2split[synthon2]) + reactant1 = reactant1[valid] + reactant2 = reactant2[valid] + logps.append(reactant.split_logp[reactant1] + reactant.logp[reactant1] + reactant.logp[reactant2]) + product_ids.append(reactant.product_id[reactant1]) + reactant_ids.append(torch.stack([reactant1, reactant2], dim=-1)) + + # combine case 1 & 2 + logps = torch.cat(logps) + reactant_ids = torch.cat(reactant_ids) + product_ids = torch.cat(product_ids) + + order = product_ids.argsort() + logps = logps[order] + reactant_ids = reactant_ids[order] + num_prediction = product_ids.bincount() + logps, topk = functional.variadic_topk(logps, num_prediction, self.max_prediction) + topk_index = topk + (num_prediction.cumsum(0) - num_prediction).unsqueeze(-1) + topk_index_shifted = torch.cat([-torch.ones(len(topk_index), 1, dtype=torch.long, device=self.device), + topk_index[:, :-1]], dim=-1) + is_duplicate = topk_index == topk_index_shifted + reactant_id = reactant_ids[topk_index] # (num_graph, k, 2) + + # why we need to repeat the graph? + # because reactant_id may be duplicated, which is not directly supported by graph indexing + is_padding = reactant_id == -1 + num_synthon = (~is_padding).sum(dim=-1) + num_synthon = num_synthon[~is_duplicate] + logps = logps[~is_duplicate] + offset = torch.arange(self.max_prediction, device=self.device) * len(reactant) + reactant_id = reactant_id + offset.view(1, -1, 1) + reactant_id = reactant_id[~(is_padding | is_duplicate.unsqueeze(-1))] + reactant = reactant.repeat(self.max_prediction) + reactant = reactant[reactant_id] + assert num_synthon.sum() == len(reactant) + synthon2graph = torch.repeat_interleave(num_synthon) + first_synthon = num_synthon.cumsum(0) - num_synthon + # inherit graph attributes from the first synthon + data_dict = reactant.data_mask(graph_index=first_synthon, include="graph")[0] + # merge synthon pairs from the same split into a single graph + reactant = reactant.merge(synthon2graph) + with reactant.graph(): + for k, v in data_dict.items(): + setattr(reactant, k, v) + reactant.logps = logps + + num_prediction = reactant.product_id.bincount() + + return reactant, num_prediction # (num_graph * k) + + def target(self, batch): + reactant, product = batch["graph"] + reactant = reactant.ion_to_molecule() + return reactant + + def evaluate(self, pred, target): + pred, num_prediction = pred + infinity = torch.iinfo(torch.long).max - 1 + + metric = {} + ranking = [] + # any better solution for parallel graph isomorphism? + num_cum_prediction = num_prediction.cumsum(0) + for i in range(len(target)): + target_smiles = target[i].to_smiles(isomeric=False, atom_map=False, canonical=True) + offset = (num_cum_prediction[i] - num_prediction[i]).item() + for j in range(num_prediction[i]): + pred_smiles = pred[offset + j].to_smiles(isomeric=False, atom_map=False, canonical=True) + if pred_smiles == target_smiles: + break + else: + j = infinity + ranking.append(j + 1) + + ranking = torch.tensor(ranking, device=self.device) + for _metric in self.metric: + if _metric.startswith("top-"): + threshold = int(_metric[4:]) + score = (ranking <= threshold).float().mean() + metric["top-%d accuracy" % threshold] = score + else: + raise ValueError("Unknown metric `%s`" % _metric) + + return metric \ No newline at end of file diff --git a/build/lib/torchdrug/tasks/task.py b/build/lib/torchdrug/tasks/task.py new file mode 100644 index 00000000..72c9dbdc --- /dev/null +++ b/build/lib/torchdrug/tasks/task.py @@ -0,0 +1,39 @@ +from collections.abc import Mapping, Sequence + +from torch import nn + + +class Task(nn.Module): + + _option_members = set() + + def _standarize_option(self, x, name): + if x is None: + x = {} + elif isinstance(x, str): + x = {x: 1} + elif isinstance(x, Sequence): + x = dict.fromkeys(x, 1) + elif not isinstance(x, Mapping): + raise ValueError("Invalid value `%s` for option member `%s`" % (x, name)) + return x + + def __setattr__(self, key, value): + if key in self._option_members: + value = self._standarize_option(value, key) + super(Task, self).__setattr__(key, value) + + def preprocess(self, train_set, valid_set, test_set): + pass + + def predict_and_target(self, batch, all_loss=None, metric=None): + return self.predict(batch, all_loss, metric), self.target(batch) + + def predict(self, batch, all_loss=None, metric=None): + raise NotImplementedError + + def target(self, batch): + raise NotImplementedError + + def evaluate(self, pred, target): + raise NotImplementedError \ No newline at end of file diff --git a/build/lib/torchdrug/transforms/__init__.py b/build/lib/torchdrug/transforms/__init__.py new file mode 100644 index 00000000..f8109df9 --- /dev/null +++ b/build/lib/torchdrug/transforms/__init__.py @@ -0,0 +1,7 @@ +from .transform import NormalizeTarget, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, \ + VirtualAtom, TruncateProtein, ProteinView, Compose + +__all__ = [ + "NormalizeTarget", "RemapAtomType", "RandomBFSOrder", "Shuffle", + "VirtualNode", "VirtualAtom", "TruncateProtein", "ProteinView", "Compose", +] diff --git a/build/lib/torchdrug/transforms/transform.py b/build/lib/torchdrug/transforms/transform.py new file mode 100644 index 00000000..a971f0a4 --- /dev/null +++ b/build/lib/torchdrug/transforms/transform.py @@ -0,0 +1,314 @@ +import copy +import logging +from collections import deque + +import torch + +from torchdrug import core +from torchdrug.core import Registry as R + + +logger = logging.getLogger(__name__) + + +@R.register("transforms.NormalizeTarget") +class NormalizeTarget(core.Configurable): + """ + Normalize the target values in a sample. + + Parameters: + mean (dict of float): mean of targets + std (dict of float): standard deviation of targets + """ + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, item): + item = item.copy() + for k in self.mean: + if k in item: + item[k] = (item[k] - self.mean[k]) / self.std[k] + else: + raise ValueError("Can't find target `%s` in data item" % k) + return item + + +@R.register("transforms.RemapAtomType") +class RemapAtomType(core.Configurable): + """ + Map atom types to their index in a vocabulary. Atom types that don't present in the vocabulary are mapped to -1. + + Parameters: + atom_types (array_like): vocabulary of atom types + """ + + def __init__(self, atom_types): + atom_types = torch.as_tensor(atom_types) + self.id2atom = atom_types + self.atom2id = - torch.ones(atom_types.max() + 1, dtype=torch.long, device=atom_types.device) + self.atom2id[atom_types] = torch.arange(len(atom_types), device=atom_types.device) + + def __call__(self, item): + graph = copy.copy(item["graph"]) + graph.atom_type = self.atom2id[graph.atom_type] + item = item.copy() + item["graph"] = graph + return item + + +@R.register("transforms.RandomBFSOrder") +class RandomBFSOrder(core.Configurable): + """ + Order the nodes in a graph according to a random BFS order. + """ + + def __call__(self, item): + graph = item["graph"] + edge_list = graph.edge_list[:, :2].tolist() + neighbor = [[] for _ in range(graph.num_node)] + for h, t in edge_list: + neighbor[h].append(t) + depth = [-1] * graph.num_node + + i = torch.randint(graph.num_node, (1,)).item() + queue = deque([i]) + depth[i] = 0 + order = [] + while queue: + h = queue.popleft() + order.append(h) + for t in neighbor[h]: + if depth[t] == -1: + depth[t] = depth[h] + 1 + queue.append(t) + + item = item.copy() + item["graph"] = graph.subgraph(order) + return item + + +@R.register("transforms.Shuffle") +class Shuffle(core.Configurable): + """ + Shuffle the order of nodes and edges in a graph. + + Parameters: + shuffle_node (bool, optional): shuffle node order or not + shuffle_edge (bool, optional): shuffle edge order or not + """ + + def __init__(self, shuffle_node=True, shuffle_edge=True): + self.shuffle_node = shuffle_node + self.shuffle_edge = shuffle_edge + + def __call__(self, item): + graph = item["graph"] + data = self.transform_data(graph.data_dict, graph.meta) + + item = item.copy() + item["graph"] = type(graph)(**data) + return item + + def transform_data(self, data, meta): + edge_list = data["edge_list"] + num_node = data["num_node"] + num_edge = data["num_edge"] + if self.shuffle_edge: + node_perm = torch.randperm(num_node, device=edge_list.device) + else: + node_perm = torch.arange(num_node, device=edge_list.device) + if self.shuffle_edge: + edge_perm = torch.randperm(num_edge, device=edge_list.device) + else: + edge_perm = torch.randperm(num_edge, device=edge_list.device) + new_data = {} + for key in data: + if meta[key] == "node": + new_data[key] = data[key][node_perm] + elif meta[key] == "edge": + new_data[key] = node_perm[data[key][edge_perm]] + else: + new_data[key] = data[key] + + return new_data + + +@R.register("transforms.VirtualNode") +class VirtualNode(core.Configurable): + """ + Add a virtual node and connect it with every node in the graph. + + Parameters: + relation (int, optional): relation of virtual edges. + By default, use the maximal relation in the graph plus 1. + weight (int, optional): weight of virtual edges + node_feature (array_like, optional): feature of the virtual node + edge_feature (array_like, optional): feature of virtual edges + kwargs: other attributes of the virtual node or virtual edges + """ + + def __init__(self, relation=None, weight=1, node_feature=None, edge_feature=None, **kwargs): + self.relation = relation + self.weight = weight + + self.default = {k: torch.as_tensor(v) for k, v in kwargs.items()} + if node_feature is not None: + self.default["node_feature"] = torch.as_tensor(node_feature) + if edge_feature is not None: + self.default["edge_feature"] = torch.as_tensor(edge_feature) + + def __call__(self, item): + graph = item["graph"] + edge_list = graph.edge_list + edge_weight = graph.edge_weight + num_node = graph.num_node + num_relation = graph.num_relation + + existing_node = torch.arange(num_node, device=edge_list.device) + virtual_node = torch.ones(num_node, dtype=torch.long, device=edge_list.device) * num_node + node_in = torch.cat([virtual_node, existing_node]) + node_out = torch.cat([existing_node, virtual_node]) + if edge_list.shape[1] == 2: + new_edge = torch.stack([node_in, node_out], dim=-1) + else: + if self.relation is None: + relation = num_relation + num_relation = num_relation + 1 + else: + relation = self.relation + relation = relation * torch.ones(num_node * 2, dtype=torch.long, device=edge_list.device) + new_edge = torch.stack([node_in, node_out, relation], dim=-1) + edge_list = torch.cat([edge_list, new_edge]) + new_edge_weight = self.weight * torch.ones(num_node * 2, device=edge_weight.device) + edge_weight = torch.cat([edge_weight, new_edge_weight]) + + # add default node/edge attributes + data = graph.data_dict.copy() + for key, value in graph.meta.items(): + if value == "node": + if key in self.default: + new_data = self.default[key].unsqueeze(0) + else: + new_data = torch.zeros(1, *data[key].shape[1:], dtype=data[key].dtype, device=data[key].device) + data[key] = torch.cat([data[key], new_data]) + elif value == "edge": + if key in self.default: + repeat = [-1] * (data[key].ndim - 1) + new_data = self.default[key].expand(num_node * 2, *repeat) + else: + new_data = torch.zeros(num_node * 2, *data[key].shape[1:], + dtype=data[key].dtype, device=data[key].device) + data[key] = torch.cat([data[key], new_data]) + + graph = type(graph)(edge_list, edge_weight=edge_weight, num_node=num_node + 1, + num_relation=num_relation, meta=graph.meta, **data) + + item = item.copy() + item["graph"] = graph + return item + + +@R.register("transforms.VirtualAtom") +class VirtualAtom(VirtualNode, core.Configurable): + """ + Add a virtual atom and connect it with every atom in the molecule. + + Parameters: + atom_type (int, optional): type of the virtual atom + bond_type (int, optional): type of the virtual bonds + node_feature (array_like, optional): feature of the virtual atom + edge_feature (array_like, optional): feature of virtual bonds + kwargs: other attributes of the virtual atoms or virtual bonds + """ + + def __init__(self, atom_type=None, bond_type=None, node_feature=None, edge_feature=None, **kwargs): + super(VirtualAtom, self).__init__(relation=bond_type, weight=1, node_feature=node_feature, + edge_feature=edge_feature, atom_type=atom_type, **kwargs) + + +@R.register("transforms.TruncateProtein") +class TruncateProtein(core.Configurable): + """ + Truncate over long protein sequences into a fixed length. + + Parameters: + max_length (int, optional): maximal length of the sequence. Truncate the sequence if it exceeds this limit. + random (bool, optional): truncate the sequence at a random position. + If not, truncate the suffix of the sequence. + keys (str or list of str, optional): keys for the items that require truncation in a sample + """ + + def __init__(self, max_length=None, random=False, keys="graph"): + self.truncate_length = max_length + self.random = random + if isinstance(keys, str): + keys = [keys] + self.keys = keys + + def __call__(self, item): + new_item = item.copy() + for key in self.keys: + graph = item[key] + if graph.num_residue > self.truncate_length: + if self.random: + start = torch.randint(graph.num_residue - self.truncate_length, (1,)).item() + else: + start = 0 + end = start + self.truncate_length + mask = torch.zeros(graph.num_residue, dtype=torch.bool, device=graph.device) + mask[start:end] = True + graph = graph.subresidue(mask) + + new_item[key] = graph + return new_item + + +@R.register("transforms.ProteinView") +class ProteinView(core.Configurable): + """ + Convert proteins to a specific view. + + Parameters: + view (str): protein view. Can be ``atom`` or ``residue``. + keys (str or list of str, optional): keys for the items that require view change in a sample + """ + + def __init__(self, view, keys="graph"): + self.view = view + if isinstance(keys, str): + keys = [keys] + self.keys = keys + + def __call__(self, item): + item = item.copy() + for key in self.keys: + graph = copy.copy(item[key]) + graph.view = self.view + item[key] = graph + return item + + +@R.register("transforms.Compose") +class Compose(core.Configurable): + """ + Compose a list of transforms into one. + + Parameters: + transforms (list of callable): list of transforms + """ + + def __init__(self, transforms): + # flatten recursive composition + new_transforms = [] + for transform in transforms: + if isinstance(transform, Compose): + new_transforms += transform.transforms + elif transform is not None: + new_transforms.append(transform) + self.transforms = new_transforms + + def __call__(self, item): + for transform in self.transforms: + item = transform(item) + return item diff --git a/build/lib/torchdrug/utils/__init__.py b/build/lib/torchdrug/utils/__init__.py new file mode 100644 index 00000000..e1b7727f --- /dev/null +++ b/build/lib/torchdrug/utils/__init__.py @@ -0,0 +1,13 @@ +from .io import input_choice, literal_eval, no_rdkit_log, capture_rdkit_log +from .file import download, smart_open, extract, compute_md5, get_line_count +from .torch import load_extension, cpu, cuda, detach, clone, mean, cat, stack, sparse_coo_tensor +from .decorator import copy_args, cached_property, cached, deprecated_alias +from . import pretty, comm, plot + +__all__ = [ + "input_choice", "literal_eval", "no_rdkit_log", "capture_rdkit_log", + "download", "smart_open", "extract", "compute_md5", "get_line_count", + "load_extension", "cpu", "cuda", "detach", "clone", "mean", "cat", "stack", "sparse_coo_tensor", + "copy_args", "cached_property", "cached", "deprecated_alias", + "pretty", "comm", "plot", +] \ No newline at end of file diff --git a/build/lib/torchdrug/utils/comm.py b/build/lib/torchdrug/utils/comm.py new file mode 100644 index 00000000..817c2812 --- /dev/null +++ b/build/lib/torchdrug/utils/comm.py @@ -0,0 +1,269 @@ +import os +import multiprocessing +from collections import defaultdict + +import torch +from torch import distributed as dist + + +cpu_group = None +gpu_group = None + + +def get_rank(): + """ + Get the rank of this process in distributed processes. + + Return 0 for single process case. + """ + if dist.is_initialized(): + return dist.get_rank() + if "RANK" in os.environ: + return int(os.environ["RANK"]) + return 0 + + +def get_world_size(): + """ + Get the total number of distributed processes. + + Return 1 for single process case. + """ + if dist.is_initialized(): + return dist.get_world_size() + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + return 1 + + +def get_group(device): + """ + Get the process group corresponding to the given device. + + Parameters: + device (torch.device): query device + """ + group = cpu_group if device.type == "cpu" else gpu_group + if group is None: + raise ValueError("%s group is not initialized. Use comm.init_process_group() to initialize it" + % device.type.upper()) + return group + + +def init_process_group(backend, init_method=None, **kwargs): + """ + Initialize CPU and/or GPU process groups. + + Parameters: + backend (str): Communication backend. Use ``nccl`` for GPUs and ``gloo`` for CPUs. + init_method (str, optional): URL specifying how to initialize the process group + """ + global cpu_group + global gpu_group + + dist.init_process_group(backend, init_method, **kwargs) + gpu_group = dist.group.WORLD + if backend == "nccl": + cpu_group = dist.new_group(backend="gloo") + else: + cpu_group = gpu_group + + +def get_cpu_count(): + """ + Get the number of CPUs on this node. + """ + return multiprocessing.cpu_count() + + +def synchronize(): + """ + Synchronize among all distributed processes. + """ + if get_world_size() > 1: + dist.barrier() + + +def _recursive_read(obj): + values = defaultdict(list) + sizes = defaultdict(list) + if isinstance(obj, torch.Tensor): + values[obj.dtype] += [obj.flatten()] + sizes[obj.dtype] += [torch.tensor([obj.numel()], device=obj.device)] + elif isinstance(obj, dict): + for v in obj.values(): + child_values, child_sizes = _recursive_read(v) + for k, v in child_values.items(): + values[k] += v + for k, v in child_sizes.items(): + sizes[k] += v + elif isinstance(obj, list) or isinstance(obj, tuple): + for v in obj: + child_values, child_sizes = _recursive_read(v) + for k, v in child_values.items(): + values[k] += v + for k, v in child_sizes.items(): + sizes[k] += v + else: + raise ValueError("Unknown type `%s`" % type(obj)) + return values, sizes + + +def _recursive_write(obj, values, sizes=None): + if isinstance(obj, torch.Tensor): + if sizes is None: + size = torch.tensor([obj.numel()], device=obj.device) + else: + s = sizes[obj.dtype] + size, s = s.split([1, len(s) - 1]) + sizes[obj.dtype] = s + v = values[obj.dtype] + new_obj, v = v.split([size, v.shape[-1] - size], dim=-1) + # compatible with reduce / stack / cat + new_obj = new_obj.view(new_obj.shape[:-1] + (-1,) + obj.shape[1:]) + values[obj.dtype] = v + return new_obj, values + elif isinstance(obj, dict): + new_obj = {} + for k, v in obj.items(): + new_obj[k], values = _recursive_write(v, values, sizes) + elif isinstance(obj, list) or isinstance(obj, tuple): + new_obj = [] + for v in obj: + new_v, values = _recursive_write(v, values, sizes) + new_obj.append(new_v) + else: + raise ValueError("Unknown type `%s`" % type(obj)) + return new_obj, values + + +def reduce(obj, op="sum", dst=None): + """ + Reduce any nested container of tensors. + + Parameters: + obj (Object): any container object. Can be nested list, tuple or dict. + op (str, optional): element-wise reduction operator. + Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``. + dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. + + Example:: + + >>> # assume 4 workers + >>> rank = comm.get_rank() + >>> x = torch.rand(5) + >>> obj = {"polynomial": x ** rank} + >>> obj = comm.reduce(obj) + >>> assert torch.allclose(obj["polynomial"], x ** 3 + x ** 2 + x + 1) + """ + values = _recursive_read(obj)[0] + values = {k: torch.cat(v) for k, v in values.items()} + + is_mean = op == "mean" + if is_mean: + op = "sum" + op = getattr(dist.ReduceOp, op.upper()) + + reduced = {} + for k, v in values.items(): + dtype = v.dtype + # NCCL can't solve bool. Cast them to byte + if dtype == torch.bool: + v = v.byte() + group = get_group(v.device) + if dst is None: + dist.all_reduce(v, op=op, group=group) + else: + dist.reduce(v, op=op, dst=dst, group=group) + if is_mean: + v = v / get_world_size() + reduced[k] = v.type(dtype) + + return _recursive_write(obj, reduced)[0] + + +def stack(obj, dst=None): + """ + Stack any nested container of tensors. The new dimension will be added at the 0-th axis. + + Parameters: + obj (Object): any container object. Can be nested list, tuple or dict. + dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. + + Example:: + + >>> # assume 4 workers + >>> rank = comm.get_rank() + >>> x = torch.rand(5) + >>> obj = {"exponent": x ** rank} + >>> obj = comm.stack(obj) + >>> truth = torch.stack([torch.ones_like(x), x, x ** 2, x ** 3] + >>> assert torch.allclose(obj["exponent"], truth)) + """ + values = _recursive_read(obj)[0] + values = {k: torch.cat(v) for k, v in values.items()} + + stacked = {} + for k, v in values.items(): + dtype = v.dtype + # NCCL can't solve bool. Cast them to byte + if dtype == torch.bool: + dtype = torch.uint8 + s = torch.zeros(get_world_size(), *v.shape, dtype=dtype, device=v.device) + s[get_rank()] = v + group = get_group(s.device) + if dst is None: + dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) + else: + dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) + stacked[k] = s.type(v.dtype) + + return _recursive_write(obj, stacked)[0] + + +def cat(obj, dst=None): + """ + Concatenate any nested container of tensors along the 0-th axis. + + Parameters: + obj (Object): any container object. Can be nested list, tuple or dict. + dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. + + Example:: + + >>> # assume 4 workers + >>> rank = comm.get_rank() + >>> rng = torch.arange(10) + >>> obj = {"range": rng[rank * (rank + 1) // 2: (rank + 1) * (rank + 2) // 2]} + >>> obj = comm.cat(obj) + >>> assert torch.allclose(obj["range"], rng) + """ + values, sizes = _recursive_read(obj) + sizes = {k: torch.cat(v) for k, v in sizes.items()} + + sizes = stack(sizes) + cated = {} + for k, value in values.items(): + size = sizes[k].t().flatten() # sizes[k]: (num_worker, num_obj) + dtype = value[0].dtype + # NCCL can't solve bool. Cast them to byte + if dtype == torch.bool: + dtype = torch.uint8 + s = torch.zeros(size.sum(), dtype=dtype, device=value[0].device) + obj_id = get_rank() + world_size = get_world_size() + offset = size[:obj_id].sum() + for v in value: + assert offset + v.numel() <= len(s) + s[offset: offset + v.numel()] = v + offset += size[obj_id: obj_id + world_size].sum() + obj_id += world_size + group = get_group(s.device) + if dst is None: + dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) + else: + dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) + cated[k] = s.type(value[0].dtype) + sizes = {k: v.sum(dim=0) for k, v in sizes.items()} + + return _recursive_write(obj, cated, sizes)[0] \ No newline at end of file diff --git a/build/lib/torchdrug/utils/decorator.py b/build/lib/torchdrug/utils/decorator.py new file mode 100644 index 00000000..bd486ebe --- /dev/null +++ b/build/lib/torchdrug/utils/decorator.py @@ -0,0 +1,297 @@ +import re +import inspect +import warnings +import functools + +from decorator import decorator + +import torch +from torch import nn + +from torchdrug import data + + +def copy_args(obj, args=None, ignore=None): + """ + Copy argument documentation from another function to fill the document of \*\*kwargs in this function. + + This class should be applied as a decorator. + + Parameters: + obj (object): object to copy document from + args (tuple of str, optional): arguments to copy. + By default, it copies all argument documentation from ``obj``, + except those already exist in the current function. + ignore (tuple of str, optional): arguments to ignore + """ + + def wrapper(obj): + sig = get_signature(obj) + parameters = list(sig.parameters.values()) + if parameters[0].name == "cls" or parameters[0].name == "self": + parameters.pop(0) + docs = get_param_docs(obj) + if len(docs) != len(parameters): + raise ValueError("Fail to parse the docstring of `%s`. " + "Inconsistent number of parameters in signature and docstring." % obj.__name__) + new_params = [] + new_docs = [] + param_names = {p.name for p in parameters} + for param, doc in zip(parameters, docs): + if param.kind == inspect.Parameter.VAR_POSITIONAL: + for arg in from_args: + if arg.name in param_names: + continue + new_params.append(arg) + new_docs.append(from_docs[arg.name]) + elif param.kind == inspect.Parameter.VAR_KEYWORD: + for kwarg in from_kwargs: + if kwarg.name in param_names: + continue + new_params.append(kwarg) + new_docs.append(from_docs[kwarg.name]) + else: + new_params.append(param) + new_docs.append(doc) + + new_sig = sig.replace(parameters=new_params) + set_signature(obj, new_sig) + set_param_docs(obj, new_docs) + + return obj + + from_obj = obj + if args is not None: + args = set(args) + if ignore is not None: + ignore = set(ignore) + + sig = get_signature(from_obj) + parameters = list(sig.parameters.values()) + if parameters[0].name == "cls" or parameters[0].name == "self": + parameters.pop(0) + from_args = [] + from_kwargs = [] + for param in parameters: + if (args is None or param.name in args) and (ignore is None or param.name not in ignore): + if param.default == inspect._empty: + from_args.append(param) + else: + from_kwargs.append(param) + + from_docs = get_param_docs(from_obj, as_dict=True) + if len(from_docs) != len(parameters): + raise ValueError("Fail to parse the docstring of `%s`. " + "Inconsistent number of parameters in signature and docstring." % from_obj.__name__) + + return wrapper + + +class cached_property(property): + """ + Cache the property once computed. + """ + + def __init__(self, func): + self.func = func + self.__doc__ = func.__doc__ + + def __get__(self, obj, cls): + if obj is None: + return self + result = self.func(obj) + obj.__dict__[self.func.__name__] = result + return result + + +def cached(forward, debug=False): + """ + Cache the result of last function call. + """ + + @decorator + def wrapper(forward, self, *args, **kwargs): + + def equal(x, y): + if isinstance(x, nn.Parameter): + x = x.data + if isinstance(y, nn.Parameter): + y = y.data + if type(x) != type(y): + return False + if isinstance(x, torch.Tensor): + return x.shape == y.shape and (x == y).all() + elif isinstance(x, data.Graph): + if x.num_node != y.num_node or x.num_edge != y.num_edge or x.num_relation != y.num_relation: + return False + edge_feature = getattr(x, "edge_feature", torch.tensor(0, device=x.device)) + y_edge_feature = getattr(y, "edge_feature", torch.tensor(0, device=y.device)) + if edge_feature.shape != y_edge_feature.shape: + return False + return (x.edge_list == y.edge_list).all() and (x.edge_weight == y.edge_weight).all() \ + and (edge_feature == y_edge_feature).all() + else: + return x == y + + if self.training: + return forward(self, *args, **kwargs) + + sig = inspect.signature(forward) + func = sig.bind(self, *args, **kwargs) + func.apply_defaults() + arguments = func.arguments.copy() + arguments.pop(next(iter(arguments.keys()))) + + if hasattr(self, "_forward_cache"): + hit = True + message = [] + for k, v in arguments.items(): + if not equal(self._forward_cache[k], v): + hit = False + message.append("%s: miss" % k) + break + message.append("%s: hit" % k) + if debug: + print("[cache] %s" % ", ".join(message)) + else: + hit = False + if debug: + print("[cache] cold start") + if hit: + return self._forward_cache["result"] + else: + self._forward_cache = {} + + for k, v in arguments.items(): + if isinstance(v, torch.Tensor) or isinstance(v, data.Graph): + v = v.detach() + self._forward_cache[k] = v + result = forward(self, *args, **kwargs) + self._forward_cache["result"] = result + return result + + return wrapper(forward) + + +def deprecated_alias(**alias): + """ + Handle argument alias for a function and output deprecated warnings. + """ + + def decorate(obj): + + @functools.wraps(obj) + def wrapper(*args, **kwargs): + for key, value in alias.items(): + if key in kwargs: + if value in kwargs: + raise TypeError("%s() got values for both `%s` and `%s`" % (obj.__name__, value, key)) + warnings.warn("%s(): argument `%s` is deprecated in favor of `%s`" % (obj.__name__, key, value)) + kwargs[value] = kwargs.pop(key) + + return obj(*args, **kwargs) + + sig = get_signature(obj) + parameters = list(sig.parameters.values()) + param_docs = get_param_docs(obj, as_dict=True) + docs = list(param_docs.values()) + alias_params = [] + alias_docs = [] + for key, value in alias.items(): + param = inspect.Parameter(key, inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=None, annotation=sig.parameters[value].annotation) + alias_params.append(param) + param_doc = param_docs[value] + match = re.search(r" \(.*?\)", param_doc) + if match: + type_str = match.group() + else: + type_str = "" + alias_docs.append("%s%s: deprecated alias of ``%s``" % (key, type_str, value)) + + if parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: + new_params = parameters[:-1] + alias_params + parameters[-1:] + new_docs = docs[:-1] + alias_docs + docs[-1:] + else: + new_params = parameters + alias_params + new_docs = docs + alias_docs + new_sig = sig.replace(parameters=new_params) + set_signature(wrapper, new_sig) + set_param_docs(wrapper, new_docs) + + return wrapper + + return decorate + + +def get_param_docs(obj, as_dict=False): + doc = obj.__doc__ or "" + + match = re.search(r"Parameters:\n", doc) + if not match: + return [] + begin = match.end() + indent = re.search(r"\s+", doc[begin:]).group() + match = re.search(r"^(?!%s)" % indent, doc[begin:]) + if match: + end = begin + match.start() + else: + end = None + param_docs = [] + pattern = r"^%s\S.*(?:\n%s\s+\S.*)*" % (indent, indent) + for match in re.finditer(pattern, doc[begin:end], re.MULTILINE): + doc = match.group() + doc = re.sub("^%s" % indent, "", doc, re.MULTILINE) # remove indent + param_docs.append(doc) + if as_dict: + param_docs = {re.search("\S+", doc).group(): doc for doc in param_docs} + + return param_docs + + +def set_param_docs(obj, param_docs): + doc = obj.__doc__ or "" + if isinstance(param_docs, dict): + param_docs = param_docs.values() + + match = re.search(r"Parameters:\n", doc) + if not match: + indent = None + for match in re.finditer(r"^(\s*)", doc): + if indent is None or len(match.group(1)) < len(indent): + indent = match.group(1) + param_docs = [re.sub("^", indent, doc, re.MULTILINE) for doc in param_docs] # add indent + param_docs = "\n".join(param_docs) + doc = "\n".join([doc, "%sParameters" % indent, param_docs]) + else: + begin = match.end() + indent = re.search(r"\s*", doc[begin:]).group() + pattern = r"^%s\S.*(?:\n%s\s+\S.*)*(?:\n%s\S.*(?:\n%s\s+\S.*)*)*" % ((indent,) * 4) + end = begin + re.search(pattern, doc[begin:], re.MULTILINE).end() + param_docs = [re.sub("^", indent, doc, re.MULTILINE) for doc in param_docs] # add indent + param_docs = "\n".join(param_docs) + doc = "".join([doc[:begin], param_docs, doc[end:]]) + obj.__doc__ = doc + + +def get_signature(obj): + if hasattr(obj, "__signature__"): # already overrided + sig = obj.__signature__ + elif inspect.isclass(obj): + sig = inspect.signature(obj.__init__) + else: + sig = inspect.signature(obj) + + return sig + + +def set_signature(obj, sig): + doc = obj.__doc__ or "" + match = re.search(r"^\s*\W+\(.*?\)( *-> *\W+)?", doc, re.MULTILINE) + if not match: + doc = "%s%s\n%s" % (obj.__name__, sig, doc) + else: + begin, end = match.span() + doc = "".join([doc[:begin], obj.__name__, str(sig), doc[end:]]) + obj.__doc__ = doc + obj.__signature__ = sig diff --git a/build/lib/torchdrug/utils/extension/__init__.py b/build/lib/torchdrug/utils/extension/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/torchdrug/utils/extension/torch_ext.cpp b/build/lib/torchdrug/utils/extension/torch_ext.cpp new file mode 100644 index 00000000..15d01514 --- /dev/null +++ b/build/lib/torchdrug/utils/extension/torch_ext.cpp @@ -0,0 +1,14 @@ +#include + +namespace at { + +Tensor sparse_coo_tensor_unsafe(const Tensor &indices, const Tensor &values, IntArrayRef size) { + return _sparse_coo_tensor_unsafe(indices, values, size, values.options().layout(kSparse)); +} + +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sparse_coo_tensor_unsafe", &at::sparse_coo_tensor_unsafe, + "Construct sparse COO tensor without index check"); +} \ No newline at end of file diff --git a/build/lib/torchdrug/utils/file.py b/build/lib/torchdrug/utils/file.py new file mode 100644 index 00000000..100a775e --- /dev/null +++ b/build/lib/torchdrug/utils/file.py @@ -0,0 +1,168 @@ +import os +import struct +import logging +from tqdm import tqdm + + +logger = logging.getLogger(__name__) + + +def download(url, path, save_file=None, md5=None): + """ + Download a file from the specified url. + Skip the downloading step if there exists a file satisfying the given MD5. + + Parameters: + url (str): URL to download + path (str): path to store the downloaded file + save_file (str, optional): name of save file. If not specified, infer the file name from the URL. + md5 (str, optional): MD5 of the file + """ + from six.moves.urllib.request import urlretrieve + + if save_file is None: + save_file = os.path.basename(url) + if "?" in save_file: + save_file = save_file[:save_file.find("?")] + save_file = os.path.join(path, save_file) + + if not os.path.exists(save_file) or compute_md5(save_file) != md5: + logger.info("Downloading %s to %s" % (url, save_file)) + urlretrieve(url, save_file) + return save_file + + +def smart_open(file_name, mode="rb"): + """ + Open a regular file or a zipped file. + + This function can be used as drop-in replacement of the builtin function `open()`. + + Parameters: + file_name (str): file name + mode (str, optional): open mode for the file stream + """ + import bz2 + import gzip + + extension = os.path.splitext(file_name)[1] + if extension == '.bz2': + return bz2.BZ2File(file_name, mode) + elif extension == '.gz': + return gzip.GzipFile(file_name, mode) + else: + return open(file_name, mode) + + +def extract(zip_file, member=None): + """ + Extract files from a zip file. Currently, ``zip``, ``gz``, ``tar.gz``, ``tar`` file types are supported. + + Parameters: + zip_file (str): file name + member (str, optional): extract specific member from the zip file. + If not specified, extract all members. + """ + import gzip + import shutil + import zipfile + import tarfile + + zip_name, extension = os.path.splitext(zip_file) + if zip_name.endswith(".tar"): + extension = ".tar" + extension + zip_name = zip_name[:-4] + save_path = os.path.dirname(zip_file) + + if extension == ".gz": + member = os.path.basename(zip_name) + members = [member] + save_files = [os.path.join(save_path, member)] + for _member, save_file in zip(members, save_files): + with open(zip_file, "rb") as fin: + fin.seek(-4, 2) + file_size = struct.unpack(">> with utils.capture_rdkit_log() as log: + >>> ... + >>> print(log.content) + """ + return CaptureStdIO(True, True) \ No newline at end of file diff --git a/build/lib/torchdrug/utils/plot.py b/build/lib/torchdrug/utils/plot.py new file mode 100644 index 00000000..6ed102df --- /dev/null +++ b/build/lib/torchdrug/utils/plot.py @@ -0,0 +1,175 @@ +import io +import os +import json +import jinja2 +from PIL import Image + +from rdkit.Chem import AllChem, Draw + + +path = os.path.join(os.path.dirname(__file__), "template") + + +def reaction(reactants, products, save_file=None, figure_size=(3, 3), atom_map=False): + """ + Visualize a chemical reaction. + + Parameters: + reactants (list of Molecule): list of reactants + products (list of Molecule): list of products + save_file (str, optional): save_file (str, optional): ``png`` file to save visualization. + If not provided, show the figure in window. + figure_size (tuple of int, optional): width and height of the figure + atom_map (bool, optional): visualize atom mapping or not + """ + rxn = AllChem.ChemicalReaction() + for reactant in reactants: + mol = reactant.to_molecule() + if not atom_map: + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + rxn.AddReactantTemplate(mol) + for product in products: + mol = product.to_molecule() + if not atom_map: + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + rxn.AddProductTemplate(mol) + size = [100 * s for s in figure_size] + img = Draw.ReactionToImage(rxn, size) + + if save_file is None: + img.show() + else: + img.save(save_file) + + +def highlight(molecule, atoms=None, bonds=None, atom_colors=None, bond_colors=None, save_file=None, figure_size=(3, 3), + atom_map=False): + """ + Visualize a molecule with highlighted atoms or bonds. + + Parameters: + molecule (Molecule): molecule to visualize + atoms (list of int): indexes of atoms to highlight + bonds (list of int): indexes of bonds to highlight + atom_colors (tuple or dict): highlight color for atoms. + Can be a tuple of 3 float between 0 and 1, or a dict that maps each index to a different color. + bond_colors (tuple or dict): highlight color for bonds. + Can be a tuple of 3 float between 0 and 1, or a dict that maps each index to a different color. + save_file (str, optional): save_file (str, optional): ``png`` file to save visualization. + If not provided, show the figure in window. + figure_size (tuple of int, optional): width and height of the figure + atom_map (bool, optional): visualize atom mapping or not + """ + if not isinstance(atom_colors, dict): + atom_colors = dict.fromkeys(atoms, atom_colors) + if not isinstance(bond_colors, dict): + bond_colors = dict.fromkeys(bonds, bond_colors) + + mol = molecule.to_molecule() + if not atom_map: + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + size = [100 * s for s in figure_size] + canvas = Draw.rdMolDraw2D.MolDraw2DCairo(*size) + Draw.rdMolDraw2D.PrepareAndDrawMolecule(canvas, mol, highlightAtoms=atoms, highlightBonds=bonds, + highlightAtomColors=atom_colors, highlightBondColors=bond_colors) + + if save_file is None: + stream = io.BytesIO(canvas.GetDrawingText()) + img = Image.open(stream) + img.show() + else: + canvas.WriteDrawingText(save_file) + + +def echarts(graph, title=None, node_colors=None, edge_colors=None, node_labels=None, relation_labels=None, + node_types=None, type_labels=None, dynamic_size=False, dynamic_width=False, save_file=None): + """ + Visualize a graph in ECharts. + + Parameters: + graph (Graph): graph to visualize + title (str, optional): title of the graph + node_colors (dict, optional): specify colors for some nodes. + Each color is either a tuple of 3 integers between 0 and 255, or a hex color code. + edge_colors (dict, optional): specify colors for some edges. + Each color is either a tuple of 3 integers between 0 and 255, or a hex color code. + node_labels (list of str, optional): labels for each node + relation_labels (list of str, optional): labels for each relation + node_types (list of int, optional): type for each node + type_labels (list of str, optional): labels for each node type + dynamic_size (bool, optional): if true, set the size of nodes based on the logarithm of degrees + dynamic_width (bool, optional): if true, set the width of edges based on the edge weights + save_file (str, optional): ``html`` file to save visualization, accompanied by a ``json`` file + """ + if dynamic_size: + symbol_size = (graph.degree_in + graph.degree_out + 2).log() + symbol_size = symbol_size / symbol_size.mean() * 10 + symbol_size = symbol_size.tolist() + else: + symbol_size = [10] * graph.num_node + nodes = [] + node_colors = node_colors or {} + for i in range(graph.num_node): + node = { + "id": i, + "symbolSize": symbol_size[i], + } + if i in node_colors: + color = node_colors[i] + if isinstance(color, tuple): + color = "rgb%s" % (color,) + node["itemStyle"] = {"color": color} + if node_labels: + node["name"] = node_labels[i] + if node_types: + node["category"] = node_types[i] + nodes.append(node) + + if dynamic_width: + width = graph.edge_weight / graph.edge_weight.mean() * 3 + width = width.tolist() + else: + width = [3] * graph.num_edge + edges = [] + if graph.num_relation: + node_in, node_out, relation = graph.edge_list.t().tolist() + else: + node_in, node_out = graph.edge_list.t().tolist() + relation = None + edge_colors = edge_colors or {} + for i in range(graph.num_edge): + edge = { + "source": node_in[i], + "target": node_out[i], + "lineStyle": {"width": width[i]}, + } + if i in edge_colors: + color = edge_colors[i] + if isinstance(color, tuple): + color = "rgb%s" % (color,) + edge["lineStyle"] = {"color": color} + if relation_labels: + edge["value"] = relation_labels[relation[i]] + edges.append(edge) + + json_file = os.path.splitext(save_file)[0] + ".json" + data = { + "title": title, + "nodes": nodes, + "edges": edges, + } + if type_labels: + data["categories"] = [{"name": label} for label in type_labels] + variables = { + "data_file": os.path.basename(json_file), + "show_label": "true" if node_labels else "false", + } + with open(os.path.join(path, "echarts.html"), "r") as fin, open(save_file, "w") as fout: + template = jinja2.Template(fin.read()) + instance = template.render(variables) + fout.write(instance) + with open(json_file, "w") as fout: + json.dump(data, fout, sort_keys=True, indent=4) \ No newline at end of file diff --git a/build/lib/torchdrug/utils/pretty.py b/build/lib/torchdrug/utils/pretty.py new file mode 100644 index 00000000..9fc622b6 --- /dev/null +++ b/build/lib/torchdrug/utils/pretty.py @@ -0,0 +1,75 @@ +import pprint +from itertools import islice, chain + + +separator = ">" * 30 +line = "-" * 30 + + +class Ellipsis(object): + + def __repr__(self): + return "..." + + +ellipsis = Ellipsis() + + +class PrettyPrinter(pprint.PrettyPrinter): + + truncation = 10 + display = 3 + + def _format_items(self, items, stream, indent, allowance, context, level): + if self._compact and len(items) > self.truncation: + items = chain(islice(items, self.display), [ellipsis], islice(items, len(items) - self.display, None)) + super(PrettyPrinter, self)._format_items(items, stream, indent, allowance, context, level) + + +def print(object, *args, **kwargs): + """ + Print a python object to a stream. + """ + return PrettyPrinter(*args, **kwargs).pprint(object) + + +def format(object, *args, **kwargs): + """ + Format a python object as a string. + """ + return PrettyPrinter(*args, **kwargs).pformat(object) + + +def time(seconds): + """ + Format time as a string. + + Parameters: + seconds (float): time in seconds + """ + sec_per_min = 60 + sec_per_hour = 60 * 60 + sec_per_day = 24 * 60 * 60 + + if seconds > sec_per_day: + return "%.2f days" % (seconds / sec_per_day) + elif seconds > sec_per_hour: + return "%.2f hours" % (seconds / sec_per_hour) + elif seconds > sec_per_min: + return "%.2f mins" % (seconds / sec_per_min) + else: + return "%.2f secs" % seconds + + +def long_array(array, truncation=10, display=3): + """ + Format an array as a string. + + Parameters: + array (array_like): array-like data + truncation (int, optional): truncate array if its length exceeds this threshold + display (int, optional): number of elements to display at the beginning and the end in truncated mode + """ + if len(array) <= truncation: + return "%s" % array + return "%s, ..., %s" % (str(array[:display])[:-1], str(array[-display:])[1:]) \ No newline at end of file diff --git a/build/lib/torchdrug/utils/template/echarts.html b/build/lib/torchdrug/utils/template/echarts.html new file mode 100644 index 00000000..d88eb603 --- /dev/null +++ b/build/lib/torchdrug/utils/template/echarts.html @@ -0,0 +1,63 @@ + + + + + + +
+ + + + + + \ No newline at end of file diff --git a/build/lib/torchdrug/utils/torch.py b/build/lib/torchdrug/utils/torch.py new file mode 100644 index 00000000..ff7088a7 --- /dev/null +++ b/build/lib/torchdrug/utils/torch.py @@ -0,0 +1,190 @@ +import os + +import torch +from torch.utils import cpp_extension + +from torchdrug import data +from . import decorator, comm + + +class LazyExtensionLoader(object): + + def __init__(self, name, sources, extra_cflags=None, extra_cuda_cflags=None, extra_ldflags=None, + extra_include_paths=None, build_directory=None, verbose=False, **kwargs): + self.name = name + self.sources = sources + self.extra_cflags = extra_cflags + self.extra_cuda_cflags = extra_cuda_cflags + self.extra_ldflags = extra_ldflags + self.extra_include_paths = extra_include_paths + worker_name = "%s_%d" % (name, comm.get_rank()) + self.build_directory = build_directory or cpp_extension._get_build_directory(worker_name, verbose) + self.verbose = verbose + self.kwargs = kwargs + + def __getattr__(self, key): + return getattr(self.module, key) + + @decorator.cached_property + def module(self): + return cpp_extension.load(self.name, self.sources, self.extra_cflags, self.extra_cuda_cflags, + self.extra_ldflags, self.extra_include_paths, self.build_directory, + self.verbose, **self.kwargs) + + +def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs): + """ + Load a PyTorch C++ extension just-in-time (JIT). + Automatically decide the compilation flags if not specified. + + This function performs lazy evaluation and is multi-process-safe. + + See `torch.utils.cpp_extension.load`_ for more details. + + .. _torch.utils.cpp_extension.load: + https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load + """ + if extra_cflags is None: + extra_cflags = ["-Ofast"] + if torch.backends.openmp.is_available(): + extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"] + else: + extra_cflags.append("-DAT_PARALLEL_NATIVE") + if extra_cuda_cflags is None: + if torch.cuda.is_available(): + extra_cuda_cflags = ["-O3"] + extra_cflags.append("-DCUDA_OP") + else: + new_sources = [] + for source in sources: + if not cpp_extension._is_cuda_file(source): + new_sources.append(source) + sources = new_sources + + return LazyExtensionLoader(name, sources, extra_cflags, extra_cuda_cflags, **kwargs) + + +def cpu(obj, *args, **kwargs): + """ + Transfer any nested container of tensors to CPU. + """ + if hasattr(obj, "cpu"): + return obj.cpu(*args, **kwargs) + elif isinstance(obj, (str, bytes)): + return obj + elif isinstance(obj, dict): + return type(obj)({k: cpu(v, *args, **kwargs) for k, v in obj.items()}) + elif isinstance(obj, (list, tuple)): + return type(obj)(cpu(x, *args, **kwargs) for x in obj) + + raise TypeError("Can't transfer object type `%s`" % type(obj)) + + +def cuda(obj, *args, **kwargs): + """ + Transfer any nested container of tensors to CUDA. + """ + if hasattr(obj, "cuda"): + return obj.cuda(*args, **kwargs) + elif isinstance(obj, (str, bytes)): + return obj + elif isinstance(obj, dict): + return type(obj)({k: cuda(v, *args, **kwargs) for k, v in obj.items()}) + elif isinstance(obj, (list, tuple)): + return type(obj)(cuda(x, *args, **kwargs) for x in obj) + + raise TypeError("Can't transfer object type `%s`" % type(obj)) + + +def detach(obj): + """ + Detach tensors in any nested conatiner. + """ + if hasattr(obj, "detach"): + return obj.detach() + elif isinstance(obj, dict): + return type(obj)({k: detach(v) for k, v in obj.items()}) + elif isinstance(obj, (list, tuple)): + return type(obj)(detach(x) for x in obj) + + raise TypeError("Can't perform detach over object type `%s`" % type(obj)) + + +def clone(obj, *args, **kwargs): + """ + Clone tensors in any nested conatiner. + """ + if hasattr(obj, "clone"): + return obj.clone(*args, **kwargs) + elif isinstance(obj, dict): + return type(obj)({k: clone(v, *args, **kwargs) for k, v in obj.items()}) + elif isinstance(obj, (list, tuple)): + return type(obj)(clone(x, *args, **kwargs) for x in obj) + + raise TypeError("Can't perform detach over object type `%s`" % type(obj)) + + +def mean(obj, *args, **kwargs): + """ + Compute mean of tensors in any nested container. + """ + if hasattr(obj, "mean"): + return obj.mean(*args, **kwargs) + elif isinstance(obj, dict): + return type(obj)({k: mean(v, *args, **kwargs) for k, v in obj.items()}) + elif isinstance(obj, (list, tuple)): + return type(obj)(mean(x, *args, **kwargs) for x in obj) + + raise TypeError("Can't perform mean over object type `%s`" % type(obj)) + + +def cat(objs, *args, **kwargs): + """ + Concatenate a list of nested containers with the same structure. + """ + obj = objs[0] + if isinstance(obj, torch.Tensor): + return torch.cat(objs, *args, **kwargs) + elif isinstance(obj, data.PackedGraph): + return data.cat(objs) + elif isinstance(obj, dict): + return {k: cat([x[k] for x in objs], *args, **kwargs) for k in obj} + elif isinstance(obj, (list, tuple)): + return type(obj)(cat(xs, *args, **kwargs) for xs in zip(*objs)) + + raise TypeError("Can't perform concatenation over object type `%s`" % type(obj)) + + +def stack(objs, *args, **kwargs): + """ + Stack a list of nested containers with the same structure. + """ + obj = objs[0] + if isinstance(obj, torch.Tensor): + return torch.stack(objs, *args, **kwargs) + elif isinstance(obj, dict): + return {k: stack([x[k] for x in objs], *args, **kwargs) for k in obj} + elif isinstance(obj, (list, tuple)): + return type(obj)(stack(xs, *args, **kwargs) for xs in zip(*objs)) + + raise TypeError("Can't perform stack over object type `%s`" % type(obj)) + + +def sparse_coo_tensor(indices, values, size): + """ + Construct a sparse COO tensor without index check. Much faster than `torch.sparse_coo_tensor`_. + + .. _torch.sparse_coo_tensor: + https://pytorch.org/docs/stable/generated/torch.sparse_coo_tensor.html + + Parameters: + indices (Tensor): 2D indices of shape (2, n) + values (Tensor): values of shape (n,) + size (list): size of the tensor + """ + return torch_ext.sparse_coo_tensor_unsafe(indices, values, size) + + +path = os.path.join(os.path.dirname(__file__), "extension") + +torch_ext = load_extension("torch_ext", [os.path.join(path, "torch_ext.cpp")]) \ No newline at end of file diff --git a/dist/torchdrug-0.2.1-py3.9.egg b/dist/torchdrug-0.2.1-py3.9.egg new file mode 100644 index 00000000..351ae751 Binary files /dev/null and b/dist/torchdrug-0.2.1-py3.9.egg differ diff --git a/torchdrug/core/engine.py b/torchdrug/core/engine.py index 1bc83e86..e2063eed 100644 --- a/torchdrug/core/engine.py +++ b/torchdrug/core/engine.py @@ -238,8 +238,9 @@ def load(self, checkpoint, load_optimizer=True, strict=True): logger.warning("Load checkpoint from %s" % checkpoint) checkpoint = os.path.expanduser(checkpoint) state = torch.load(checkpoint, map_location=self.device) - - self.model.load_state_dict(state["model"], strict=strict) + state["model"].pop("graph") # Made changes as per Issue #89 + state["model"].pop("fact_graph") # Made changes as per Issue #89 + self.model.load_state_dict(state["model"], strict=False) if load_optimizer: self.optimizer.load_state_dict(state["optimizer"]) diff --git a/torchdrug/layers/functional/extension/rspmm.h b/torchdrug/layers/functional/extension/rspmm.h index c09216a5..8025f4d7 100644 --- a/torchdrug/layers/functional/extension/rspmm.h +++ b/torchdrug/layers/functional/extension/rspmm.h @@ -3,7 +3,7 @@ #include #include -#include +#include namespace at { diff --git a/torchdrug/layers/functional/extension/spmm.h b/torchdrug/layers/functional/extension/spmm.h index 94004d67..905e7836 100644 --- a/torchdrug/layers/functional/extension/spmm.h +++ b/torchdrug/layers/functional/extension/spmm.h @@ -3,7 +3,7 @@ #include #include -#include +#include #include "rspmm.h" diff --git a/torchdrug/tasks/reasoning.py b/torchdrug/tasks/reasoning.py index 1052a973..802c44c0 100644 --- a/torchdrug/tasks/reasoning.py +++ b/torchdrug/tasks/reasoning.py @@ -250,4 +250,4 @@ def _strict_negative(self, pos_h_index, pos_t_index, pos_r_index): neg_index = torch.cat([neg_t_index, neg_h_index]) - return neg_index \ No newline at end of file + return neg_index