from packaging import version
import warnings
from typing import Tuple
import numpy as np
from numpy import ndarray
from numba import njit, prange
import scipy as sp
try:
from sklearn.neighbors import KDTree
__has_sklearn__ = True
except Exception:
__has_sklearn__ = False
__scipy_version__ = sp.__version__
__cache = True
__all__ = ["k_nearest_neighbours", "knn_to_lines"]
[docs]
def k_nearest_neighbours(
X: ndarray,
Y: ndarray = None,
*,
backend: str = "scipy",
k: int = 1,
workers: int = -1,
tree_kwargs: dict = None,
query_kwargs: dict = None,
leaf_size: int = 30,
return_distance: bool = False,
max_distance: float = None,
) -> Tuple[ndarray, ndarray] | ndarray:
"""
Returns the k nearest neighbours (KNN) of a KDTree for a pointcloud using `scipy`
or `sklearn`. The function acts as a uniform interface for similar functionality
of `scipy` and `sklearn`. The most important parameters are highlighted, for the
complete list of arguments, see the corresponding docs:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.KDTree.html#scipy.spatial.KDTree
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html
To learn more about nearest neighbour searches in general:
https://scikit-learn.org/stable/modules/neighbors.html
Parameters
----------
X: numpy.ndarray
An array of points to build the tree.
Y: numpy.ndarray, Optional
An array of sampling points to query the tree. If None it is the
same as the points used to build the tree. Default is None.
k: int or Sequence[int], Optional
Either the number of nearest neighbors to return,
or a list of the k-th nearest neighbors to return, starting from 1.
leaf_size: positive int, Optional
The number of points at which the algorithm switches over to brute-force.
Default is 10.
workers: int, Optional
Only if backend is 'scipy'.
Number of workers to use for parallel processing. If -1 is given all
CPU threads are used. Default: -1.
New in 'scipy' version 1.6.0.
max_distance: float, Optional
Return only neighbors within this distance. It can be a single value, or
an array of values of shape matching the input, while a None value
translates to an infinite upper bound.
Default is None.
tree_kwargs: dict, Optional
Extra keyword arguments passed to the KDTree creator of the selected
backend. Default is None.
Returns
-------
d: float or array of floats
The distances to the nearest neighbors. Only returned if
`return_distance==True`.
i: integer or array of integers
The index of each neighbor.
Raises
------
ImportError
In the abscence of a usable backend.
Examples
--------
>>> from sigmaepsilon.mesh.grid import grid
>>> from sigmaepsilon.mesh import KNN
>>> size = 80, 60, 20
>>> shape = 10, 8, 4
>>> X, _ = grid(size=size, shape=shape, eshape='H8')
>>> i = KNN(X, X, k=3, max_distance=10.0)
"""
tree_kwargs = {} if tree_kwargs is None else tree_kwargs
query_kwargs = {} if query_kwargs is None else query_kwargs
if backend == "scipy":
from scipy.spatial import KDTree
tree = KDTree(X, leafsize=leaf_size, **tree_kwargs)
max_distance = np.inf if max_distance is None else max_distance
query_kwargs["distance_upper_bound"] = max_distance
if version.parse(__scipy_version__) < version.parse("1.6.0"):
warnings.warn(
"Multithreaded execution of a KNN search is "
+ "running on a single thread in scipy<1.6.0. Install a newer"
+ "version or use `backend=sklearn` if scikit is installed."
)
d, i = tree.query(Y, k=k, **query_kwargs)
else:
d, i = tree.query(Y, k=k, workers=workers)
elif backend == "sklearn":
if not __has_sklearn__:
raise ImportError("'sklearn' must be installed for this!")
tree = KDTree(X, leaf_size=leaf_size, **tree_kwargs)
if max_distance is None:
d, i = tree.query(Y, k=k, **query_kwargs)
else:
r = max_distance
d, i = tree.query_radius(Y, r, k=k, **query_kwargs)
else:
raise ImportError("Either `sklearn` or `scipy` must be present for this!")
return (d, i) if return_distance else i
@njit(nogil=True, parallel=True, cache=__cache)
def knn_to_lines(inds: ndarray):
nN, nK = inds.shape
res = np.zeros((nN, nK, 2), dtype=inds.dtype)
for i in prange(nN):
for j in prange(nK):
res[i, j, 0] = i
res[i, j, 1] = inds[i, j]
return res