1
2
3
4
5
6 """
7 KD tree data structure for searching N-dimensional vectors.
8
9 The KD tree data structure can be used for all kinds of searches that
10 involve N-dimensional vectors, e.g. neighbor searches (find all points
11 within a radius of a given point) or finding all point pairs in a set
12 that are within a certain radius of each other. See "Computational Geometry:
13 Algorithms and Applications" (Mark de Berg, Marc van Kreveld, Mark Overmars,
14 Otfried Schwarzkopf). Author: Thomas Hamelryck.
15 """
16
17 from numpy import sum, sqrt, array
18 from numpy.random import random
19
20 from Bio.KDTree import _CKDTree
21
22
24 diff = p - q
25 return sqrt(sum(diff * diff))
26
27
29 """ Test all fixed radius neighbor search.
30
31 Test all fixed radius neighbor search using the
32 KD tree C module.
33
34 o nr_points - number of points used in test
35 o dim - dimension of coords
36 o bucket_size - nr of points per tree node
37 o radius - radius of search (typically 0.05 or so)
38 """
39
40 kdt = _CKDTree.KDTree(dim, bucket_size)
41 coords = random((nr_points, dim))
42 kdt.set_data(coords)
43 neighbors = kdt.neighbor_search(radius)
44 r = [neighbor.radius for neighbor in neighbors]
45 if r is None:
46 l1 = 0
47 else:
48 l1 = len(r)
49
50 neighbors = kdt.neighbor_simple_search(radius)
51 r = [neighbor.radius for neighbor in neighbors]
52 if r is None:
53 l2 = 0
54 else:
55 l2 = len(r)
56 if l1 == l2:
57 print "Passed."
58 else:
59 print "Not passed: %i != %i." % (l1, l2)
60
61
62 -def _test(nr_points, dim, bucket_size, radius):
63 """Test neighbor search.
64
65 Test neighbor search using the KD tree C module.
66
67 o nr_points - number of points used in test
68 o dim - dimension of coords
69 o bucket_size - nr of points per tree node
70 o radius - radius of search (typically 0.05 or so)
71 """
72
73 kdt = _CKDTree.KDTree(dim, bucket_size)
74 coords = random((nr_points, dim))
75 center = coords[0]
76 kdt.set_data(coords)
77 kdt.search_center_radius(center, radius)
78 r = kdt.get_indices()
79 if r is None:
80 l1 = 0
81 else:
82 l1 = len(r)
83 l2 = 0
84
85 for i in range(0, nr_points):
86 p = coords[i]
87 if _dist(p, center) <= radius:
88 l2 = l2 + 1
89 if l1 == l2:
90 print "Passed."
91 else:
92 print "Not passed: %i != %i." % (l1, l2)
93
94
96 """
97 KD tree implementation (C++, SWIG python wrapper)
98
99 The KD tree data structure can be used for all kinds of searches that
100 involve N-dimensional vectors, e.g. neighbor searches (find all points
101 within a radius of a given point) or finding all point pairs in a set
102 that are within a certain radius of each other.
103
104 Reference:
105
106 Computational Geometry: Algorithms and Applications
107 Second Edition
108 Mark de Berg, Marc van Kreveld, Mark Overmars, Otfried Schwarzkopf
109 published by Springer-Verlag
110 2nd rev. ed. 2000.
111 ISBN: 3-540-65620-0
112
113 The KD tree data structure is described in chapter 5, pg. 99.
114
115 The following article made clear to me that the nodes should
116 contain more than one point (this leads to dramatic speed
117 improvements for the "all fixed radius neighbor search", see
118 below):
119
120 JL Bentley, "Kd trees for semidynamic point sets," in Sixth Annual ACM
121 Symposium on Computational Geometry, vol. 91. San Francisco, 1990
122
123 This KD implementation also performs a "all fixed radius neighbor search",
124 i.e. it can find all point pairs in a set that are within a certain radius
125 of each other. As far as I know the algorithm has not been published.
126 """
127
128 - def __init__(self, dim, bucket_size=1):
129 self.dim = dim
130 self.kdt = _CKDTree.KDTree(dim, bucket_size)
131 self.built = 0
132
133
134
136 """Add the coordinates of the points.
137
138 o coords - two dimensional NumPy array. E.g. if the points
139 have dimensionality D and there are N points, the coords
140 array should be NxD dimensional.
141 """
142 if coords.min() <= -1e6 or coords.max() >= 1e6:
143 raise Exception("Points should lie between -1e6 and 1e6")
144 if len(coords.shape) != 2 or coords.shape[1] != self.dim:
145 raise Exception("Expected a Nx%i NumPy array" % self.dim)
146 self.kdt.set_data(coords)
147 self.built = 1
148
149
150
151 - def search(self, center, radius):
152 """Search all points within radius of center.
153
154 o center - one dimensional NumPy array. E.g. if the points have
155 dimensionality D, the center array should be D dimensional.
156 o radius - float>0
157 """
158 if not self.built:
159 raise Exception("No point set specified")
160 if center.shape != (self.dim,):
161 raise Exception("Expected a %i-dimensional NumPy array"
162 % self.dim)
163 self.kdt.search_center_radius(center, radius)
164
166 """Return radii.
167
168 Return the list of distances from center after
169 a neighbor search.
170 """
171 a = self.kdt.get_radii()
172 if a is None:
173 return []
174 return a
175
177 """Return the list of indices.
178
179 Return the list of indices after a neighbor search.
180 The indices refer to the original coords NumPy array. The
181 coordinates with these indices were within radius of center.
182
183 For an index pair, the first index<second index.
184 """
185 a = self.kdt.get_indices()
186 if a is None:
187 return []
188 return a
189
190
191
193 """All fixed neighbor search.
194
195 Search all point pairs that are within radius.
196
197 o radius - float (>0)
198 """
199 if not self.built:
200 raise Exception("No point set specified")
201 self.neighbors = self.kdt.neighbor_search(radius)
202
204 """Return All Fixed Neighbor Search results.
205
206 Return a Nx2 dim NumPy array containing
207 the indices of the point pairs, where N
208 is the number of neighbor pairs.
209 """
210 a = array([[neighbor.index1, neighbor.index2] for neighbor in self.neighbors])
211 return a
212
214 """Return All Fixed Neighbor Search results.
215
216 Return an N-dim array containing the distances
217 of all the point pairs, where N is the number
218 of neighbor pairs..
219 """
220 return [neighbor.radius for neighbor in self.neighbors]
221
222 if __name__ == "__main__":
223
224 nr_points = 100000
225 dim = 3
226 bucket_size = 10
227 query_radius = 10
228
229 coords = (200 * random((nr_points, dim)))
230
231 kdtree = KDTree(dim, bucket_size)
232
233
234 kdtree.set_coords(coords)
235
236
237
238 kdtree.all_search(query_radius)
239
240
241
242
243
244
245 indices = kdtree.all_get_indices()
246 radii = kdtree.all_get_radii()
247
248 print "Found %i point pairs within radius %f." % (len(indices), query_radius)
249
250
251
252 for i in range(0, 10):
253
254 center = random(dim)
255
256
257 kdtree.search(center, query_radius)
258
259
260 indices = kdtree.get_indices()
261 radii = kdtree.get_radii()
262
263 x, y, z = center
264 print "Found %i points in radius %f around center (%.2f, %.2f, %.2f)." % (len(indices), query_radius, x, y, z)
265