11 #ifndef CUBBYFLOW_SVD_IMPL_HPP 12 #define CUBBYFLOW_SVD_IMPL_HPP 23 return static_cast<double>(b) >= 0.0 ? std::fabs(a) : -std::fabs(a);
37 result = at * std::sqrt(1 + ct * ct);
42 result = bt * std::sqrt(1 + ct * ct);
56 const int m =
static_cast<int>(a.
GetRows());
57 const int n =
static_cast<int>(a.
GetCols());
59 int i, j = 0, jj = 0, k = 0, l = 0, nm = 0;
60 T c = 0, f = 0, h = 0, s = 0, x = 0, y = 0, z = 0;
61 T anorm = 0, g = 0, scale = 0;
65 throw std::invalid_argument{
66 "Number of rows of input matrix must greater than or equal to " 78 for (i = 0; i < n; i++)
87 for (k = i; k < m; k++)
89 scale += std::fabs(u(k, i));
92 if (std::fabs(static_cast<double>(scale)) >=
93 std::numeric_limits<double>::epsilon())
95 for (k = i; k < m; k++)
98 s += u(k, i) * u(k, i);
108 for (j = l; j < n; j++)
112 for (k = i; k < m; k++)
114 s += u(k, i) * u(k, j);
119 for (k = i; k < m; k++)
121 u(k, j) += f * u(k, i);
126 for (k = i; k < m; k++)
138 if (i < m && i != n - 1)
140 for (k = l; k < n; k++)
142 scale += std::fabs(u(i, k));
145 if (std::fabs(static_cast<double>(scale)) >=
146 std::numeric_limits<double>::epsilon())
148 for (k = l; k < n; k++)
151 s += u(i, k) * u(i, k);
159 for (k = l; k < n; k++)
161 rv1[k] =
static_cast<T
>(u(i, k)) / h;
166 for (j = l; j < m; j++)
170 for (k = l; k < n; k++)
172 s += u(j, k) * u(i, k);
175 for (k = l; k < n; k++)
177 u(j, k) += s * rv1[k];
182 for (k = l; k < n; k++)
189 anorm = std::max(anorm,
190 (std::fabs(static_cast<T>(w[i])) + std::fabs(rv1[i])));
194 for (i = n - 1; i >= 0; i--)
198 if (std::fabs(static_cast<double>(g)) >=
199 std::numeric_limits<double>::epsilon())
201 for (j = l; j < n; j++)
203 v(j, i) = ((u(i, j) / u(i, l)) / g);
207 for (j = l; j < n; j++)
211 for (k = l; k < n; k++)
213 s += u(i, k) * v(k, j);
216 for (k = l; k < n; k++)
218 v(k, j) += s * v(k, i);
223 for (j = l; j < n; j++)
225 v(i, j) = v(j, i) = 0;
235 for (i = n - 1; i >= 0; i--)
242 for (j = l; j < n; j++)
248 if (std::fabs(static_cast<double>(g)) >=
249 std::numeric_limits<double>::epsilon())
255 for (j = l; j < n; j++)
259 for (k = l; k < m; k++)
261 s += u(k, i) * u(k, j);
264 f = (s / u(i, i)) * g;
266 for (k = i; k < m; k++)
268 u(k, j) += f * u(k, i);
273 for (j = i; j < m; j++)
275 u(j, i) = u(j, i) * g;
280 for (j = i; j < m; j++)
290 for (k = n - 1; k >= 0; k--)
293 for (
int its = 0; its < 30; its++)
298 for (l = k; l >= 0; l--)
303 if (std::fabs(static_cast<double>(rv1[l])) <=
304 std::numeric_limits<double>::epsilon())
310 if (std::fabs(static_cast<double>(w[nm])) <=
311 std::numeric_limits<double>::epsilon())
322 for (i = l; i <= k; i++)
326 if (std::fabs(static_cast<double>(f)) <=
327 std::numeric_limits<double>::epsilon())
331 w[i] =
static_cast<T
>(h);
336 for (j = 0; j < m; j++)
340 u(j, nm) = y * c + z * s;
341 u(j, i) = z * c - y * s;
357 for (j = 0; j < n; j++)
368 throw std::logic_error{
"No convergence after 30 iterations" };
377 f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2 * h * y);
379 f = ((x - z) * (x + z) +
386 for (j = l; j <= nm; j++)
402 for (jj = 0; jj < n; jj++)
406 v(jj, j) = x * c + z * s;
407 v(jj, i) = z * c - x * s;
413 if (std::fabs(static_cast<double>(z)) >=
414 std::numeric_limits<double>::epsilon())
421 f = (c * g) + (s * y);
422 x = (c * y) - (s * g);
424 for (jj = 0; jj < m; jj++)
428 u(jj, j) = y * c + z * s;
429 u(jj, i) = z * c - y * s;
440 template <
typename T,
size_t M,
size_t N>
444 const int m =
static_cast<int>(M);
445 const int n =
static_cast<int>(N);
447 int i, its, j = 0, jj = 0, k = 0, l = 0, nm = 0;
448 T c = 0, f = 0, h = 0, s = 0, x = 0, y = 0, z = 0;
449 T anorm = 0, g = 0, scale = 0;
451 static_assert(m >= n,
452 "Number of rows of input matrix must greater than or equal " 462 for (i = 0; i < n; i++)
471 for (k = i; k < m; k++)
473 scale += std::fabs(u(k, i));
478 for (k = i; k < m; k++)
481 s += u(k, i) * u(k, i);
491 for (j = l; j < n; j++)
495 for (k = i; k < m; k++)
497 s += u(k, i) * u(k, j);
502 for (k = i; k < m; k++)
504 u(k, j) += f * u(k, i);
509 for (k = i; k < m; k++)
521 if (i < m && i != n - 1)
523 for (k = l; k < n; k++)
525 scale += std::fabs(u(i, k));
530 for (k = l; k < n; k++)
533 s += u(i, k) * u(i, k);
541 for (k = l; k < n; k++)
543 rv1[k] =
static_cast<T
>(u(i, k)) / h;
548 for (j = l; j < m; j++)
552 for (k = l; k < n; k++)
554 s += u(j, k) * u(i, k);
557 for (k = l; k < n; k++)
559 u(j, k) += s * rv1[k];
564 for (k = l; k < n; k++)
570 anorm = std::max(anorm,
571 (std::fabs(static_cast<T>(w[i])) + std::fabs(rv1[i])));
575 for (i = n - 1; i >= 0; i--)
581 for (j = l; j < n; j++)
583 v(j, i) = ((u(i, j) / u(i, l)) / g);
587 for (j = l; j < n; j++)
591 for (k = l; k < n; k++)
593 s += u(i, k) * v(k, j);
596 for (k = l; k < n; k++)
598 v(k, j) += s * v(k, i);
603 for (j = l; j < n; j++)
605 v(i, j) = v(j, i) = 0;
615 for (i = n - 1; i >= 0; i--)
622 for (j = l; j < n; j++)
634 for (j = l; j < n; j++)
638 for (k = l; k < m; k++)
640 s += u(k, i) * u(k, j);
643 f = (s / u(i, i)) * g;
645 for (k = i; k < m; k++)
647 u(k, j) += f * u(k, i);
652 for (j = i; j < m; j++)
654 u(j, i) = u(j, i) * g;
659 for (j = i; j < m; j++)
669 for (k = n - 1; k >= 0; k--)
672 for (its = 0; its < 30; its++)
677 for (l = k; l >= 0; l--)
682 if (std::fabs(rv1[l]) + anorm == anorm)
688 if (std::fabs(static_cast<T>(w[nm])) + anorm == anorm)
699 for (i = l; i <= k; i++)
703 if (std::fabs(f) + anorm != anorm)
707 w[i] =
static_cast<T
>(h);
712 for (j = 0; j < m; j++)
716 u(j, nm) = y * c + z * s;
717 u(j, i) = z * c - y * s;
733 for (j = 0; j < n; j++)
744 throw std::logic_error{
"No convergence after 30 iterations" };
753 f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2 * h * y);
755 f = ((x - z) * (x + z) +
762 for (j = l; j <= nm; j++)
778 for (jj = 0; jj < n; jj++)
782 v(jj, j) = x * c + z * s;
783 v(jj, i) = z * c - x * s;
796 f = (c * g) + (s * y);
797 x = (c * y) - (s * g);
799 for (jj = 0; jj < m; jj++)
803 u(jj, j) = y * c + z * s;
804 u(jj, i) = z * c - y * s;
T Sign(T a, T b)
Definition: SVD-Impl.hpp:21
size_t GetCols() const
Definition: Matrix-Impl.hpp:1069
void Resize(size_t rows, size_t cols, ConstReference val=ValueType{})
Definition: Matrix-Impl.hpp:1035
Definition: Matrix.hpp:27
Definition: pybind11Utils.hpp:20
Definition: Matrix.hpp:531
Definition: Matrix.hpp:611
T Pythag(T a, T b)
Definition: SVD-Impl.hpp:27
void Resize(size_t rows, ConstReference val=ValueType{})
Definition: Matrix-Impl.hpp:1245
void SVD(const MatrixMxN< T > &a, MatrixMxN< T > &u, VectorN< T > &w, MatrixMxN< T > &v)
Singular value decomposition (SVD).
Definition: SVD-Impl.hpp:54
size_t GetRows() const
Definition: Matrix-Impl.hpp:1063