Source code for gnp.models.gnp

import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import global_mean_pool

from . import layers as _layers


[docs] class GNP(nn.Module): """ General purpose model for Geometric Neural Operators (GNPs). This model consists of a lifting layer, multiple graph convolution blocks, and a final projection layer. Parameters ---------- node_dim : int Input dimension of node features. edge_dim : int Input dimension of edge features. out_dim : int Output dimension. layers : list[int] List of hidden dimensions. conv_name : str Name of the convolution block class to use. conv_args : dict Arguments for the convolution block. nonlinearity : str The name of the activation function to use. skip_connection: bool Whether to use a skip connection for each layer. device : str The device to place the model on. Defaults to "cuda". """ def __init__( self, node_dim: int, edge_dim: int, out_dim: int, layers: list[int], conv_name: str, conv_args: dict, nonlinearity: str, skip_connection: bool, device: str, ): super().__init__() self.device = device self.layers = layers self.depth = len(layers) - 1 self.edge_dim = edge_dim self.out_dim = out_dim self.nonlinearity = nonlinearity self.lift = nn.Linear(node_dim, layers[0]) self.proj = nn.Linear(layers[-1], out_dim) self.activation = _layers.get_activation(nonlinearity) self.num_parameters = sum( p.numel() for p in self.parameters() if p.requires_grad ) if not hasattr(_layers, conv_name): raise ValueError(f"Convolution '{conv_name}' not found in {__name__}") self.blocks = nn.ModuleList( [ _layers.ConvolutionBlock( in_dim=in_dim, out_dim=out_dim, edge_dim=edge_dim, nonlinearity=nonlinearity, conv_name=conv_name, conv_args=conv_args, skip=skip_connection, ) for in_dim, out_dim in zip(layers[:-1], layers[1:]) ] )
[docs] def forward(self, data: Data): """ Forward pass. Parameters ---------- data : Data PyG Data object containing x, edge_index, and edge_attr. Returns ------- torch.Tensor Output features. """ x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr x = self.lift(x) for block in self.blocks[:-1]: x = block(x=x, edge_index=edge_index, edge_attr=edge_attr) x = self.blocks[-1]( x=x, edge_index=edge_index, edge_attr=edge_attr, use_activation=False ) x = self.proj(x) return x
[docs] class PatchGNP(nn.Module): """ Geometric Neural Operator for processing point cloud patches. Parameters ---------- node_dim : int The dimensionality of the input node features (e.g., 3 for xyz). out_dim : int The dimensionality of the final output vector for each patch. layers : list[int] A list of integers defining the width of each layer in the network. The first element is the width after the initial lifting layer. num_channels : int The number of channels to use in the 'block' type convolution. neurons : int The number of neurons in the hidden layers of the MLPs within the convolutional layers. nonlinearity : str The name of the activation function to use. device : str The device to place the model on. Defaults to "cuda". """ def __init__( self, node_dim: int, out_dim: int, layers: list[int], num_channels: int, neurons: int, nonlinearity: str, device: str = "cuda", ): super().__init__() self.node_dim = node_dim self.out_dim = out_dim self.layer_widths = layers self.num_channels = num_channels self.neurons = neurons self.nonlinearity = nonlinearity self.device = device self.activation = _layers.get_activation(nonlinearity) self.lift = nn.Linear(node_dim, layers[0]) self.proj = nn.Sequential( nn.Linear(layers[-1], 2 * layers[-1]), self.activation, nn.Linear(2 * layers[-1], out_dim), ) self.convs = nn.ModuleList( [ _layers.PatchSeparableBlockFactorizedConvolutionBlock( in_dim=in_dim, out_dim=out_dim, dim_x=node_dim, num_channels=num_channels, neurons=neurons, nonlinearity=nonlinearity, ) for in_dim, out_dim in zip(layers[:-1], layers[1:]) ] )
[docs] def forward(self, x: torch.Tensor, batch: torch.Tensor): """ Perform a forward pass on a batch of patches. Parameters ---------- x : torch.Tensor Input coordinates/features (N, node_dim). batch : torch.Tensor Batch indices (N,). Returns ------- torch.Tensor A tensor of shape (num_patches, out_dim) containing the output vector for each patch. """ v = self.lift(x) for i, conv in enumerate(self.convs): v = conv(x, v, batch) if i < len(self.convs) - 1: v = self.activation(v) v = global_mean_pool(v, batch) return self.proj(v)