Source code for mooon.functional

from typing import Optional, Tuple, Union

import torch
from torch import Tensor

try:
    import torch_cluster  # noqa
    random_walk = torch.ops.torch_cluster.random_walk
except ImportError:
    random_walk = None

from torch_geometric.utils import degree, sort_edge_index, subgraph
from torch_geometric.utils.num_nodes import maybe_num_nodes

classes = __all__ = [
    "drop_edge",
    "drop_node",
    "drop_path",
    "add_random_walk_edge",
]


[docs]def drop_edge( edge_index: Tensor, edge_weight: Optional[Tensor] = None, p: float = 0.5, training: bool = True, ) -> Tuple[Tensor, Optional[Tensor]]: r"""DropEdge: Sampling edge using a uniform distribution from the `"DropEdge: Towards Deep Graph Convolutional Networks on Node Classification" <https://arxiv.org/abs/1907.10903>`_ paper (ICLR'20) Parameters ---------- edge_index : torch.Tensor the input edge index edge_weight : Optional[Tensor], optional the input edge weight, by default None p : float, optional the probability of dropping out on each edge, by default 0.5 training : bool, optional whether the model is during training, do nothing if :obj:`training=True`, by default True Returns ------- Tuple[Tensor, Optional[Tensor]] the output edge index and edge weight Raises ------ ValueError p is out of range [0,1] Example ------- .. code-block:: python from mooon import drop_edge edge_index = torch.tensor([[1, 2], [3,4]]) drop_edge(edge_index, p=0.5) See also -------- :class:`mooon.DropEdge` """ if p < 0. or p > 1.: raise ValueError(f'Dropout probability has to be between 0 and 1 ' f'(got {p}') if not training or not p: return edge_index, edge_weight edge_mask = torch.rand(edge_index.size(1), device=edge_index.device) >= p edge_index = edge_index[:, edge_mask] if edge_weight is not None: edge_weight = edge_weight[edge_mask] return edge_index, edge_weight
[docs]def drop_node( edge_index: Tensor, edge_weight: Optional[Tensor] = None, p: float = 0.5, training: bool = True, num_nodes: Optional[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: """DropNode: Sampling node using a uniform distribution from the `"Graph Contrastive Learning with Augmentations" <https://arxiv.org/abs/2010.139023>`_ paper (NeurIPS'20) Parameters ---------- edge_index : torch.Tensor the input edge index edge_weight : Optional[Tensor], optional the input edge weight, by default None p : float, optional the probability of dropping out on each node, by default 0.5 training : bool, optional whether the model is during training, do nothing if :obj:`training=True`, by default True Returns ------- Tuple[Tensor, Optional[Tensor]] the output edge index and edge weight Raises ------ ValueError p is out of range [0,1] Example ------- .. code-block:: python from mooon import drop_node edge_index = torch.tensor([[1, 2], [3,4]]) drop_node(edge_index, p=0.5) See also -------- :class:`mooon.DropNode` """ if p < 0. or p > 1.: raise ValueError(f'Dropout probability has to be between 0 and 1 ' f'(got {p}') if not training or not p: return edge_index, edge_weight num_nodes = maybe_num_nodes(edge_index, num_nodes) prob = torch.rand(num_nodes, device=edge_index.device) node_mask = prob > p return subgraph(node_mask, edge_index, edge_weight)
[docs]def drop_path( edge_index: Tensor, edge_weight: Optional[Tensor] = None, p: float = 0.5, walks_per_node: int = 1, walk_length: int = 3, num_nodes: Optional[int] = None, start: Union[str, Tensor] = 'node', is_sorted: bool = False, training: bool = True, ) -> Tuple[Tensor, Optional[Tensor]]: """DropPath: a structured form of :class:`~mooon.drop_edge` from the `"MaskGAE: Masked Graph Modeling Meets Graph Autoencoders" <https://arxiv.org/abs/2205.10053>`_ paper (arXiv'22) Parameters ---------- edge_index : torch.Tensor the input edge index edge_weight : Optional[Tensor], optional the input edge weight, by default None p : float, optional the percentage of nodes in the graph that chosen as root nodes to perform random walks. By default, :obj:`p=0.5`. walks_per_node : int, optional number of walks per node, by default 1 walk_length : int, optional number of walk length per node, by default 3 num_nodes : int, optional number of total nodes in the graph, by default None start : Union[str, Tensor], optional the type of starting node chosen from "node", "edge", or custom nodes, by default 'node' is_sorted : bool, optional whether the input :obj:`edge_index` is sorted training : bool, optional whether the model is during training, do nothing if :obj:`training=True`, by default True Returns ------- Tuple[Tensor, Optional[Tensor]] the output edge index and edge weight Raises ------ ImportError if :class:`torch_cluster` is not installed. ValueError :obj:`p` is out of scope [0,1] ValueError :obj:`p` is not integer value or a Tensor Example ------- .. code-block:: python from mooon import drop_path edge_index = torch.tensor([[1, 2], [3,4]]) drop_path(edge_index, p=0.5) # specify root nodes drop_path(edge_index, start=torch.tensor([1,2])) See also -------- :class:`mooon.DropPath` """ if torch_cluster is None: raise ImportError("`torch_cluster` is not installed.") if not training: return edge_index, edge_weight if p < 0. or p > 1.: raise ValueError(f'Sample probability has to be between 0 and 1 ' f'(got {p}') assert isinstance(start, Tensor) or start in ['node', 'edge'] num_edges = edge_index.size(1) edge_mask = edge_index.new_ones(num_edges, dtype=torch.bool) if not training or p == 0.0: return edge_index, edge_mask if random_walk is None: raise ImportError('`drop_path` requires `torch-cluster`.') num_nodes = maybe_num_nodes(edge_index, num_nodes) if not is_sorted: edge_index = sort_edge_index(edge_index, edge_weight, num_nodes=num_nodes) if edge_weight is not None: edge_index, edge_weight = edge_index row, col = edge_index if start == 'edge': sample_mask = torch.rand(row.size(0), device=edge_index.device) <= p start = row[sample_mask].repeat(walks_per_node) elif start == 'node': perm = torch.randperm(num_nodes, device=edge_index.device) start = perm[:round(num_nodes * p)].repeat(walks_per_node) elif start.dtype == torch.bool: start = start.nonzero().view(-1) deg = degree(row, num_nodes=num_nodes) rowptr = row.new_zeros(num_nodes + 1) torch.cumsum(deg, 0, out=rowptr[1:]) n_id, e_id = random_walk(rowptr, col, start, walk_length, 1.0, 1.0) e_id = e_id[e_id != -1].view(-1) # filter illegal edges edge_mask[e_id] = False if edge_weight is not None: edge_weight = edge_weight[edge_mask] return edge_index[:, edge_mask], edge_weight
[docs]def add_random_walk_edge( edge_index: Tensor, edge_weight: Optional[Tensor] = None, start: Optional[Tensor] = None, walks_per_node: int = 1, walk_length: int = 3, skip_first: bool = True, num_nodes: Optional[int] = None, is_sorted: bool = False, training: bool = True, ) -> Tuple[Tensor, Optional[Tensor]]: """Adds edges and corresponding edge weights based on random walks. Parameters ---------- edge_index : torch.Tensor the input edge index edge_weight : Optional[Tensor], optional the input edge weight, by default None start : Tensor, optional the starting node to perform random walks, if None, use all nodes in the graph as root nodes, by default None walks_per_node : int, optional number of walks per node, by default 1 walk_length : int, optional number of walk length per node, by default 3 skip_first : bool, optional whether to skip the first-hop node when adding edges between root nodes and nodes sampled from random walks, by default False num_nodes : int, optional number of total nodes in the graph, by default None is_sorted : bool, optional whether the input :obj:`edge_index` is sorted training : bool, optional whether the model is during training, do nothing if :obj:`training=True`, by default True Returns ------- Tuple[Tensor, Optional[Tensor]] the output edge index and edge weight Raises ------ ImportError if :class:`torch_cluster` is not installed. Example ------- .. code-block:: python from mooon import add_random_walk_edge edge_index = torch.tensor([[1, 2], [3,4]]) add_random_walk_edge(edge_index) # specify root nodes add_random_walk_edge(edge_index, start=torch.tensor([1,2])) See also -------- :class:`mooon.AddRandomWalkEdge` """ if random_walk is None: raise ImportError('`add_random_walk_edge` requires `torch-cluster`.') if not training: return edge_index, edge_weight num_nodes = maybe_num_nodes(edge_index, num_nodes) if not is_sorted: edge_index = sort_edge_index(edge_index, edge_weight, num_nodes=num_nodes) if edge_weight is not None: edge_index, edge_weight = edge_index row, col = edge_index device = edge_index.device if start is None: num_nodes = maybe_num_nodes(edge_index, num_nodes) start = torch.arange(num_nodes, device=device) elif start.dtype == torch.bool: start = start.nonzero().view(-1) start = start.repeat(walks_per_node) deg = degree(row, num_nodes=num_nodes) rowptr = row.new_zeros(num_nodes + 1) torch.cumsum(deg, 0, out=rowptr[1:]) p = q = 1.0 walks = random_walk(rowptr, col, start, walk_length, p, q)[0] if skip_first: assert walk_length > 1 rw_row = walks[:, [0]].repeat(1, walk_length - 1) rw_col = walks[:, 2:] else: rw_row = walks[:, [0]].repeat(1, walk_length) rw_col = walks[:, 1:] aug_edge_index = torch.stack([rw_row, rw_col]).view(2, -1).contiguous() # filter self-loops mask = aug_edge_index[0] != aug_edge_index[1] aug_edge_index = aug_edge_index[:, mask] edge_index = torch.cat([edge_index, aug_edge_index], dim=1) if edge_weight is not None: assert edge_weight.ndim == 1 aug_edge_weight = 1. / torch.arange( int(skip_first) + 1, walk_length + 1, dtype=torch.float, device=device) aug_edge_weight = aug_edge_weight.repeat(start.size(0))[mask] edge_weight = torch.cat([edge_weight, aug_edge_weight]) return edge_index, edge_weight