from functools import cached_property
from typing import Optional
import torch
import torch_geometric as tg
from torch_geometric.data import Data
from torch_scatter import scatter_add, scatter_max, scatter_mean
from ..utils import QueryTorchGeometric
[docs]
class PatchData(Data):
"""
A PyTorch Geometric Data object specialized for patch-based point cloud data.
This class extends the base `Data` object to include patch-specific
attributes and a utility for iterating over batches of patches.
"""
def __init__(self, x: Optional[torch.Tensor] = None, **kwargs):
super().__init__(x=x, **kwargs)
@property
def num_patches(self):
"""Returns the number of patches in the dataset."""
return self.centers.shape[0] if "centers" in self else 0
[docs]
def batch_iterator(self, batch_size: int):
"""
Create a generator to iterate over batches of patches.
This method yields smaller `PatchData` objects, each containing a subset
of the patches. It assumes that the patch data is sorted by `patch_number`.
Parameters
----------
batch_size : int
The maximum number of patches in each batch.
Yields
------
PatchData
A new `PatchData` object representing a batch of patches.
"""
num_patches = self.num_patches
batch_size = min(num_patches, batch_size)
ptr = torch.zeros(num_patches + 1, dtype=torch.long, device=self.x.device)
torch.cumsum(self.patch_lens, dim=0, out=ptr[1:])
for start_idx in range(0, num_patches, batch_size):
end_idx = min(start_idx + batch_size, num_patches)
batch_kwargs = {}
for key, value in self.to_dict().items():
if isinstance(value, torch.Tensor) and value.shape[0] == num_patches:
batch_kwargs[key] = value[start_idx:end_idx]
edge_start = ptr[start_idx]
edge_end = ptr[end_idx]
total_edges = self.patch_indices.shape[0]
for key, value in self.to_dict().items():
if isinstance(value, torch.Tensor) and value.shape[0] == total_edges:
sliced_val = value[edge_start:edge_end]
if key == "patch_number":
sliced_val = sliced_val - start_idx
batch_kwargs[key] = sliced_val
yield PatchData(**batch_kwargs)
[docs]
class PatchTensor:
"""
Processes a point cloud into a collection of overlapping patches.
This class handles the entire pipeline of patchifying a point cloud. It
selects patch centers, finds neighboring points for each patch, computes
local coordinate systems using PCA, and scales the coordinates. The final
output is a `PatchData` object ready for use in a model.
Parameters
----------
data : dict
A dictionary containing the point cloud data, requires at least an
'x' key with a tensor of shape (N, 3).
k : int, optional
Number of nearest neighbors to consider for various calculations,
by default 30.
mode : str, optional
The mode for center selection ('train', 'test', or 'gmls'),
by default "test".
pca : bool, optional
Whether to use PCA to determine local coordinate systems,
by default True.
scale : bool, optional
Whether to scale the local coordinates, by default True.
min_z_scale : float, optional
The minimum value for z-scaling, by default 5e-3.
basis : str, optional
The basis to use, by default "legendre".
basis_degree : int, optional
The degree of the basis, by default 3.
num_training_patches : int, optional
Number of patches to sample in 'train' mode, by default 1024.
device : str, optional
The device to perform computations on, by default "cpu".
"""
def __init__(
self,
data: dict,
k: int = 30,
mode: str = "test",
pca: bool = True,
scale: bool = True,
min_z_scale: float = 5e-3,
basis: str = "legendre",
basis_degree: int = 3,
num_training_patches: int = 1024,
device: str = "cpu",
):
self.data = data
self.x = data["x"].squeeze().to(device)
self.original_x = data.get("original_x", self.x).squeeze().to(device)
for key, val in data.items():
if isinstance(val, torch.Tensor):
self.data[key] = val.to(device)
self.k = k
self.mode = mode
self.pca = pca
self.scale = scale
self.min_z_scale = min_z_scale
self.basis = basis
self.basis_degree = basis_degree
self.num_training_patches = num_training_patches
self.device = device
self.query = QueryTorchGeometric(x=self.x, device=self.device)
self.center_indices, self.clusters, self.knn_distances = self.get_centers()
self.patch_indices, self.patch_number, self.patch_lens = (
self._patch_data_query()
)
self.centers = scatter_mean(
self.x[self.patch_indices], index=self.patch_number, dim=0
)
[docs]
def get_centers(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Dispatch method to get patch centers based on the current mode.
Returns
-------
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple containing:
- The indices of the center points.
- The cluster assignment for each point in the cloud.
- The k-NN distance for each center, used to determine patch radius.
"""
if self.mode == "train":
return self.get_train_centers()
elif self.mode == "test":
return self.get_test_centers()
elif self.mode == "gmls":
return self.get_gmls_centers()
else:
raise ValueError(f"Unknown mode: {self.mode}")
[docs]
def get_train_centers(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Select patch centers by random sampling for training.
Returns
-------
tuple[torch.Tensor, torch.Tensor, torch.Tensor
Center indices, cluster assignments (dummy), and k-NN distances.
"""
centers = torch.randperm(self.x.shape[0])[: self.num_training_patches]
clusters = torch.zeros((self.x.shape[0])).long()
distances, _ = self.query.query_knn(self.x[centers], k=self.k)
return (
centers.to(self.device),
clusters.to(self.device),
distances[:, -1].to(self.device),
)
[docs]
def get_test_centers(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Select patch centers using a greedy covering strategy for testing.
This method iterates through shuffled points and selects a point as a
center if it's not already covered by an existing patch, effectively
creating a set of patches that cover the entire point cloud.
Returns
-------
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Center indices, cluster assignments, and k-NN distances.
"""
distances, ind = self.query.query_knn(self.x, k=self.k)
center_knn_ind = ind[:, : int(0.66 * self.k)].contiguous()
center_indices = []
mask = torch.ones(self.x.shape[0], dtype=torch.bool)
shuffled_indices = torch.randperm(self.x.shape[0])
for idx in shuffled_indices:
if mask[idx]:
center_indices.append(idx)
neighbors_to_cover = center_knn_ind[idx].flatten()
mask[neighbors_to_cover] = False
center_indices = torch.LongTensor(center_indices)
self.max_patches = center_indices.shape[0] + 1
knn_ind = ind[center_indices]
clusters = torch.zeros(self.x.shape[0], dtype=torch.long)
arange = torch.arange(center_indices.shape[0])
for j in range(ind.shape[1] - 1, -1, -1):
clusters[knn_ind[:, j]] = arange
return (
center_indices.to(self.device),
clusters.to(self.device),
distances[center_indices, -1],
)
[docs]
def get_gmls_centers(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Select every point as a patch center for GMLS.
In GMLS mode, a patch is centered at every single point in the cloud.
Returns
-------
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Center indices, cluster assignments, and k-NN distances.
"""
distances, _ = self.query.query_knn(self.x, k=self.k)
return (
torch.arange(self.x.shape[0], device=self.device),
torch.arange(self.x.shape[0], device=self.device),
distances[:, -1],
)
def _patch_data_query(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Find all points within the radius of each patch center.
Uses a radius query to find all points belonging to each patch, defined
by the k-NN distance of the center point.
Returns
-------
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple containing:
- The indices of all points belonging to any patch.
- The patch number (i.e., center index) for each point.
- The number of points in each patch.
"""
index_x, index_y = self.query.query_radius(
x=self.x,
y=self.x[self.center_indices],
radius=1.1 * self.knn_distances.max().item(),
max_num_neighbors=5 * self.k,
)
mask = (self.x[index_x] - self.x[self.center_indices[index_y]]).norm(dim=1) < (
1.1 * self.knn_distances
)[index_y]
patch_indices = index_x[mask]
patch_number = index_y[mask]
patch_lens = scatter_add(
torch.ones(patch_number.shape[0], device=self.device, dtype=torch.long),
patch_number,
)
return (
patch_indices.to(self.device),
patch_number.to(self.device),
patch_lens.to(self.device),
)
[docs]
@cached_property
def tensor_centered(self):
"""
Dense tensor of patch points, centered by subtracting the patch mean.
Returns
-------
torch.Tensor
A tensor of shape (num_patches, max_points_in_patch, 3).
"""
patch_tensor, _ = tg.utils.to_dense_batch(
x=self.x[self.patch_indices], batch=self.patch_number, fill_value=torch.inf
)
patch_tensor -= self.centers.unsqueeze(1)
return torch.nan_to_num(patch_tensor, nan=0.0, posinf=0.0, neginf=0.0)
@cached_property
def _pca_data(self):
"""
Compute PCA vectors and z-scaling for each patch.
This method calculates the principal component vectors for each patch.
It aligns the third component with the provided orientation (normals)
and ensures a right-handed coordinate system. It also computes a
z-scaling factor based on the standard deviation along the third
principal component.
Returns
-------
tuple[torch.Tensor, torch.Tensor
A tuple containing:
- PCA vectors (num_patches, 3, 3).
- Z-axis scaling factor (num_patches, 1).
"""
x_centered = self.x[self.patch_indices] - self.centers[self.patch_number]
outer_prod = x_centered.unsqueeze(2) * x_centered.unsqueeze(1)
cov_matrices = scatter_add(outer_prod, self.patch_number, dim=0)
_, S_squared, Vh = torch.linalg.svd(cov_matrices)
S = S_squared.sqrt()
pca_vectors = Vh.clone()
if self.data.get("orientation", None) is not None:
orientation = self.data.get("orientation")[self.center_indices]
elif self.data.get("normals", None) is not None:
orientation = self.data.get("normals")[self.center_indices]
else:
orientation = self.x[self.center_indices]
flip_mask = (orientation * pca_vectors[:, 2]).sum(dim=-1) < 0.0
pca_vectors[flip_mask, 2] *= -1
cross_mask = (
torch.linalg.cross(pca_vectors[:, 0], pca_vectors[:, 1]) * pca_vectors[:, 2]
).sum(dim=-1) < 0
pca_vectors[cross_mask, 1] *= -1
z_scale = S[:, 2] / (self.patch_lens - 1).sqrt()
z_mask = z_scale < self.min_z_scale
z_scale[z_mask] = self.min_z_scale
z_scale = z_scale.view(-1, 1)
return pca_vectors, z_scale
@cached_property
def _local_coordinate_data(self):
"""
Compute local coordinates for points in each patch.
Projects the centered patch points onto the PCA basis and scales them.
The xy coordinates are scaled by the maximum xy-norm in the patch, and
the z coordinate is scaled by the pre-computed `z_scale`.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
A tuple containing:
- Scaled local coordinates for each point in a patch.
- The xy-scaling factor for each patch.
"""
local_unscaled = (
(
(
self.x[self.patch_indices] - self.centers[self.patch_number]
).unsqueeze(1)
)
@ self.pca_vectors.permute(0, 2, 1)[self.patch_number]
).squeeze(1)
xy_scale, _ = scatter_max(local_unscaled[:, :2].norm(dim=1), self.patch_number)
xy_scale = xy_scale.view(-1, 1)
scaling = torch.cat((xy_scale, xy_scale, self.z_scale), dim=1)
return local_unscaled / scaling[self.patch_number], xy_scale
@cached_property
def _local_coordinate_original_data(self):
"""
Compute local coordinates for the 'original' points.
If `original_x` data is provided, this method computes the local
coordinates for those points using the same transformation (PCA vectors
and scaling) derived from the primary point cloud.
"""
x = self.data.get("original_x", None)
if x is None:
return None
else:
x = x.squeeze(0)
local_unscaled = (
((x[self.patch_indices] - self.centers[self.patch_number]).unsqueeze(1))
[docs]
@ self.pca_vectors.permute(0, 2, 1)[self.patch_number]
).squeeze(1)
scaling = torch.cat((self.xy_scale, self.xy_scale, self.z_scale), dim=1)
return local_unscaled / scaling[self.patch_number]
@cached_property
def x_local(self):
"""All points transformed into the local coordinate system of their assigned patch."""
if self.mode == "train":
return self.local_coordinates
local_unscaled = (
(self.x - self.centers[self.clusters]).unsqueeze(1)
[docs]
@ self.pca_vectors[self.clusters].permute(0, 2, 1)
).squeeze(1)
return local_unscaled / self.scaling[self.clusters]
@cached_property
def tensor_local(self):
"""Dense tensor of patch points in local coordinates."""
local_unscaled = (
self.pca_vectors.unsqueeze(1) @ self.tensor_centered.unsqueeze(-1)
).squeeze()
return local_unscaled / self.scaling.unsqueeze(1)
@property
def pca_vectors(self):
"""The PCA vectors (local basis) for each patch."""
pca_vectors, _ = self._pca_data
return pca_vectors
@property
def z_scale(self):
"""The z-axis scaling factor for each patch."""
_, z_scale = self._pca_data
return z_scale
@property
def local_coordinates(self):
"""The scaled local coordinates of points within their respective patches."""
local_coordinates, _ = self._local_coordinate_data
return local_coordinates
@property
def local_coordinates_original(self):
"""The scaled local coordinates of the 'original' points."""
local_coordinates_original = self._local_coordinate_original_data
return local_coordinates_original
@property
def xy_scale(self):
"""The xy-plane scaling factor for each patch."""
_, xy_scale = self._local_coordinate_data
return xy_scale
@property
def scaling(self):
"""The combined (x, y, z) scaling vector for each patch."""
return torch.cat((self.xy_scale, self.xy_scale, self.z_scale), dim=1)
[docs]
def as_patch_data(self) -> PatchData:
"""
Assemble and return the final `PatchData` object.
This method collects all computed attributes (patch indices, local
coordinates, PCA vectors, etc.) and any additional data from the input
dictionary into a single `PatchData` object.
Returns
-------
PatchData
The fully processed patch data object.
"""
data_dict = {
"x": self.x,
"x_original": self.original_x,
"mode": self.mode,
"clusters": self.clusters,
"centers": self.centers,
"center_indices": self.center_indices,
"patch_indices": self.patch_indices,
"patch_number": self.patch_number,
"patch_lens": self.patch_lens,
"local_coordinates": self.local_coordinates,
"local_coordinates_original": self.local_coordinates_original,
"pca_vectors": self.pca_vectors,
"z_scale": self.z_scale,
"xy_scale": self.xy_scale,
"degree": self.basis_degree,
}
for k, v in self.data.items():
if k not in data_dict.keys():
data_dict[k] = v
return PatchData(**data_dict).to(self.device)