Source code for gnp.geometry.surface

from functools import cached_property

import torch

from ..dataset.patch import PatchData
from .legendre import Legendre2D


[docs] class Surface: """ Surface for computing geometric quantities from basis coefficients used at evaluation time. Parameters ---------- patch_data : Batch Patch data in local coordinates from PatchTensor. basis_coefficients : torch.Tensor Basis coefficients representing the surface (e.g., Legendre coefficients). """ def __init__( self, patch_data: PatchData, basis_coefficients: torch.Tensor, use_original: bool = False, ): self.patch = patch_data self.device = patch_data.x.device self.x = patch_data.x_original if use_original else patch_data.x self.centers = patch_data.centers self.xy_scale = patch_data.xy_scale[self.batch] self.z_scale = patch_data.z_scale self.pca_vectors = patch_data.pca_vectors self.coefficients = basis_coefficients self.basis = Legendre2D(degree=patch_data.degree) self._compute_geometry() @property def batch(self): """ Batch identifier for which cluster which points belong to. Returns ------- torch.Tensor Batch indices. """ return self.patch.clusters
[docs] @cached_property def local_coordinates(self): """ Returns the local coordinates of the patch data. Returns ------- torch.Tensor local coordinates of the patch data. """ diff = (self.x - self.centers[self.batch]).unsqueeze(1) local_unscaled = (diff @ self.pca_vectors[self.batch].permute(0, 2, 1)).squeeze( 1 ) scaling = torch.cat( (self.xy_scale, self.xy_scale, self.z_scale[self.batch]), dim=1 ) return local_unscaled / scaling
def _compute_geometry(self): """ Compute height function and its derivatives for computation of downstream geometric quantities. """ self.derivative_scale = torch.reciprocal( torch.cat( (self.xy_scale.repeat(1, 2), self.xy_scale.pow(2).repeat(1, 3)), dim=1 ) ) scaled_coeffs = self.z_scale * self.coefficients self.h = self.basis.evaluate_from_coeffs( self.local_coordinates[:, :2], scaled_coeffs[self.batch] ) raw_derivatives = self.basis.derivatives_from_coeffs( self.local_coordinates[:, :2], scaled_coeffs[self.batch] ) h_derivatives = self.derivative_scale * raw_derivatives (self.h_u, self.h_v, self.h_uv, self.h_uu, self.h_vv) = torch.split( h_derivatives, 1, 1 ) @property def local_coordinate_basis(self) -> torch.Tensor: """ Return the local PCA basis coordinates at each point Returns ------- torch.Tensor local coordinates of the surface patches """ return self.pca_vectors[self.batch]
[docs] @cached_property def xyz_coordinates(self) -> torch.Tensor: """ Compute the xyz coordinates of the surface patches Returns ------- torch.Tensor xyz coordinates of the surface patches """ pca_x = self.local_coordinates[..., :2].unsqueeze(2) pca_basis = self.local_coordinate_basis xyz = ( self.centers[self.batch] + self.xy_scale * (pca_x * pca_basis[:, :2]).sum(dim=1) + self.h * pca_basis[:, 2] ) return xyz
[docs] @cached_property def pca_coordinates(self) -> torch.Tensor: """ Compute the PCA coordinates of the surface patches Returns ------- torch.Tensor PCA coordinates of the surface patches """ pca_coord = torch.cat( (self.local_coordinates[..., :2], self.h / self.z_scale[self.batch]), dim=1 ) return pca_coord
[docs] @cached_property def tangents_pca(self) -> torch.Tensor: """ Compute the tangents of the surface patches in local coordinates Returns ------- torch.Tensor tangents of the surface patches """ ones = torch.ones((self.h.shape[0], 1), device=self.device) zeros = torch.zeros((self.h.shape[0], 1), device=self.device) tangent_u = torch.cat((ones, zeros, self.h_u), dim=-1) tangent_v = torch.cat((zeros, ones, self.h_v), dim=-1) return torch.stack((tangent_u, tangent_v), dim=1)
[docs] @cached_property def tangents(self) -> torch.Tensor: """ Compute the tangents of the surface patches in global coordinates Returns ------- torch.Tensor tangents of the surface patches """ pca_basis = self.local_coordinate_basis return self.tangents_pca @ pca_basis
[docs] @cached_property def normals_pca(self) -> torch.Tensor: """ Compute the normals of the surface patches in local coordinates Returns ------- torch.Tensor normals of the surface patches """ ones = torch.ones((self.h.shape[0], 1), device=self.device) normals = torch.cat((-self.h_u, -self.h_v, ones), dim=-1) normals = normals / (1 + self.h_u.pow(2) + self.h_v.pow(2)).sqrt() return normals
[docs] @cached_property def normals(self) -> torch.Tensor: """ Compute the normals of the surface patches in global coordinates Returns ------- torch.Tensor normals of the surface patches """ pca_basis = self.local_coordinate_basis normals = (self.normals_pca.unsqueeze(1) @ pca_basis).squeeze(1) normals = normals / normals.norm(dim=1, keepdim=True) return normals
[docs] @cached_property def metric(self) -> torch.Tensor: """ Compute the metric tensor of the surface patches Returns ------- torch.Tensor metric tensor of the surface patches """ E = 1 + self.h_u.pow(2) F = self.h_u * self.h_v G = 1 + self.h_v.pow(2) return torch.stack((torch.cat((E, F), dim=1), torch.cat((F, G), dim=1)), dim=1)
[docs] @cached_property def shape(self) -> torch.Tensor: """ Compute the shape tensor of the surface patches Returns ------- torch.Tensor shape tensor of the surface patches """ divisor = (1 + self.h_u.pow(2) + self.h_v.pow(2)).sqrt() L = self.h_uu / divisor M = self.h_uv / divisor N = self.h_vv / divisor return torch.stack((torch.cat((L, M), dim=1), torch.cat((M, N), dim=1)), dim=1)
[docs] @cached_property def weingarten(self) -> torch.Tensor: """ Compute the Weingarten tensor of the surface patches Returns ------- torch.Tensor Weingarten tensor of the surface patches """ return torch.linalg.inv(self.metric) @ self.shape
[docs] @cached_property def gaussian_curvature(self) -> torch.Tensor: """ Compute the Gaussian curvature tensor of the surface patches Returns ------- torch.Tensor Gaussian curvature tensor of the surface patches """ return torch.linalg.det(self.weingarten)
[docs] @cached_property def mean_curvature(self) -> torch.Tensor: """ Compute the mean curvature tensor of the surface patches Returns ------- torch.Tensor mean curvature tensor of the surface patches """ return 0.5 * self.weingarten.diagonal(dim1=1, dim2=2).sum(dim=1)
[docs] @cached_property def inverse_metric(self) -> torch.Tensor: """ Compute the inverse metric tensor of the surface patches Returns ------- torch.Tensor inverse metric tensor of the surface patches """ return torch.linalg.inv(self.metric)
[docs] @cached_property def inverse_metric_derivatives(self) -> torch.Tensor: """ Compute the derivatives of the inverse metric tensor of the surface patches Returns ------- torch.Tensor derivatives of the inverse metric tensor of the surface patches """ zero = torch.zeros_like(self.h_u, device=self.device) mat1 = torch.cat( ( 1 + self.h_v.pow(2), -self.h_u * self.h_v, -self.h_u * self.h_v, 1 + self.h_u.pow(2), ), dim=1, ).view(-1, 2, 2) mat2_u = torch.cat( ( zero, -self.h_uu * self.h_v - self.h_u * self.h_uv, -self.h_uu * self.h_v - self.h_u * self.h_uv, 2 * self.h_u * self.h_uu, ), dim=1, ).view(-1, 2, 2) mat2_v = torch.cat( ( 2 * self.h_v * self.h_vv, -self.h_uv * self.h_v - self.h_u * self.h_vv, -self.h_uv * self.h_v - self.h_u * self.h_vv, zero, ), dim=1, ).view(-1, 2, 2) divisor = 1 + self.h_u.pow(2) + self.h_v.pow(2) det_u = ( -(2 * self.h_u * self.h_uu + 2 * self.h_v * self.h_uv) / divisor.pow(2) ).unsqueeze(2) det_v = ( -(2 * self.h_u * self.h_uv + 2 * self.h_v * self.h_vv) / divisor.pow(2) ).unsqueeze(2) ginv_u = det_u * mat1 + mat2_u / divisor.unsqueeze(2) ginv_v = det_v * mat1 + mat2_v / divisor.unsqueeze(2) return torch.stack((ginv_u, ginv_v), dim=-1)
[docs] @cached_property def det_metric(self) -> torch.Tensor: """ Compute the determinant of the metric tensor of the surface patches Returns ------- torch.Tensor determinant of the metric tensor of the surface patches """ return 1 + self.h_u.pow(2) + self.h_v.pow(2)
[docs] @cached_property def laplace_beltrami_first_terms(self) -> torch.Tensor: """ Compute the first terms of the Laplace-Beltrami operator. These terms are multiplied to the second derivatives of the function. Returns ------- torch.Tensor The first terms of the laplace-beltrami operator. """ return self.det_metric.sqrt() * self.inverse_metric.contiguous().view(-1, 4)
[docs] @cached_property def laplace_beltrami_second_terms(self) -> torch.Tensor: """ Compute the second terms of the Laplace-Beltrami operator. These terms are multiplied to the first derivatives of the function. Returns ------- torch.Tensor The second terms of the laplace-beltrami operator. """ det_g = self.det_metric uu = (2 * self.h_v * self.h_uv / det_g.sqrt()) - ( (self.h_u * self.h_uu + self.h_v * self.h_uv) * (1 + self.h_v.pow(2)) / det_g.pow(1.5) ) uv = -((self.h_uu * self.h_v + self.h_u * self.h_uv) / det_g.pow(0.5)) + ( (self.h_u * self.h_uu + self.h_v * self.h_uv) * (self.h_u * self.h_v) / det_g.pow(1.5) ) vu = -((self.h_uv * self.h_v + self.h_u * self.h_vv) / det_g.pow(0.5)) + ( (self.h_u * self.h_uv + self.h_v * self.h_vv) * (self.h_u * self.h_v) / det_g.pow(1.5) ) vv = (2 * self.h_u * self.h_uv / det_g.sqrt()) - ( (self.h_u * self.h_uv + self.h_v * self.h_vv) * (1 + self.h_u.pow(2)) / det_g.pow(1.5) ) return torch.cat((uu, uv, vu, vv), dim=1)
[docs] @cached_property def laplace_beltrami_basis_terms(self) -> torch.Tensor: """Compute the Laplace-Beltrami operator on each of the basis functions. Returns ------- torch.Tensor (n, 16) tensor containing the Laplace-Beltrami operator applied to the basis functions. """ derivatives = self.derivative_scale.unsqueeze(-1) * torch.stack( self.basis.evaluate_derivatives(self.local_coordinates[..., :2]), dim=1, ) fu, fv, fuv, fuu, fvv = torch.unbind(derivatives, dim=1) lb_first = self.laplace_beltrami_first_terms.unsqueeze(-1) lb_second = self.laplace_beltrami_second_terms.unsqueeze(-1) basis_lb = ( lb_first * torch.stack((fuu, fuv, fuv, fvv), dim=1) + lb_second * torch.stack((fu, fv, fu, fv), dim=1) ).sum(dim=1) return basis_lb / self.det_metric.sqrt()
[docs] def laplace_beltrami_from_coefficients( self, function_coefficients: torch.Tensor ) -> torch.Tensor: """ Compute the Laplace-Beltrami operator from the surface coefficients and the function coefficients. Parameters ---------- function_coefficients : torch.Tensor Basis coefficients representing the function. Returns ------- torch.Tensor Laplace-Beltrami operator applied to the function. """ lb_values = ( function_coefficients[self.batch] * self.laplace_beltrami_basis_terms ).sum(dim=1, keepdim=True) return lb_values