KdTree-Impl.hpp
Go to the documentation of this file.
1 // This code is based on Jet framework.
2 // Copyright (c) 2018 Doyub Kim
3 // CubbyFlow is voxel-based fluid simulation engine for computer games.
4 // Copyright (c) 2020 CubbyFlow Team
5 // Core Part: Chris Ohk, Junwoo Hwang, Jihong Sin, Seungwoo Yoo
6 // AI Part: Dongheon Cho, Minseo Kim
7 // We are making my contributions/submissions to this project solely in our
8 // personal capacity and are not conveying any rights to any intellectual
9 // property of any third parties.
10 
11 #ifndef CUBBYFLOW_KDTREE_IMPL_HPP
12 #define CUBBYFLOW_KDTREE_IMPL_HPP
13 
14 #include <numeric>
15 
16 namespace CubbyFlow
17 {
18 template <typename T, size_t K>
19 void KdTree<T, K>::Node::InitLeaf(size_t it, const Point& pt)
20 {
21  flags = K;
22  item = it;
23  child = std::numeric_limits<size_t>::max();
24  point = pt;
25 }
26 
27 template <typename T, size_t K>
28 void KdTree<T, K>::Node::InitInternal(size_t axis, size_t it, size_t c,
29  const Point& pt)
30 {
31  flags = axis;
32  item = it;
33  child = c;
34  point = pt;
35 }
36 
37 template <typename T, size_t K>
39 {
40  return flags == K;
41 }
42 
43 template <typename T, size_t K>
45 {
46  m_points.resize(points.Length());
47  std::copy(points.begin(), points.end(), m_points.begin());
48 
49  if (m_points.empty())
50  {
51  return;
52  }
53 
54  m_nodes.clear();
55 
56  std::vector<size_t> itemIndices(m_points.size());
57  std::iota(std::begin(itemIndices), std::end(itemIndices), 0);
58 
59  [[maybe_unused]] const size_t d =
60  Build(0, itemIndices.data(), m_points.size(), 0);
61 }
62 
63 template <typename T, size_t K>
65  const Point& origin, T radius,
66  const std::function<void(size_t, const Point&)>& callback) const
67 {
68  const T r2 = radius * radius;
69 
70  // prepare to traverse the tree for sphere
71  static const int maxTreeDepth = 8 * sizeof(size_t);
72  const Node* todo[maxTreeDepth];
73  size_t todoPos = 0;
74 
75  // traverse the tree nodes for sphere
76  const Node* node = m_nodes.data();
77 
78  while (node != nullptr)
79  {
80  if (node->item != std::numeric_limits<size_t>::max() &&
81  (node->point - origin).LengthSquared() <= r2)
82  {
83  callback(node->item, node->point);
84  }
85 
86  if (node->IsLeaf())
87  {
88  // grab next node to process from todo stack
89  if (todoPos > 0)
90  {
91  // dequeue
92  --todoPos;
93  node = todo[todoPos];
94  }
95  else
96  {
97  break;
98  }
99  }
100  else
101  {
102  // get node children pointers for sphere
103  const Node* firstChild = node + 1;
104  const Node* secondChild = const_cast<Node*>(&m_nodes[node->child]);
105 
106  // advance to next child node, possibly enqueue other child
107  const size_t axis = node->flags;
108  const T plane = node->point[axis];
109 
110  if (plane - origin[axis] > radius)
111  {
112  node = firstChild;
113  }
114  else if (origin[axis] - plane > radius)
115  {
116  node = secondChild;
117  }
118  else
119  {
120  // enqueue secondChild in todo stack
121  todo[todoPos] = secondChild;
122  ++todoPos;
123  node = firstChild;
124  }
125  }
126  }
127 }
128 
129 template <typename T, size_t K>
130 bool KdTree<T, K>::HasNearbyPoint(const Point& origin, T radius) const
131 {
132  const T r2 = radius * radius;
133 
134  // prepare to traverse the tree for sphere
135  static const int maxTreeDepth = 8 * sizeof(size_t);
136  const Node* todo[maxTreeDepth];
137  size_t todoPos = 0;
138 
139  // traverse the tree nodes for sphere
140  const Node* node = m_nodes.data();
141 
142  while (node != nullptr)
143  {
144  if (node->item != std::numeric_limits<size_t>::max() &&
145  (node->point - origin).LengthSquared() <= r2)
146  {
147  return true;
148  }
149 
150  if (node->IsLeaf())
151  {
152  // grab next node to process from todo stack
153  if (todoPos > 0)
154  {
155  // dequeue
156  --todoPos;
157  node = todo[todoPos];
158  }
159  else
160  {
161  break;
162  }
163  }
164  else
165  {
166  // get node children pointers for sphere
167  const Node* firstChild = node + 1;
168  const Node* secondChild = const_cast<Node*>(&m_nodes[node->child]);
169 
170  // advance to next child node, possibly enqueue other child
171  const size_t axis = node->flags;
172  const T plane = node->point[axis];
173 
174  if (origin[axis] < plane && plane - origin[axis] > radius)
175  {
176  node = firstChild;
177  }
178  else if (origin[axis] > plane && origin[axis] - plane > radius)
179  {
180  node = secondChild;
181  }
182  else
183  {
184  // enqueue secondChild in todo stack
185  todo[todoPos] = secondChild;
186  ++todoPos;
187  node = firstChild;
188  }
189  }
190  }
191 
192  return false;
193 }
194 
195 template <typename T, size_t K>
196 size_t KdTree<T, K>::GetNearestPoint(const Point& origin) const
197 {
198  // prepare to traverse the tree for sphere
199  static const int maxTreeDepth = 8 * sizeof(size_t);
200  const Node* todo[maxTreeDepth];
201  size_t todoPos = 0;
202 
203  // traverse the tree nodes for sphere
204  const Node* node = m_nodes.data();
205  size_t nearest = 0;
206  T minDist2 = (node->point - origin).LengthSquared();
207 
208  while (node != nullptr)
209  {
210  const T newDist2 = (node->point - origin).LengthSquared();
211  if (newDist2 <= minDist2)
212  {
213  nearest = node->item;
214  minDist2 = newDist2;
215  }
216 
217  if (node->IsLeaf())
218  {
219  // grab next node to process from todo stack
220  if (todoPos > 0)
221  {
222  // Dequeue
223  --todoPos;
224  node = todo[todoPos];
225  }
226  else
227  {
228  break;
229  }
230  }
231  else
232  {
233  // get node children pointers for sphere
234  const Node* firstChild = node + 1;
235  const Node* secondChild = static_cast<Node*>(&m_nodes[node->child]);
236 
237  // advance to next child node, possibly enqueue other child
238  const size_t axis = node->flags;
239  const T plane = node->point[axis];
240  const T minDist = std::sqrt(minDist2);
241 
242  if (plane - origin[axis] > minDist)
243  {
244  node = firstChild;
245  }
246  else if (origin[axis] - plane > minDist)
247  {
248  node = secondChild;
249  }
250  else
251  {
252  // enqueue secondChild in todo stack
253  todo[todoPos] = secondChild;
254  ++todoPos;
255  node = firstChild;
256  }
257  }
258  }
259 
260  return nearest;
261 }
262 
263 template <typename T, size_t K>
264 void KdTree<T, K>::Reserve(size_t numPoints, size_t numNodes)
265 {
266  m_points.resize(numPoints);
267  m_nodes.resize(numNodes);
268 }
269 
270 template <typename T, size_t K>
272 {
273  return m_points.begin();
274 };
275 
276 template <typename T, size_t K>
278 {
279  return m_points.end();
280 };
281 
282 template <typename T, size_t K>
284 {
285  return m_points.begin();
286 };
287 
288 template <typename T, size_t K>
290 {
291  return m_points.end();
292 };
293 
294 template <typename T, size_t K>
296 {
297  return m_nodes.begin();
298 };
299 
300 template <typename T, size_t K>
302 {
303  return m_nodes.end();
304 };
305 
306 template <typename T, size_t K>
308 {
309  return m_nodes.begin();
310 };
311 
312 template <typename T, size_t K>
314 {
315  return m_nodes.end();
316 };
317 
318 template <typename T, size_t K>
319 size_t KdTree<T, K>::Build(size_t nodeIndex, size_t* itemIndices, size_t nItems,
320  size_t currentDepth)
321 {
322  // add a node
323  m_nodes.emplace_back();
324 
325  // initialize leaf node if termination criteria met
326  if (nItems == 0)
327  {
328  m_nodes[nodeIndex].InitLeaf(std::numeric_limits<size_t>::max(), {});
329  return currentDepth + 1;
330  }
331  if (nItems == 1)
332  {
333  m_nodes[nodeIndex].InitLeaf(itemIndices[0], m_points[itemIndices[0]]);
334  return currentDepth + 1;
335  }
336 
337  // choose which axis to split along
338  BBox nodeBound;
339  for (size_t i = 0; i < nItems; ++i)
340  {
341  nodeBound.Merge(m_points[itemIndices[i]]);
342  }
343  Point d = nodeBound.upperCorner - nodeBound.lowerCorner;
344  const size_t axis = static_cast<size_t>(d.DominantAxis());
345 
346  // pick mid point
347  std::nth_element(itemIndices, itemIndices + nItems / 2,
348  itemIndices + nItems, [&](size_t a, size_t b) {
349  return m_points[a][axis] < m_points[b][axis];
350  });
351  const size_t midPoint = nItems / 2;
352 
353  // recursively initialize children nodes
354  const size_t d0 =
355  Build(nodeIndex + 1, itemIndices, midPoint, currentDepth + 1);
356  m_nodes[nodeIndex].InitInternal(axis, itemIndices[midPoint], m_nodes.size(),
357  m_points[itemIndices[midPoint]]);
358  const size_t d1 =
359  Build(m_nodes[nodeIndex].child, itemIndices + midPoint + 1,
360  nItems - midPoint - 1, currentDepth + 1);
361 
362  return std::max(d0, d1);
363 }
364 } // namespace CubbyFlow
365 
366 #endif
VectorType upperCorner
Upper corner of the bounding box.
Definition: BoundingBox.hpp:148
void Reserve(size_t numPoints, size_t numNodes)
Reserves memory space for this tree.
Definition: KdTree-Impl.hpp:264
NodeIterator EndNode()
Returns the mutable end iterator of the node.
Definition: KdTree-Impl.hpp:301
typename ContainerType::const_iterator ConstIterator
Definition: KdTree.hpp:56
typename NodeContainerType::iterator NodeIterator
Definition: KdTree.hpp:59
N-D axis-aligned bounding box class.
Definition: BoundingBox.hpp:46
Iterator begin()
Returns the mutable begin iterator of the item.
Definition: KdTree-Impl.hpp:271
bool IsLeaf() const
Returns true if leaf.
Definition: KdTree-Impl.hpp:38
Iterator end()
Definition: ArrayBase-Impl.hpp:102
bool HasNearbyPoint(const Point &origin, T radius) const
Definition: KdTree-Impl.hpp:130
Point point
Point stored in the node.
Definition: KdTree.hpp:51
NodeIterator BeginNode()
Returns the mutable begin iterator of the node.
Definition: KdTree-Impl.hpp:295
Definition: Matrix.hpp:27
void InitInternal(size_t axis, size_t it, size_t c, const Point &pt)
Initializes internal node.
Definition: KdTree-Impl.hpp:28
size_t GetNearestPoint(const Point &origin) const
Returns index of the nearest point.
Definition: KdTree-Impl.hpp:196
Definition: pybind11Utils.hpp:20
size_t DominantAxis() const
Definition: MatrixExpression-Impl.hpp:206
Iterator begin()
Definition: ArrayBase-Impl.hpp:90
typename NodeContainerType::const_iterator ConstNodeIterator
Definition: KdTree.hpp:60
VectorType lowerCorner
Lower corner of the bounding box.
Definition: BoundingBox.hpp:145
size_t Length() const
Definition: ArrayBase-Impl.hpp:84
void Build(const ConstArrayView1< Point > &points)
Builds internal acceleration structure for given points list.
Definition: KdTree-Impl.hpp:44
size_t flags
Split axis if flags < K, leaf indicator if flags == K.
Definition: KdTree.hpp:41
void Merge(const VectorType &point)
Merges this and other point.
Definition: BoundingBox-Impl.hpp:219
Generic N-dimensional array class interface.
Definition: Array.hpp:32
size_t child
Right child index. Note that left child index is this node index + 1.
Definition: KdTree.hpp:45
Iterator end()
Returns the mutable end iterator of the item.
Definition: KdTree-Impl.hpp:277
typename ContainerType::iterator Iterator
Definition: KdTree.hpp:55
size_t item
Item index.
Definition: KdTree.hpp:48
void ForEachNearbyPoint(const Point &origin, T radius, const std::function< void(size_t, const Point &)> &callback) const
Definition: KdTree-Impl.hpp:64
void InitLeaf(size_t it, const Point &pt)
Initializes leaf node.
Definition: KdTree-Impl.hpp:19