10 #ifndef EIGEN_VISITOR_H
11 #define EIGEN_VISITOR_H
13 #include "./InternalHeaderCheck.h"
19 template<
typename Visitor,
typename Derived,
int UnrollCount,
bool Vectorize=((Derived::PacketAccess!=0) && functor_traits<Visitor>::PacketAccess)>
22 template<
typename Visitor,
typename Derived,
int UnrollCount>
23 struct visitor_impl<Visitor, Derived, UnrollCount, false>
26 col = Derived::IsRowMajor ? (UnrollCount-1) % Derived::ColsAtCompileTime
27 : (UnrollCount-1) / Derived::RowsAtCompileTime,
28 row = Derived::IsRowMajor ? (UnrollCount-1) / Derived::ColsAtCompileTime
29 : (UnrollCount-1) % Derived::RowsAtCompileTime
33 static inline void run(
const Derived &mat, Visitor& visitor)
35 visitor_impl<Visitor, Derived, UnrollCount-1>::run(mat, visitor);
36 visitor(mat.coeff(row, col), row, col);
40 template<
typename Visitor,
typename Derived>
41 struct visitor_impl<Visitor, Derived, 1, false>
44 static inline void run(
const Derived &mat, Visitor& visitor)
46 return visitor.init(mat.coeff(0, 0), 0, 0);
51 template<
typename Visitor,
typename Derived>
52 struct visitor_impl<Visitor, Derived, 0, false> {
54 static inline void run(
const Derived &, Visitor& )
58 template<
typename Visitor,
typename Derived>
59 struct visitor_impl<Visitor, Derived,
Dynamic, false>
62 static inline void run(
const Derived& mat, Visitor& visitor)
64 visitor.init(mat.coeff(0,0), 0, 0);
65 if (Derived::IsRowMajor) {
66 for(
Index i = 1; i < mat.cols(); ++i) {
67 visitor(mat.coeff(0, i), 0, i);
69 for(
Index j = 1; j < mat.rows(); ++j) {
70 for(
Index i = 0; i < mat.cols(); ++i) {
71 visitor(mat.coeff(j, i), j, i);
75 for(
Index i = 1; i < mat.rows(); ++i) {
76 visitor(mat.coeff(i, 0), i, 0);
78 for(
Index j = 1; j < mat.cols(); ++j) {
79 for(
Index i = 0; i < mat.rows(); ++i) {
80 visitor(mat.coeff(i, j), i, j);
87 template<
typename Visitor,
typename Derived,
int UnrollSize>
88 struct visitor_impl<Visitor, Derived, UnrollSize, true>
90 typedef typename Derived::Scalar Scalar;
91 typedef typename packet_traits<Scalar>::type Packet;
94 static inline void run(
const Derived& mat, Visitor& visitor)
96 const Index PacketSize = packet_traits<Scalar>::size;
97 visitor.init(mat.coeff(0,0), 0, 0);
98 if (Derived::IsRowMajor) {
99 for(
Index i = 0; i < mat.rows(); ++i) {
100 Index j = i == 0 ? 1 : 0;
101 for(; j+PacketSize-1 < mat.cols(); j += PacketSize) {
102 Packet p = mat.packet(i, j);
103 visitor.packet(p, i, j);
105 for(; j < mat.cols(); ++j)
106 visitor(mat.coeff(i, j), i, j);
109 for(
Index j = 0; j < mat.cols(); ++j) {
110 Index i = j == 0 ? 1 : 0;
111 for(; i+PacketSize-1 < mat.rows(); i += PacketSize) {
112 Packet p = mat.packet(i, j);
113 visitor.packet(p, i, j);
115 for(; i < mat.rows(); ++i)
116 visitor(mat.coeff(i, j), i, j);
123 template<
typename XprType>
124 class visitor_evaluator
127 typedef internal::evaluator<XprType> Evaluator;
131 IsRowMajor = XprType::IsRowMajor,
132 RowsAtCompileTime = XprType::RowsAtCompileTime,
133 ColsAtCompileTime = XprType::ColsAtCompileTime,
134 CoeffReadCost = Evaluator::CoeffReadCost
139 explicit visitor_evaluator(
const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) { }
141 typedef typename XprType::Scalar Scalar;
142 typedef std::remove_const_t<typename XprType::CoeffReturnType> CoeffReturnType;
143 typedef std::remove_const_t<typename XprType::PacketReturnType> PacketReturnType;
145 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
Index rows() const EIGEN_NOEXCEPT {
return m_xpr.rows(); }
146 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
Index cols() const EIGEN_NOEXCEPT {
return m_xpr.cols(); }
147 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
Index size() const EIGEN_NOEXCEPT {
return m_xpr.size(); }
149 EIGEN_DEVICE_FUNC CoeffReturnType coeff(
Index row,
Index col)
const
150 {
return m_evaluator.coeff(row, col); }
151 EIGEN_DEVICE_FUNC PacketReturnType packet(
Index row,
Index col)
const
152 {
return m_evaluator.template packet<Unaligned,PacketReturnType>(row, col); }
155 Evaluator m_evaluator;
156 const XprType &m_xpr;
180 template<
typename Derived>
181 template<
typename Visitor>
188 typedef typename internal::visitor_evaluator<Derived> ThisEvaluator;
189 ThisEvaluator thisEval(derived());
192 unroll = SizeAtCompileTime !=
Dynamic
193 && SizeAtCompileTime * int(ThisEvaluator::CoeffReadCost) + (SizeAtCompileTime-1) *
int(internal::functor_traits<Visitor>::Cost) <= EIGEN_UNROLLING_LIMIT
195 return internal::visitor_impl<Visitor, ThisEvaluator, unroll ? int(SizeAtCompileTime) :
Dynamic>::run(thisEval, visitor);
203 template <
typename Derived>
208 coeff_visitor() : row(-1), col(-1), res(0) {}
209 typedef typename Derived::Scalar Scalar;
213 inline void init(
const Scalar& value,
Index i,
Index j)
222 template<
typename Scalar,
int NaNPropagation,
bool is_min=true>
223 struct minmax_compare {
224 typedef typename packet_traits<Scalar>::type Packet;
225 static EIGEN_DEVICE_FUNC
inline bool compare(Scalar a, Scalar b) {
return a < b; }
226 static EIGEN_DEVICE_FUNC
inline Scalar predux(
const Packet& p) {
return predux_min<NaNPropagation>(p);}
229 template<
typename Scalar,
int NaNPropagation>
230 struct minmax_compare<Scalar, NaNPropagation, false> {
231 typedef typename packet_traits<Scalar>::type Packet;
232 static EIGEN_DEVICE_FUNC
inline bool compare(Scalar a, Scalar b) {
return a > b; }
233 static EIGEN_DEVICE_FUNC
inline Scalar predux(
const Packet& p) {
return predux_max<NaNPropagation>(p);}
236 template <
typename Derived,
bool is_min,
int NaNPropagation>
237 struct minmax_coeff_visitor : coeff_visitor<Derived>
239 using Scalar =
typename Derived::Scalar;
240 using Packet =
typename packet_traits<Scalar>::type;
241 using Comparator = minmax_compare<Scalar, NaNPropagation, is_min>;
243 EIGEN_DEVICE_FUNC
inline
244 void operator() (
const Scalar& value,
Index i,
Index j)
246 if(Comparator::compare(value, this->res)) {
253 EIGEN_DEVICE_FUNC
inline
254 void packet(
const Packet& p,
Index i,
Index j) {
255 const Index PacketSize = packet_traits<Scalar>::size;
256 Scalar value = Comparator::predux(p);
257 if (Comparator::compare(value, this->res)) {
258 const Packet range = preverse(plset<Packet>(Scalar(1)));
259 Packet mask = pcmp_eq(pset1<Packet>(value), p);
260 Index max_idx = PacketSize -
static_cast<Index>(predux_max(pand(range, mask)));
262 this->row = Derived::IsRowMajor ? i : i + max_idx;;
263 this->col = Derived::IsRowMajor ? j + max_idx : j;
270 template <
typename Derived,
bool is_min>
271 struct minmax_coeff_visitor<Derived, is_min,
PropagateNumbers> : coeff_visitor<Derived>
273 typedef typename Derived::Scalar Scalar;
274 using Packet =
typename packet_traits<Scalar>::type;
275 using Comparator = minmax_compare<Scalar, PropagateNumbers, is_min>;
277 EIGEN_DEVICE_FUNC
inline
278 void operator() (
const Scalar& value,
Index i,
Index j)
280 if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
287 EIGEN_DEVICE_FUNC
inline
288 void packet(
const Packet& p,
Index i,
Index j) {
289 const Index PacketSize = packet_traits<Scalar>::size;
290 Scalar value = Comparator::predux(p);
291 if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
292 const Packet range = preverse(plset<Packet>(Scalar(1)));
294 Packet mask = pcmp_eq(pset1<Packet>(value), p);
295 Index max_idx = PacketSize -
static_cast<Index>(predux_max(pand(range, mask)));
297 this->row = Derived::IsRowMajor ? i : i + max_idx;;
298 this->col = Derived::IsRowMajor ? j + max_idx : j;
306 template <
typename Derived,
bool is_min>
307 struct minmax_coeff_visitor<Derived, is_min,
PropagateNaN> : coeff_visitor<Derived>
309 typedef typename Derived::Scalar Scalar;
310 using Packet =
typename packet_traits<Scalar>::type;
311 using Comparator = minmax_compare<Scalar, PropagateNaN, is_min>;
313 EIGEN_DEVICE_FUNC
inline
314 void operator() (
const Scalar& value,
Index i,
Index j)
316 const bool value_is_nan = (numext::isnan)(value);
317 if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
324 EIGEN_DEVICE_FUNC
inline
325 void packet(
const Packet& p,
Index i,
Index j) {
326 const Index PacketSize = packet_traits<Scalar>::size;
327 Scalar value = Comparator::predux(p);
328 const bool value_is_nan = (numext::isnan)(value);
329 if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
330 const Packet range = preverse(plset<Packet>(Scalar(1)));
332 Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1<Packet>(value), p);
333 Index max_idx = PacketSize -
static_cast<Index>(predux_max(pand(range, mask)));
335 this->row = Derived::IsRowMajor ? i : i + max_idx;;
336 this->col = Derived::IsRowMajor ? j + max_idx : j;
341 template<
typename Scalar,
bool is_min,
int NaNPropagation>
342 struct functor_traits<minmax_coeff_visitor<Scalar, is_min, NaNPropagation> > {
344 Cost = NumTraits<Scalar>::AddCost,
362 template<
typename Derived>
363 template<
int NaNPropagation,
typename IndexType>
365 typename internal::traits<Derived>::Scalar
368 eigen_assert(this->rows()>0 && this->cols()>0 &&
"you are using an empty matrix");
370 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
371 this->visit(minVisitor);
372 *rowId = minVisitor.row;
373 if (colId) *colId = minVisitor.col;
374 return minVisitor.res;
387 template<
typename Derived>
388 template<
int NaNPropagation,
typename IndexType>
390 typename internal::traits<Derived>::Scalar
393 eigen_assert(this->rows()>0 && this->cols()>0 &&
"you are using an empty matrix");
395 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
396 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
397 this->visit(minVisitor);
398 *index = IndexType((RowsAtCompileTime==1) ? minVisitor.col : minVisitor.row);
399 return minVisitor.res;
413 template<
typename Derived>
414 template<
int NaNPropagation,
typename IndexType>
416 typename internal::traits<Derived>::Scalar
419 eigen_assert(this->rows()>0 && this->cols()>0 &&
"you are using an empty matrix");
421 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
422 this->visit(maxVisitor);
423 *rowPtr = maxVisitor.row;
424 if (colPtr) *colPtr = maxVisitor.col;
425 return maxVisitor.res;
438 template<
typename Derived>
439 template<
int NaNPropagation,
typename IndexType>
441 typename internal::traits<Derived>::Scalar
444 eigen_assert(this->rows()>0 && this->cols()>0 &&
"you are using an empty matrix");
446 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
447 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
448 this->visit(maxVisitor);
449 *index = (RowsAtCompileTime==1) ? maxVisitor.col : maxVisitor.row;
450 return maxVisitor.res;
internal::traits< Derived >::Scalar minCoeff() const
Definition: Redux.h:433
void visit(Visitor &func) const
Definition: Visitor.h:183
internal::traits< Derived >::Scalar maxCoeff() const
Definition: Redux.h:448
@ PropagateNaN
Definition: Constants.h:345
@ PropagateNumbers
Definition: Constants.h:347
const unsigned int PacketAccessBit
Definition: Constants.h:96
Namespace containing all symbols from the Eigen library.
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:59
const int Dynamic
Definition: Constants.h:24