from typing import Tuple
import torch
import torch_geometric as tg
from torch_scatter import scatter_add, scatter_max
[docs]
class QueryTorchGeometric:
"""
A wrapper for performing k-nearest neighbor and radius queries on a point
cloud using PyTorch Geometric.
Parameters
----------
x : torch.Tensor
The point cloud data to be queried, of shape (N, D).
device : str, optional
The device to store the point cloud on, by default "cpu".
"""
def __init__(self, x: torch.Tensor, device="cpu"):
self.x = x.to(device)
self.device = torch.device(device)
[docs]
def query_knn(
self, queries: torch.Tensor, k: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find the k-nearest neighbors for a set of query points.
Parameters
----------
queries : torch.Tensor
The query points, of shape (M, D).
k : int
The number of nearest neighbors to find.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
A tuple containing:
- The distances to the k-nearest neighbors, shape (M, k).
- The indices of the k-nearest neighbors, shape (M, k).
"""
if not queries.device == self.device:
queries = queries.to(self.device)
index_y, index_x = tg.nn.knn(self.x, queries, k=k)
distances = (self.x[index_x.view(-1, k)] - queries[index_y.view(-1, k)]).norm(
dim=-1
)
return distances, index_x.view(-1, k)
[docs]
@staticmethod
def query_radius(
x: torch.Tensor,
y: torch.Tensor,
radius: float,
max_num_neighbors: int = 100,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find all points in `x` within a given radius of points in `y`.
This method finds all pairs (i, j) such that the distance between
x[i] and y[j] is less than the radius.
Parameters
----------
x : torch.Tensor
The point cloud to search within (the "haystack").
y : torch.Tensor
The query points (the "needles").
radius : float
The search radius.
max_num_neighbors : int, optional
The maximum number of neighbors to return for each query point,
by default 100.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
A tuple containing:
- `index_x`: Indices of the found points in `x`.
- `index_y`: Indices of the corresponding query points in `y`.
"""
assert radius >= 0
index_y, index_x = tg.nn.radius(
x=x, y=y, r=radius, max_num_neighbors=max_num_neighbors
)
return index_x, index_y
[docs]
def subsample_points_by_radius(x: torch.Tensor, radius: float) -> torch.Tensor:
"""
Subsample points using a parallel maximal independent set approach.
Parameters
----------
x : torch.Tensor
The input points to subsample, shape (N, D).
radius : float
The radius for defining neighborhoods.
Returns
-------
torch.Tensor
A tensor containing the indices of the subsampled points.
"""
device = x.device
num_points = x.shape[0]
random_rank = torch.rand(num_points, device=device)
neighbor_indices, query_indices = QueryTorchGeometric.query_radius(
x, x, radius, max_num_neighbors=1024
)
mask = torch.ones(x.shape[0], device=device).bool()
mask_sum = mask.sum() - 1
while mask_sum != mask.sum():
random_rank = torch.rand(num_points, device=device)
mask_sum = mask.sum()
max_neighbor_rank, _ = scatter_max(
random_rank[neighbor_indices], query_indices, dim=0, dim_size=num_points
)
max_neighbor_mask = (random_rank == max_neighbor_rank) & mask
neighbor_sum_mask = scatter_add(
max_neighbor_mask[neighbor_indices], query_indices, dim=0, dim_size=num_points
)
mask = (max_neighbor_mask | ~neighbor_sum_mask) & mask
edge_mask = mask[neighbor_indices] & mask[query_indices]
neighbor_indices = neighbor_indices[edge_mask]
query_indices = query_indices[edge_mask]
if (~mask.any()):
print(f"Removed {(1 - mask).sum()} points.")
return torch.where(mask)[0]
[docs]
def smooth_values_by_gaussian(
x: torch.Tensor, values: torch.Tensor, radius: float
) -> torch.Tensor:
"""
Smooth values on a point cloud using a truncated Gaussian kernel.
For each point, this function computes a weighted average of the values of
its neighbors within a given radius. The weights are determined by a
Gaussian function of the distance.
Parameters
----------
x : torch.Tensor
The input data points, shape (N, D).
values : torch.Tensor
The values associated with each point to be smoothed, shape (N,).
radius : float
The truncation radius for the Gaussian kernel. The standard deviation
of the Gaussian is set to one-third of this radius.
Returns
-------
torch.Tensor
A tensor of shape (N,) containing the smoothed values.
"""
neighbor_indices, query_indices = QueryTorchGeometric.query_radius(
x, x, radius, max_num_neighbors=1024
)
distances = (x[neighbor_indices] - x[query_indices]).norm(dim=1)
weights = torch.exp(-distances.pow(2) / (2 * (radius / 3) ** 2))
weight_sums = scatter_add(weights, index=query_indices, dim=0, dim_size=x.size(0))
weights_normalized = weights / weight_sums[query_indices]
values_averaged = scatter_add(
weights_normalized * values[neighbor_indices],
query_indices,
dim=0,
dim_size=x.size(0),
)
return values_averaged