Source code for gnp.estimator

from pathlib import Path
from typing import Optional

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F
import torch_geometric as tg
from tqdm import tqdm

from .config import load_config, load_model
from .dataset.patch import PatchData, PatchTensor
from .geometry.surface import Surface
from .utils import smooth_values_by_gaussian, subsample_points_by_radius

MODULE_PATH = Path(__file__).parent


[docs] class GeometryEstimator: """ Class used for geometry estimation using the pre-trained PatchGNP. Parameters ---------- pcd : torch.Tensor Point cloud data (N, 3). orientation : torch.Tensor Normal vectors for each point (N, 3). function_values : torch.Tensor, optional Function values defined on the point cloud, by default None. model_name : str Name of the pre-trained model to use, by default "clean_30k". Options: "clean_30k", "clean_50k", "noise_70k", "outlier_50k". batch_size : int, optional Batch size for processing patches, by default 8192. device : str, optional Device to run the model on, by default "cpu". **data_kwargs : Additional keyword arguments for data configuration. These will override the default data configurations. """ def __init__( self, pcd: torch.Tensor, orientation: torch.Tensor, function_values: Optional[torch.Tensor] = None, model_name: str = "clean_30k", batch_size: int = 8192, device: str = "cpu", **data_kwargs: Optional[dict], ): assert model_name in ["clean_30k", "clean_50k", "noise_70k", "outlier_50k"] self.pcd = pcd.to(device) self.orientation = orientation.to(device) self.device = device self.data = {"x": self.pcd, "normals": self.orientation} self.config = load_config( MODULE_PATH / "model_weights" / model_name / "config.yaml" ) self.model_path = MODULE_PATH / "model_weights" / model_name / "state_dict.pth" self.model = load_model( config=self.config["model"], model_path=self.model_path, device=device ) self.model.to(device) self.batch_size = batch_size if function_values is not None: self.function_values = function_values.to(device) self.data["function_values"] = self.function_values for k, v in data_kwargs.items(): self.config["data"][k] = v
[docs] def patch_data(self, **datakwargs) -> PatchData: """ Create a PatchData object from the point cloud data. This method configures and generates patches from the input point cloud data, which can then be used by the model for predictions. Parameters ---------- **datakwargs : Keyword arguments to override data configuration settings. Returns ------- PatchData A PatchData object containing the point cloud data structured into patches. """ data_config = self.config["data"] for k, v in datakwargs.items(): data_config[k] = v return PatchTensor( data=self.data, device=self.device, **data_config ).as_patch_data()
[docs] def surface_patch(self, patch_data: Optional[PatchData] = None) -> Surface: """ Create a Surface from the input patch data using the predictions from the model. Parameters ---------- patch_data : PatchData, optional Batch containing the input patch data. If patch_data is None defaults to self.patch_data() Returns ------- Surface Surface object. """ if patch_data is None: patch_data = self.patch_data() surface_coefficients = [] for pd in patch_data.batch_iterator(self.batch_size): x, batch = pd.local_coordinates, pd.patch_number with torch.no_grad(): surface_coefficients.append(self.model(x, batch)) return Surface(patch_data, torch.cat(surface_coefficients, dim=0))
[docs] def estimate_quantities(self, quantity_names: list[str]) -> dict[str, torch.Tensor]: """ Estimate geometric quantities on the point cloud. This function returns a dictionary containing the estimated scalar and/or vector values. Parameters ---------- quantity_names : list[str] List of quantity names to estimate. Available quantities include: 'xyz_coordinates', 'normals', 'tangents', 'mean_curvature', 'gaussian_curvature', 'pca_coordinates', 'normals_pca', 'tangents_pca', 'metric', 'shape', 'weingarten', 'inverse_metric', 'inverse_metric_derivatives', 'det_metric', 'laplace_beltrami_from_coefficients'. Returns ------- dict[str, torch.Tensor] A dictionary where keys are the quantity names and values are the estimated tensors. """ surface = self.surface_patch() output = {} for name in quantity_names: if hasattr(surface, name): output[name] = getattr(surface, name) return output
[docs] def flow_step( self, delta_t: float, subsample_radius: float, smooth_radius: float, smooth_x: bool, ) -> dict: """ Perform a single step of mean curvature flow on the point cloud. Parameters ---------- delta_t : float Time step for the flow. subsample_radius : float Radius used for subsampling points after the flow step. smooth_radius : float Radius used for smoothing the mean curvature before the flow step. smooth_x : bool Whether to smooth the point cloud coordinates before the flow step. Returns ------- dict A dictionary containing the updated point cloud data ('x'), normals, and mean curvature. """ if smooth_x: estimate = self.estimate_quantities(["xyz_coordinates"]) x = estimate["xyz_coordinates"] self.pcd = x self.data["x"] = x estimate = self.estimate_quantities(["normals", "mean_curvature"]) x = self.pcd normals = estimate["normals"] mean_curvature = smooth_values_by_gaussian( x=x, values=estimate["mean_curvature"], radius=smooth_radius ) new_x = x + delta_t * mean_curvature.view(-1, 1) * normals subsampled_indices = subsample_points_by_radius(new_x, subsample_radius) new_x = new_x[subsampled_indices] new_normals = normals[subsampled_indices] mean_curvature = mean_curvature[subsampled_indices] new_data = { "x": new_x.contiguous(), "normals": new_normals.contiguous(), "mean_curvature": mean_curvature.contiguous(), } return new_data
[docs] def mean_flow( self, num_steps: int, save_data_per_step: int, delta_t: float, subsample_radius: float, smooth_radius: float, smooth_x: bool, ) -> list[dict]: """ Perform mean curvature flow on the point cloud over multiple steps. This method iteratively applies the mean curvature flow step and saves the state of the point cloud at specified intervals. Parameters ---------- num_steps : int The total number of flow steps to perform. save_data_per_step : int The interval at which to save the point cloud data. For example, a value of 5 means data is saved every 5 steps. delta_t : float Time step for each flow step. subsample_radius : float Radius used for subsampling points after each flow step. smooth_radius : float Radius used for smoothing the mean curvature before each flow step. smooth_x : bool Whether to smooth the point cloud coordinates before each flow step. Returns ------- list[dict] A list of dictionaries. Each dictionary contains the point cloud data ('x'), 'normals', and 'mean_curvature' at a saved step. """ save_data = [] for i in tqdm(range(num_steps)): new_data = self.flow_step( delta_t=delta_t, subsample_radius=subsample_radius, smooth_radius=smooth_radius, smooth_x=smooth_x, ) self.data = new_data.copy() self.pcd = new_data["x"] self.orientation = new_data["normals"] if not torch.isfinite(new_data["x"]).all(): print(f"Nan or Infinite detected in Mean Flow! Exiting early at iteration {i}") return save_data if i % save_data_per_step == 0: save_data.append(new_data.copy()) return save_data
[docs] def gmls_weights( self, patch_data: PatchData, mask: torch.Tensor, radius: float = 1.0, p: int = 4 ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute the weights for the generalized moving least squares (GMLS) method. Parameters ---------- patch_data : PatchData PatchData object containing coordinate data. mask : torch.Tensor A boolean mask indicating which points in the patch data to use. radius : float, optional Radius to truncate the weight function, by default 1.0. p : int, optional The exponent for the weight function, by default 4. Returns ------- tuple[torch.Tensor, torch.Tensor] A tuple containing: - The weight matrices for each patch (batch_size, 1, num_points). - The dense mask used to create the dense batch. """ uv = patch_data.local_coordinates[mask, :2] dists = uv.norm(dim=1) dists_dense, mask = tg.utils.to_dense_batch( x=dists, batch=patch_data.patch_number[mask], fill_value=torch.inf ) weights = F.relu(1 - dists_dense / radius).pow(p) return weights.unsqueeze(1), mask
[docs] def laplace_beltrami_legendre_blocks(self, surface: Surface) -> torch.Tensor: """ Compute the Laplace-Beltrami operator of Legendre basis functions. Parameters ---------- surface : Surface Surface object containing patch information and basis functions. Returns ------- torch.Tensor A tensor containing the Laplace-Beltrami operator applied """ return surface.laplace_beltrami_basis_terms
[docs] def legendre_blocks(self, surface: Surface, mask: torch.Tensor) -> torch.Tensor: """ Evaluate Legendre basis functions for each point in the patches. This method evaluates the Legendre basis functions at the local `uv` coordinates of the points specified by the mask and returns them as a dense tensor, batched by patch. Parameters ---------- surface : Surface Surface object containing patch data and basis functions. mask : torch.Tensor A boolean mask to select which points' local coordinates to use for the evaluation. Returns ------- torch.Tensor A dense tensor of Legendre basis function evaluations, with shape (num_patches, max_points_in_patch, num_basis_functions). """ uv = surface.patch.local_coordinates[mask, :2] legendre_values = surface.basis.evaluate(uv) legendre_blocks, _ = tg.utils.to_dense_batch( x=legendre_values, batch=surface.patch.patch_number[mask], fill_value=0 ) return legendre_blocks
[docs] def stiffness_matrix_gmls( self, drop_ratio: float = 0.1, radius: float = 1.0, p: int = 4, remove_outliers: bool = False, outlier_threshold: float = 0.2, ) -> tuple[sp.csr_array, torch.Tensor, torch.Tensor]: """ Compute the stiffness matrix using the generalized moving least squares (GMLS) method. Parameters ---------- drop_ratio: float, optional Ratio of points to drop. Defaults to 0.1. radius: float, optional Radius to truncate weight function. Defaults to 1.. p: int, optional Degree p of the weight function. Defaults to 4. remove_outliers: bool, optional Whether to remove outliers. Defaults to False. outlier_threshold: float, optional The threshold used to determine which points are labeled as outliers. Only used if remove_outliers is True. Returns ------- sp.csr_array Stiffness matrix for solving the Laplace-Beltrami PDE torch.Tensor Mask for which values to solve collocation problem on. torch.Tensor Mask for points that are removed in smoothing. If remove_outliers is False then the mask will be all true. """ if remove_outliers: outputs = self.estimate_quantities(["local_coordinates", "pca_coordinates"]) outlier_mask = ( outputs["local_coordinates"] - outputs["pca_coordinates"] ).norm(dim=1) < outlier_threshold for k, v in self.data.items(): self.data[k] = v[outlier_mask] self.pcd = self.pcd[outlier_mask] self.orientation = self.orientation[outlier_mask] else: outlier_mask = torch.ones( self.pcd.shape[0], dtype=torch.bool, device=self.pcd.device ) if drop_ratio > 0.0: drop_inds = tg.nn.fps(self.pcd, ratio=drop_ratio).to(self.device) else: drop_inds = torch.LongTensor([]) collocation_mask = torch.ones( self.pcd.shape[0], dtype=torch.bool, device=self.device ) collocation_mask[drop_inds] = False patch_data = self.patch_data(mode="gmls") coord_mask = collocation_mask[patch_data.patch_indices] surface = self.surface_patch(patch_data) weights, tensor_mask = self.gmls_weights(patch_data, coord_mask, radius, p) legendre = self.legendre_blocks(surface, coord_mask) lb = self.laplace_beltrami_legendre_blocks(surface) ls_solutions = torch.linalg.lstsq( (legendre.permute(0, 2, 1) * weights) @ legendre, legendre.permute(0, 2, 1) * weights, ).solution stiffness_values = (-lb.unsqueeze(1) @ ls_solutions).squeeze(1)[tensor_mask] _, patch_indices_reindexed = torch.unique( patch_data.patch_indices[coord_mask], return_inverse=True ) stiffness_indices = torch.stack( [ patch_data.patch_number.flatten()[coord_mask], patch_indices_reindexed, ], dim=0, ) stiffness = sp.coo_matrix( ( stiffness_values.cpu().numpy(), stiffness_indices.cpu().numpy().astype(np.int32), ) ).tocsr() return stiffness, collocation_mask, outlier_mask