iMSTK
Interactive Medical Simulation Toolkit
imstkSurfaceMeshDistanceTransform.cpp
1 /*
2 ** This file is part of the Interactive Medical Simulation Toolkit (iMSTK)
3 ** iMSTK is distributed under the Apache License, Version 2.0.
4 ** See accompanying NOTICE for details.
5 */
6 
7 #include "imstkSurfaceMeshDistanceTransform.h"
8 #include "imstkDataArray.h"
9 #include "imstkGeometryUtilities.h"
10 #include "imstkImageData.h"
11 #include "imstkLogger.h"
12 #include "imstkLooseOctree.h"
13 #include "imstkParallelUtils.h"
14 #include "imstkSurfaceMesh.h"
15 #include "imstkTimer.h"
16 #include "imstkSurfaceMeshImageMask.h"
17 #include <stack>
18 #include <vtkDistancePolyDataFilter.h>
19 #include <vtkImageData.h>
20 #include <vtkImplicitPolyDataDistance.h>
21 #include <vtkOctreePointLocator.h>
22 #include <vtkPointData.h>
23 #include <vtkPolyData.h>
24 #include <vtkSelectEnclosedPoints.h>
25 
26 namespace imstk
27 {
34 //static int isNeighborhoodEquivalent(const Vec3i& pt, const Vec3i& dim, const double val, const double* imgPtr, const int dilateSize)
35 //{
36 // const Vec3i min = (pt - Vec3i(dilateSize, dilateSize, dilateSize)).cwiseMax(Vec3i(0, 0, 0)).cwiseMin(dim - Vec3i(1, 1, 1));
37 // const Vec3i max = (pt + Vec3i(dilateSize, dilateSize, dilateSize)).cwiseMax(Vec3i(0, 0, 0)).cwiseMin(dim - Vec3i(1, 1, 1));
38 //
39 // // Take the max of the neighborhood
40 // for (int z = min[2]; z < max[2] + 1; z++)
41 // {
42 // for (int y = min[1]; y < max[1] + 1; y++)
43 // {
44 // for (int x = min[0]; x < max[0] + 1; x++)
45 // {
46 // const int index = ImageData::getScalarIndex(x, y, z, dim, 1);
47 // if (val != imgPtr[index])
48 // {
49 // return false;
50 // }
51 // }
52 // }
53 // }
54 // return 0;
55 //}
56 static bool
57 isNeighborhoodEquivalent(const Vec3i& pt, const Vec3i& dim, const float val, const float* imgPtr, const int dilateSize)
58 {
59  const Vec3i min = (pt - Vec3i(dilateSize, dilateSize, dilateSize)).cwiseMax(Vec3i(0, 0, 0)).cwiseMin(dim - Vec3i(1, 1, 1));
60  const Vec3i max = (pt + Vec3i(dilateSize, dilateSize, dilateSize)).cwiseMax(Vec3i(0, 0, 0)).cwiseMin(dim - Vec3i(1, 1, 1));
61 
62  // Take the max of the neighborhood
63  for (int z = min[2]; z < max[2] + 1; z++)
64  {
65  for (int y = min[1]; y < max[1] + 1; y++)
66  {
67  for (int x = min[0]; x < max[0] + 1; x++)
68  {
69  const size_t index = ImageData::getScalarIndex(x, y, z, dim, 1);
70  if (val != imgPtr[index])
71  {
72  return false;
73  }
74  }
75  }
76  }
77  return true;
78 }
79 
80 // Narrow band is WIP, it works but is slow
81 static void
82 computeNarrowBandedDT(std::shared_ptr<ImageData> imageData, std::shared_ptr<SurfaceMesh> surfMesh, const int dilateSize,
83  const double tolerance)
84 {
85  // Rasterize a mask from the polygon
86  std::shared_ptr<SurfaceMeshImageMask> imageMask = std::make_shared<SurfaceMeshImageMask>();
87  imageMask->setInputMesh(surfMesh);
88  imageMask->setReferenceImage(imageData);
89  imageMask->update();
90 
91  auto inputScalarsPtr = std::dynamic_pointer_cast<DataArray<float>>(imageMask->getOutputImage()->getScalars());
92  DataArray<float>& inputScalars = *inputScalarsPtr;
93  float* inputImgPtr = inputScalars.getPointer();
94  auto outputScalarsPtr = std::dynamic_pointer_cast<DataArray<double>>(imageData->getScalars());
95  DataArray<double>& outputScalars = *outputScalarsPtr;
96  double* outputImgPtr = outputScalars.getPointer();
97 
98  // Separate polygons used to avoid race conditions
99  vtkSmartPointer<vtkPolyData> inputPolyData = GeometryUtils::copyToVtkPolyData(surfMesh);
100  vtkSmartPointer<vtkImplicitPolyDataDistance> distFunc = vtkSmartPointer<vtkImplicitPolyDataDistance>::New();
101  distFunc->SetInput(inputPolyData);
102  distFunc->SetTolerance(tolerance);
103 
104  std::fill_n(outputImgPtr, outputScalars.size(), 10000.0);
105 
106  // Iterate the image testing for boundary pixels (ie any 0 adjacent to a 1)
107  const Vec3i& dim = imageData->getDimensions();
108  const Vec3d shift = imageData->getOrigin() + imageData->getSpacing() * 0.5;
109  const Vec3d& spacing = imageData->getSpacing();
110  int i = 0;
111  for (int z = 0; z < dim[2]; z++)
112  {
113  for (int y = 0; y < dim[1]; y++)
114  {
115  for (int x = 0; x < dim[0]; x++, i++)
116  {
117  const float val = inputImgPtr[i];
118  const Vec3i pt = Vec3i(x, y, z);
119 
120  // If neighborhood is homogenous then its not touching the boundary
121  if (!isNeighborhoodEquivalent(pt, dim, val, inputImgPtr, dilateSize))
122  {
123  const Vec3d pos = pt.cast<double>().cwiseProduct(spacing) + shift;
124  outputImgPtr[i] = distFunc->FunctionValue(pos.data());
125  }
126  // If value is 1 (we are inside)
127  else if (val == 1.0)
128  {
129  outputImgPtr[i] = -10000.0;
130  }
131 
132  if (i % 1000000 == 0)
133  {
134  double p = static_cast<double>(i) / (dim[0] * dim[1] * dim[2]);
135  std::cout << "Progress " << p << "\n";
136  }
137  }
138  }
139  }
140 }
141 
142 static void
143 computeFullDT(std::shared_ptr<ImageData> imageData, std::shared_ptr<SurfaceMesh> surfMesh, const double tolerance)
144 {
145  // Get the optimal number of threads
146  const int numThreads = static_cast<int>(ParallelUtils::ThreadManager::getThreadPoolSize());
147 
148  const Vec3i& dim = imageData->getDimensions();
149  const Vec3d spacing = imageData->getSpacing();
150  const Vec3d shift = imageData->getOrigin() + spacing * 0.5;
151 
152  auto scalarsPtr = std::dynamic_pointer_cast<DataArray<double>>(imageData->getScalars());
153  DataArray<double>& scalars = *scalarsPtr.get();
154 
155  // Split the work up along z
156  ParallelUtils::parallelFor(numThreads, [&](const int& i)
157  {
158  // Separate polygons used to avoid race conditions
159  vtkSmartPointer<vtkPolyData> inputPolyData = GeometryUtils::copyToVtkPolyData(surfMesh);
160  vtkSmartPointer<vtkImplicitPolyDataDistance> distFunc = vtkSmartPointer<vtkImplicitPolyDataDistance>::New();
161  distFunc->SetInput(inputPolyData);
162  distFunc->SetTolerance(tolerance);
163 
164  for (int z = i; z < dim[2]; z += numThreads)
165  {
166  int j = z * dim[0] * dim[1];
167  for (int y = 0; y < dim[1]; y++)
168  {
169  for (int x = 0; x < dim[0]; x++, j++)
170  {
171  double pos[3] = { x* spacing[0] + shift[0], y * spacing[1] + shift[1], z * spacing[2] + shift[2] };
172  scalars[j] = distFunc->FunctionValue(pos);
173  }
174  }
175  }
176  });
177 }
178 
179 SurfaceMeshDistanceTransform::SurfaceMeshDistanceTransform()
180 {
181  setNumInputPorts(1);
182  setRequiredInputType<SurfaceMesh>(0);
183 
185  setOutput(std::make_shared<ImageData>(), 0);
186 }
187 
188 void
189 SurfaceMeshDistanceTransform::setInputMesh(std::shared_ptr<SurfaceMesh> mesh)
190 {
191  setInput(mesh, 0);
192 }
193 
194 std::shared_ptr<ImageData>
195 SurfaceMeshDistanceTransform::getOutputImage()
196 {
197  return std::dynamic_pointer_cast<ImageData>(getOutput());
198 }
199 
200 void
201 SurfaceMeshDistanceTransform::setupDistFunc()
202 {
203  std::shared_ptr<SurfaceMesh> inputSurfaceMesh = std::dynamic_pointer_cast<SurfaceMesh>(getInput(0));
204  vtkSmartPointer<vtkPolyData> inputPolyData = GeometryUtils::copyToVtkPolyData(inputSurfaceMesh);
205 
206  m_distFunc = vtkSmartPointer<vtkImplicitPolyDataDistance>::New();
207  m_distFunc->SetInput(inputPolyData);
208 }
209 
210 Vec3d
212 {
213  m_distFunc->SetTolerance(m_Tolerance);
214  Vec3d closestPt = Vec3d::Zero();
215  Vec3d p = pos;
216  m_distFunc->EvaluateFunctionAndGetClosestPoint(p.data(), closestPt.data());
217  return closestPt;
218 }
219 
220 void
221 SurfaceMeshDistanceTransform::setBounds(const Vec6d& bounds)
222 {
223  m_Bounds = bounds;
224  if (m_Bounds.isZero())
225  {
226  LOG(WARNING) << "SurfaceMeshDistanceTransform Bounds are zero, the input SurfaceMesh bounds will be used instead.";
227  }
228 }
229 
230 void
231 SurfaceMeshDistanceTransform::setBounds(const Vec3d& min, const Vec3d& max)
232 {
233  m_Bounds << min.x(), max.x(), min.y(), max.y(), min.z(), max.z();
234  if (m_Bounds.isZero())
235  {
236  LOG(WARNING) << "SurfaceMeshDistanceTransform Bounds are zero, the input SurfaceMesh bounds will be used instead.";
237  }
238 }
239 
240 void
242 {
243  std::shared_ptr<SurfaceMesh> inputSurfaceMesh = std::dynamic_pointer_cast<SurfaceMesh>(getInput(0));
244  std::shared_ptr<ImageData> outputImageData = std::dynamic_pointer_cast<ImageData>(getOutput(0));
245 
246  if (m_Dimensions[0] == 0 || m_Dimensions[1] == 0 || m_Dimensions[2] == 0)
247  {
248  LOG(WARNING) << "SurfaceMeshDistanceTransform Dimensions not set";
249  return;
250  }
251 
252  Vec6d bounds = m_Bounds;
253  if (bounds.isZero())
254  {
255  Vec3d min, max;
256  inputSurfaceMesh->computeBoundingBox(min, max, 0.0);
257  bounds << min.x(), max.x(), min.y(), max.y(), min.z(), max.z();
258  LOG(WARNING) << "SurfaceMeshDistanceTransform Bounds are zero, the input SurfaceMesh bounds (" << bounds.transpose() << ") will be used.";
259  }
260 
261  const Vec3d size = Vec3d(bounds[1] - bounds[0], bounds[3] - bounds[2], bounds[5] - bounds[4]);
262  const Vec3d spacing = size.cwiseQuotient(m_Dimensions.cast<double>());
263  const Vec3d origin = Vec3d(bounds[0], bounds[2], bounds[4]);
264  outputImageData->allocate(IMSTK_DOUBLE, 1, m_Dimensions, spacing, origin);
265 
266  /* StopWatch timer;
267  timer.start();*/
268 
269  if (m_NarrowBanded)
270  {
271  computeNarrowBandedDT(outputImageData, inputSurfaceMesh, m_DilateSize,
272  m_Tolerance);
273  }
274  else
275  {
276  computeFullDT(outputImageData, inputSurfaceMesh, m_Tolerance);
277  }
278 
279  //printf("time: %f\n", timer.getTimeElapsed());
280 }
281 } // namespace imstk
size_t getScalarIndex(int x, int y, int z=0)
Returns index of data in scalar array given structured image coordinate, does no bounds checking...
vtkSmartPointer< vtkPolyData > copyToVtkPolyData(std::shared_ptr< LineMesh > imstkMesh)
Converts imstk line mesh into a vtk polydata.
std::shared_ptr< Geometry > getInput(size_t port=0) const
Returns input geometry given port, returns nullptr if doesn&#39;t exist.
Compound Geometry.
void setNumOutputPorts(const size_t numPorts)
Get/Set the amount of output ports.
void setBounds(const Vec3d &min, const Vec3d &max)
Optionally one may specify bounds, if not specified bounds of the input SurfaceMesh is used...
std::shared_ptr< Geometry > getOutput(size_t port=0) const
Returns output geometry given port, returns nullptr if doesn&#39;t exist.
void setInput(std::shared_ptr< Geometry > inputGeometry, size_t port=0)
Set the input at the port.
void requestUpdate() override
Users can implement this for the logic to be run.
Represents a set of triangles & vertices via an array of Vec3d double vertices & Vec3i integer indice...
void setOutput(std::shared_ptr< Geometry > inputGeometry, const size_t port=0)
Set the output at the port.
void setNumInputPorts(const size_t numPorts)
Get/Set the amount of input ports.
static size_t getThreadPoolSize()
Returns the size of the thread pool.
Class to represent 1, 2, or 3D image data (i.e. structured points)
Vec3d getNearestPoint(const Vec3d &pos)
Get the nearest point.
void setInputMesh(std::shared_ptr< SurfaceMesh > surfMesh)
Required input, port 0.