11 #ifndef CUBBYFLOW_CUDA_ARRAY_BASE_IMPL_HPP 12 #define CUBBYFLOW_CUDA_ARRAY_BASE_IMPL_HPP 14 #ifdef CUBBYFLOW_USE_CUDA 18 template <
typename T,
size_t N,
typename Derived>
19 size_t CUDAArrayBase<T, N, Derived>::Index(
size_t i)
const 24 template <
typename T,
size_t N,
typename Derived>
25 template <
typename... Args>
26 size_t CUDAArrayBase<T, N, Derived>::Index(
size_t i, Args... args)
const 28 static_assert(
sizeof...(args) == N - 1,
"Invalid number of indices.");
29 return i + m_size[0] * IndexInternal(1, args...);
32 template <
typename T,
size_t N,
typename Derived>
33 template <
size_t... I>
34 size_t CUDAArrayBase<T, N, Derived>::Index(
35 const CUDAStdArray<size_t, N>& idx)
const 37 return IndexInternal(idx, std::make_index_sequence<N>{});
40 template <
typename T,
size_t N,
typename Derived>
41 T* CUDAArrayBase<T, N, Derived>::data()
46 template <
typename T,
size_t N,
typename Derived>
47 const T* CUDAArrayBase<T, N, Derived>::data()
const 52 template <
typename T,
size_t N,
typename Derived>
53 const CUDAStdArray<size_t, N>& CUDAArrayBase<T, N, Derived>::Size()
const 58 template <
typename T,
size_t N,
typename Derived>
60 std::enable_if_t<(M > 0),
size_t> CUDAArrayBase<T, N, Derived>::Width()
const 65 template <
typename T,
size_t N,
typename Derived>
67 std::enable_if_t<(M > 1),
size_t> CUDAArrayBase<T, N, Derived>::Height()
const 72 template <
typename T,
size_t N,
typename Derived>
74 std::enable_if_t<(M > 2),
size_t> CUDAArrayBase<T, N, Derived>::Depth()
const 79 template <
typename T,
size_t N,
typename Derived>
80 size_t CUDAArrayBase<T, N, Derived>::Length()
const 84 for (
size_t i = 1; i < N; ++i)
93 template <
typename T,
size_t N,
typename Derived>
94 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::Reference
95 CUDAArrayBase<T, N, Derived>::At(
size_t i)
100 template <
typename T,
size_t N,
typename Derived>
101 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::ConstReference
102 CUDAArrayBase<T, N, Derived>::At(
size_t i)
const 107 template <
typename T,
size_t N,
typename Derived>
108 template <
typename... Args>
109 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::Reference
110 CUDAArrayBase<T, N, Derived>::At(
size_t i, Args... args)
112 return At(Index(i, args...));
115 template <
typename T,
size_t N,
typename Derived>
116 template <
typename... Args>
117 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::ConstReference
118 CUDAArrayBase<T, N, Derived>::At(
size_t i, Args... args)
const 120 return At(Index(i, args...));
123 template <
typename T,
size_t N,
typename Derived>
124 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::Reference
125 CUDAArrayBase<T, N, Derived>::At(
const CUDAStdArray<size_t, N>& idx)
127 return At(Index(idx));
130 template <
typename T,
size_t N,
typename Derived>
131 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::ConstReference
132 CUDAArrayBase<T, N, Derived>::At(
const CUDAStdArray<size_t, N>& idx)
const 134 return At(Index(idx));
137 template <
typename T,
size_t N,
typename Derived>
138 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::Reference
139 CUDAArrayBase<T, N, Derived>::operator[](
size_t i)
144 template <
typename T,
size_t N,
typename Derived>
145 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::ConstReference
146 CUDAArrayBase<T, N, Derived>::operator[](
size_t i)
const 151 template <
typename T,
size_t N,
typename Derived>
152 template <
typename... Args>
153 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::Reference
154 CUDAArrayBase<T, N, Derived>::operator()(
size_t i, Args... args)
156 return At(i, args...);
159 template <
typename T,
size_t N,
typename Derived>
160 template <
typename... Args>
161 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::ConstReference
162 CUDAArrayBase<T, N, Derived>::operator()(
size_t i, Args... args)
const 164 return At(i, args...);
167 template <
typename T,
size_t N,
typename Derived>
168 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::Reference
169 CUDAArrayBase<T, N, Derived>::operator()(
const CUDAStdArray<size_t, N>& idx)
174 template <
typename T,
size_t N,
typename Derived>
175 CUBBYFLOW_CUDA_DEVICE
typename CUDAArrayBase<T, N, Derived>::ConstReference
176 CUDAArrayBase<T, N, Derived>::operator()(
177 const CUDAStdArray<size_t, N>& idx)
const 182 template <
typename T,
size_t N,
typename Derived>
183 typename CUDAArrayBase<T, N, Derived>::HostReference
184 CUDAArrayBase<T, N, Derived>::At(
size_t i)
186 return HostReference(m_ptr + i);
189 template <
typename T,
size_t N,
typename Derived>
190 T CUDAArrayBase<T, N, Derived>::At(
size_t i)
const 192 return (T)HostReference(m_ptr + i);
195 template <
typename T,
size_t N,
typename Derived>
196 template <
typename... Args>
197 typename CUDAArrayBase<T, N, Derived>::HostReference
198 CUDAArrayBase<T, N, Derived>::At(
size_t i, Args... args)
200 return At(Index(i, args...));
203 template <
typename T,
size_t N,
typename Derived>
204 template <
typename... Args>
205 T CUDAArrayBase<T, N, Derived>::At(
size_t i, Args... args)
const 207 return At(Index(i, args...));
210 template <
typename T,
size_t N,
typename Derived>
211 typename CUDAArrayBase<T, N, Derived>::HostReference
212 CUDAArrayBase<T, N, Derived>::At(
const CUDAStdArray<size_t, N>& idx)
214 return At(Index(idx));
217 template <
typename T,
size_t N,
typename Derived>
218 T CUDAArrayBase<T, N, Derived>::At(
const CUDAStdArray<size_t, N>& idx)
const 220 return At(Index(idx));
223 template <
typename T,
size_t N,
typename Derived>
224 typename CUDAArrayBase<T, N, Derived>::HostReference
225 CUDAArrayBase<T, N, Derived>::operator[](
size_t i)
230 template <
typename T,
size_t N,
typename Derived>
231 T CUDAArrayBase<T, N, Derived>::operator[](
size_t i)
const 236 template <
typename T,
size_t N,
typename Derived>
237 template <
typename... Args>
238 typename CUDAArrayBase<T, N, Derived>::HostReference
239 CUDAArrayBase<T, N, Derived>::operator()(
size_t i, Args... args)
241 return At(i, args...);
244 template <
typename T,
size_t N,
typename Derived>
245 template <
typename... Args>
246 T CUDAArrayBase<T, N, Derived>::operator()(
size_t i, Args... args)
const 248 return At(i, args...);
251 template <
typename T,
size_t N,
typename Derived>
252 typename CUDAArrayBase<T, N, Derived>::HostReference
253 CUDAArrayBase<T, N, Derived>::operator()(
const CUDAStdArray<size_t, N>& idx)
258 template <
typename T,
size_t N,
typename Derived>
259 T CUDAArrayBase<T, N, Derived>::operator()(
260 const CUDAStdArray<size_t, N>& idx)
const 266 template <
typename T,
size_t N,
typename Derived>
267 CUDAArrayBase<T, N, Derived>::CUDAArrayBase() : m_size{}
272 template <
typename T,
size_t N,
typename Derived>
273 CUDAArrayBase<T, N, Derived>::CUDAArrayBase(
const CUDAArrayBase& other)
278 template <
typename T,
size_t N,
typename Derived>
279 CUDAArrayBase<T, N, Derived>::CUDAArrayBase(CUDAArrayBase&& other) noexcept
281 *
this = std::move(other);
284 template <
typename T,
size_t N,
typename Derived>
285 CUDAArrayBase<T, N, Derived>& CUDAArrayBase<T, N, Derived>::operator=(
286 const CUDAArrayBase& other)
292 template <
typename T,
size_t N,
typename Derived>
293 CUDAArrayBase<T, N, Derived>& CUDAArrayBase<T, N, Derived>::operator=(
294 CUDAArrayBase&& other) noexcept
297 other.SetPtrAndSize(
nullptr, CUDAStdArray<size_t, N>{});
301 template <
typename T,
size_t N,
typename Derived>
302 template <
typename... Args>
303 void CUDAArrayBase<T, N, Derived>::SetPtrAndSize(
Pointer ptr,
size_t ni,
309 template <
typename T,
size_t N,
typename Derived>
310 void CUDAArrayBase<T, N, Derived>::SetPtrAndSize(
Pointer ptr,
311 CUDAStdArray<size_t, N> size)
317 template <
typename T,
size_t N,
typename Derived>
318 void CUDAArrayBase<T, N, Derived>::SwapPtrAndSize(CUDAArrayBase& other)
320 CUDASwap(
m_ptr, other.m_ptr);
321 CUDASwap(
m_size, other.m_size);
324 template <
typename T,
size_t N,
typename Derived>
325 void CUDAArrayBase<T, N, Derived>::ClearPtrAndSize()
330 template <
typename T,
size_t N,
typename Derived>
331 template <
typename... Args>
332 size_t CUDAArrayBase<T, N, Derived>::IndexInternal(
size_t d,
size_t i,
335 return i +
m_size[d] * IndexInternal(d + 1, args...);
338 template <
typename T,
size_t N,
typename Derived>
339 size_t CUDAArrayBase<T, N, Derived>::IndexInternal(
size_t,
size_t i)
const 344 template <
typename T,
size_t N,
typename Derived>
345 template <
size_t... I>
346 size_t CUDAArrayBase<T, N, Derived>::IndexInternal(
347 const CUDAStdArray<size_t, N>& idx, std::index_sequence<I...>)
const 349 return Index(idx[I]...);
Pointer m_ptr
Definition: ArrayBase.hpp:124
Vector< size_t, N > m_size
Definition: ArrayBase.hpp:125
Definition: pybind11Utils.hpp:20
size_t Index(size_t i) const
Definition: ArrayBase-Impl.hpp:17
T * Pointer
Definition: ArrayBase.hpp:26
void SetPtrAndSize(Pointer ptr, size_t ni, Args... args)
Definition: ArrayBase-Impl.hpp:250