Source code for mooon.flag

from typing import Callable

import torch
import torch.nn as nn
from torch import Tensor

classes = __all__ = ["FLAG"]


[docs]class FLAG(nn.Module): r"""The Free Large-scale Adversarial Augmentation (FLAG) from the `"Robust Optimization as Data Augmentation for Large-scale Graphs" <https://arxiv.org/abs/2010.09891>`_ paper Parameters ---------- criterion : Callable The loss function to be used for training the model. steps : int, optional The number of steps to be taken for adversarial trainin, by default 3 step_size : float, optional The size of the perturbation to be added to the input at each step, by default 1e-3 Example ------- .. code-block:: python import torch from mooon import FLAG data = ... # PyG-like data model = ... # GNN model optimizer = torch.optim.Adam() criterion = torch.nn.CrossEntropycriterion() flag = FLAG(criterion) def forward(perturb): out = model(data.x + perturb, data.edge_index, data.edge_attr) return out[data.train_mask] def train(): model.train() optimizer.zero_grad() loss = flag(forward, data.x, data.y[data.train_mask]) loss.backward() optimizer.step() return float(loss) train() Reference: https://github.com/devnkong/FLAG """ def __init__( self, criterion: Callable, steps: int = 3, step_size: float = 1e-3, ): super().__init__() self.criterion = criterion self.steps = steps self.step_size = step_size
[docs] def forward(self, forward: Callable, x: Tensor, y: Tensor) -> Tensor: r"""Performs forward pass and adversarial training with FLAG algorithm. Parameters ---------- forward : Callable The self-defined forward function of the model, which accepts obj:`perturb` as input. x : Tensor The input node features. y : Tensor The target node labels. Returns ------- Tensor The loss after adversarial training. """ criterion = self.criterion step_size = self.step_size perturb = torch.empty_like(x).uniform_(-step_size, step_size) perturb.requires_grad_() out = forward(perturb) loss = criterion(out, y) / self.steps for _ in range(self.steps - 1): loss.backward() perturb_data = perturb.detach() + step_size * torch.sign( perturb.grad.detach()) perturb.data = perturb_data.data perturb.grad[:] = 0 out = forward(perturb) loss = criterion(out, y) / self.steps return loss
def __repr__(self): return (f"{self.__class__.__name__}(criterion={self.criterion}, " f"steps={self.steps}, step_size={self.step_size})")