10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
13 #include "./InternalHeaderCheck.h"
25 template<
typename Broadcast,
typename XprType>
26 struct traits<TensorBroadcastingOp<Broadcast, XprType> > :
public traits<XprType>
28 typedef typename XprType::Scalar Scalar;
29 typedef traits<XprType> XprTraits;
30 typedef typename XprTraits::StorageKind StorageKind;
31 typedef typename XprTraits::Index
Index;
32 typedef typename XprType::Nested Nested;
33 typedef std::remove_reference_t<Nested> Nested_;
34 static constexpr
int NumDimensions = XprTraits::NumDimensions;
35 static constexpr
int Layout = XprTraits::Layout;
36 typedef typename XprTraits::PointerType PointerType;
39 template<
typename Broadcast,
typename XprType>
40 struct eval<TensorBroadcastingOp<Broadcast, XprType>,
Eigen::Dense>
42 typedef const TensorBroadcastingOp<Broadcast, XprType> EIGEN_DEVICE_REF type;
45 template<
typename Broadcast,
typename XprType>
46 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
48 typedef TensorBroadcastingOp<Broadcast, XprType> type;
51 template <
typename Dims>
52 struct is_input_scalar {
53 static const bool value =
false;
56 struct is_input_scalar<Sizes<> > {
57 static const bool value =
true;
59 #ifndef EIGEN_EMULATE_CXX11_META_H
60 template <
typename std::ptrdiff_t... Indices>
61 struct is_input_scalar<Sizes<Indices...> > {
62 static const bool value = (Sizes<Indices...>::total_size == 1);
70 template<
typename Broadcast,
typename XprType>
71 class TensorBroadcastingOp :
public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
74 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
76 typedef typename XprType::CoeffReturnType CoeffReturnType;
77 typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
78 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
79 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index
Index;
81 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(
const XprType& expr,
const Broadcast& broadcast)
82 : m_xpr(expr), m_broadcast(broadcast) {}
85 const Broadcast& broadcast()
const {
return m_broadcast; }
88 const internal::remove_all_t<typename XprType::Nested>&
89 expression()
const {
return m_xpr; }
92 typename XprType::Nested m_xpr;
93 const Broadcast m_broadcast;
98 template<
typename Broadcast,
typename ArgType,
typename Device>
99 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
101 typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
102 typedef typename XprType::Index Index;
103 static constexpr
int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
104 typedef DSizes<Index, NumDims> Dimensions;
105 typedef typename XprType::Scalar Scalar;
106 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
107 typedef typename XprType::CoeffReturnType CoeffReturnType;
108 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
109 static constexpr
int PacketSize = PacketType<CoeffReturnType, Device>::size;
111 bool isCopy, nByOne, oneByN;
113 typedef StorageMemory<CoeffReturnType, Device> Storage;
114 typedef typename Storage::Type EvaluatorPointerType;
117 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
118 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
119 BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess,
120 PreferBlockAccess =
true,
123 static constexpr
int Layout = TensorEvaluator<ArgType, Device>::Layout;
125 typedef std::remove_const_t<Scalar> ScalarNoConst;
129 typedef DSizes<Index, 2 * NumDims> BroadcastDimensions;
132 typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
133 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
135 typedef typename TensorEvaluator<const ArgType, Device>::TensorBlock
138 typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumDims,
143 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
144 : isCopy(false), nByOne(false), oneByN(false),
145 m_device(device), m_broadcast(op.broadcast()), m_impl(op.expression(), device)
151 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
152 const InputDimensions& input_dims = m_impl.dimensions();
154 for (
int i = 0; i < NumDims; ++i) {
155 eigen_assert(input_dims[i] > 0);
156 m_dimensions[i] = input_dims[i] * m_broadcast[i];
157 if (m_broadcast[i] != 1) {
162 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
163 m_inputStrides[0] = 1;
164 m_outputStrides[0] = 1;
165 for (
int i = 1; i < NumDims; ++i) {
166 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
167 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
170 m_inputStrides[NumDims-1] = 1;
171 m_outputStrides[NumDims-1] = 1;
172 for (
int i = NumDims-2; i >= 0; --i) {
173 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
174 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
178 if (input_dims[0] == 1) {
180 for (
int i = 1; i < NumDims; ++i) {
181 if (m_broadcast[i] != 1) {
186 }
else if (input_dims[NumDims-1] == 1) {
188 for (
int i = 0; i < NumDims-1; ++i) {
189 if (m_broadcast[i] != 1) {
198 if (!oneByN && !nByOne) {
199 if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) {
202 for (
int i = 1; i < NumDims-1; ++i) {
203 if (m_broadcast[i] != 1) {
213 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
215 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
216 m_impl.evalSubExprsIfNeeded(NULL);
220 #ifdef EIGEN_USE_THREADS
221 template <
typename EvalSubExprsCallback>
222 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(
223 EvaluatorPointerType, EvalSubExprsCallback done) {
224 m_impl.evalSubExprsIfNeededAsync(
nullptr, [done](
bool) { done(
true); });
228 EIGEN_STRONG_INLINE
void cleanup() {
232 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index)
const
234 if (internal::is_input_scalar<internal::remove_all_t<InputDimensions>>::value) {
235 return m_impl.coeff(0);
238 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
240 return m_impl.coeff(index);
242 return coeffColMajor(index);
246 return m_impl.coeff(index);
248 return coeffRowMajor(index);
254 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Index indexColMajor(Index index)
const {
255 Index inputIndex = 0;
257 for (
int i = NumDims - 1; i > 0; --i) {
258 const Index idx = index / m_outputStrides[i];
259 if (internal::index_statically_eq<Broadcast>(i, 1)) {
260 eigen_assert(idx < m_impl.dimensions()[i]);
261 inputIndex += idx * m_inputStrides[i];
263 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
264 eigen_assert(idx % m_impl.dimensions()[i] == 0);
266 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
269 index -= idx * m_outputStrides[i];
271 if (internal::index_statically_eq<Broadcast>(0, 1)) {
272 eigen_assert(index < m_impl.dimensions()[0]);
275 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
276 eigen_assert(index % m_impl.dimensions()[0] == 0);
278 inputIndex += (index % m_impl.dimensions()[0]);
284 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index)
const
286 return m_impl.coeff(indexColMajor(index));
289 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Index indexRowMajor(Index index)
const {
290 Index inputIndex = 0;
292 for (
int i = 0; i < NumDims - 1; ++i) {
293 const Index idx = index / m_outputStrides[i];
294 if (internal::index_statically_eq<Broadcast>(i, 1)) {
295 eigen_assert(idx < m_impl.dimensions()[i]);
296 inputIndex += idx * m_inputStrides[i];
298 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
299 eigen_assert(idx % m_impl.dimensions()[i] == 0);
301 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
304 index -= idx * m_outputStrides[i];
306 if (internal::index_statically_eq<Broadcast>(NumDims - 1, 1)) {
307 eigen_assert(index < m_impl.dimensions()[NumDims - 1]);
310 if (internal::index_statically_eq<InputDimensions>(NumDims - 1, 1)) {
311 eigen_assert(index % m_impl.dimensions()[NumDims - 1] == 0);
313 inputIndex += (index % m_impl.dimensions()[NumDims - 1]);
319 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index)
const
321 return m_impl.coeff(indexRowMajor(index));
324 template<
int LoadMode>
325 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index)
const
327 if (internal::is_input_scalar<internal::remove_all_t<InputDimensions>>::value) {
328 return internal::pset1<PacketReturnType>(m_impl.coeff(0));
331 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
333 #ifdef EIGEN_GPU_COMPILE_PHASE
336 return m_impl.template packet<Unaligned>(index);
338 return m_impl.template packet<LoadMode>(index);
340 }
else if (oneByN && !nByOne) {
341 return packetNByOne<LoadMode>(index);
342 }
else if (!oneByN && nByOne) {
343 return packetOneByN<LoadMode>(index);
344 }
else if (oneByN && nByOne) {
345 return packetOneByNByOne<LoadMode>(index);
347 return packetColMajor<LoadMode>(index);
351 #ifdef EIGEN_GPU_COMPILE_PHASE
353 return m_impl.template packet<Unaligned>(index);
355 return m_impl.template packet<LoadMode>(index);
357 }
else if (oneByN && !nByOne) {
358 return packetOneByN<LoadMode>(index);
359 }
else if (!oneByN && nByOne) {
360 return packetNByOne<LoadMode>(index);
361 }
else if (oneByN && nByOne) {
362 return packetOneByNByOne<LoadMode>(index);
364 return packetRowMajor<LoadMode>(index);
369 template<
int LoadMode>
370 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne
373 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
375 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
376 Index startDim, endDim;
377 Index inputIndex, outputOffset, batchedIndex;
379 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
380 startDim = NumDims - 1;
384 endDim = NumDims - 2;
387 batchedIndex = index % m_outputStrides[startDim];
388 inputIndex = batchedIndex / m_outputStrides[endDim];
389 outputOffset = batchedIndex % m_outputStrides[endDim];
391 if (outputOffset + PacketSize <= m_outputStrides[endDim]) {
392 values[0] = m_impl.coeff(inputIndex);
393 return internal::pload1<PacketReturnType>(values);
396 for (
int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
397 if (outputOffset + cur < m_outputStrides[endDim]) {
398 values[i] = m_impl.coeff(inputIndex);
401 inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
402 values[i] = m_impl.coeff(inputIndex);
407 return internal::pload<PacketReturnType>(values);
411 template<
int LoadMode>
412 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index)
const
418 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
421 const Index M = (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) ?
422 m_inputStrides[NumDims - 1] : m_inputStrides[0];
423 Index inputIndex = index % M;
424 if (inputIndex + PacketSize <= M) {
425 return m_impl.template packet<Unaligned>(inputIndex);
427 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
429 for (
int i = 0; i < PacketSize; ++i) {
430 if (inputIndex > M - 1) {
433 values[i] = m_impl.coeff(inputIndex++);
435 return internal::pload<PacketReturnType>(values);
439 template<
int LoadMode>
440 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetNByOne(Index index)
const
446 eigen_assert(index + PacketSize-1 < dimensions().TotalSize());
448 const Index M = (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) ?
449 m_broadcast[0] : m_broadcast[NumDims - 1];
451 Index inputIndex = index / M;
452 Index outputOffset = index % M;
453 if (outputOffset + PacketSize <= M) {
454 return internal::pset1<PacketReturnType>(m_impl.coeff(inputIndex));
456 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
458 for (
int i = 0; i < PacketSize; ++i) {
459 if (outputOffset < M) {
460 values[i] = m_impl.coeff(inputIndex);
463 values[i] = m_impl.coeff(++inputIndex);
467 return internal::pload<PacketReturnType>(values);
473 template<
int LoadMode>
474 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index)
const
476 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
478 const Index originalIndex = index;
480 Index inputIndex = 0;
482 for (
int i = NumDims - 1; i > 0; --i) {
483 const Index idx = index / m_outputStrides[i];
484 if (internal::index_statically_eq<Broadcast>(i, 1)) {
485 eigen_assert(idx < m_impl.dimensions()[i]);
486 inputIndex += idx * m_inputStrides[i];
488 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
489 eigen_assert(idx % m_impl.dimensions()[i] == 0);
491 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
494 index -= idx * m_outputStrides[i];
497 if (internal::index_statically_eq<Broadcast>(0, 1)) {
498 eigen_assert(index < m_impl.dimensions()[0]);
499 innermostLoc = index;
501 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
502 eigen_assert(index % m_impl.dimensions()[0] == 0);
505 innermostLoc = index % m_impl.dimensions()[0];
508 inputIndex += innermostLoc;
512 if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
513 return m_impl.template packet<Unaligned>(inputIndex);
515 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
516 values[0] = m_impl.coeff(inputIndex);
518 for (
int i = 1; i < PacketSize; ++i) {
519 if (innermostLoc + i < m_impl.dimensions()[0]) {
520 values[i] = m_impl.coeff(inputIndex+i);
522 values[i] = coeffColMajor(originalIndex+i);
525 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
530 template<
int LoadMode>
531 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index)
const
533 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
535 const Index originalIndex = index;
537 Index inputIndex = 0;
539 for (
int i = 0; i < NumDims - 1; ++i) {
540 const Index idx = index / m_outputStrides[i];
541 if (internal::index_statically_eq<Broadcast>(i, 1)) {
542 eigen_assert(idx < m_impl.dimensions()[i]);
543 inputIndex += idx * m_inputStrides[i];
545 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
546 eigen_assert(idx % m_impl.dimensions()[i] == 0);
548 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
551 index -= idx * m_outputStrides[i];
554 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
555 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
556 innermostLoc = index;
558 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
559 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
562 innermostLoc = index % m_impl.dimensions()[NumDims-1];
565 inputIndex += innermostLoc;
569 if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
570 return m_impl.template packet<Unaligned>(inputIndex);
572 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
573 values[0] = m_impl.coeff(inputIndex);
575 for (
int i = 1; i < PacketSize; ++i) {
576 if (innermostLoc + i < m_impl.dimensions()[NumDims-1]) {
577 values[i] = m_impl.coeff(inputIndex+i);
579 values[i] = coeffRowMajor(originalIndex+i);
582 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
587 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
588 costPerCoeff(
bool vectorized)
const {
589 double compute_cost = TensorOpCost::AddCost<Index>();
590 if (!isCopy && NumDims > 0) {
592 for (
int i = NumDims - 1; i > 0; --i) {
593 compute_cost += TensorOpCost::DivCost<Index>();
594 if (internal::index_statically_eq<Broadcast>(i, 1)) {
596 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
598 if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
599 compute_cost += TensorOpCost::MulCost<Index>() +
600 TensorOpCost::ModCost<Index>() +
601 TensorOpCost::AddCost<Index>();
605 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
608 return m_impl.costPerCoeff(vectorized) +
609 TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
612 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
613 internal::TensorBlockResourceRequirements getResourceRequirements()
const {
616 const size_t target_size = m_device.firstLevelCacheSize();
617 return internal::TensorBlockResourceRequirements::merge(
618 m_impl.getResourceRequirements(),
619 internal::TensorBlockResourceRequirements::skewed<Scalar>(target_size));
622 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock
623 block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
624 bool =
false)
const {
625 BlockBroadcastingParams params = blockBroadcastingParams(desc);
627 if (params.inner_dim_size == 0 || params.bcast_dim_size == 0) {
632 const typename TensorBlock::Storage block_storage =
633 TensorBlock::prepareStorage(desc, scratch);
634 ScalarNoConst* materialized_output = block_storage.data();
637 size_t materialized_input_size = 0;
638 ScalarNoConst* materialized_input = NULL;
643 array<BlockBroadcastingIteratorState, NumDims> it;
646 for (
int i = params.inner_dim_count + 1; i < NumDims; ++i) {
647 const Index dim = IsColMajor ? i : NumDims - 1 - i;
648 it[idx].size = params.output_dims[dim];
650 it[idx].output_stride = m_outputStrides[dim];
651 it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
656 Index output_offset = 0;
660 const Index output_size = NumDims == 0 ? 1 : params.output_dims.TotalSize();
662 for (Index num_output_coeffs = 0; num_output_coeffs < output_size;) {
663 ScalarNoConst* bcast_output = materialized_output + num_output_coeffs;
664 Index bcast_offset = desc.offset() + output_offset;
667 num_output_coeffs += BroadcastBlockAlongBcastDim(
668 params, bcast_offset, scratch, bcast_output, &materialized_input,
669 &materialized_input_size);
672 for (
int j = 0; j < idx; ++j) {
673 if (++it[j].count < it[j].size) {
674 output_offset += it[j].output_stride;
678 output_offset -= it[j].output_span;
682 return block_storage.AsTensorMaterializedBlock();
685 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
687 const TensorEvaluator<ArgType, Device>& impl()
const {
return m_impl; }
689 Broadcast functor()
const {
return m_broadcast; }
690 #ifdef EIGEN_USE_SYCL
692 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(
693 cl::sycl::handler& cgh)
const {
698 static constexpr
bool IsColMajor =
699 static_cast<int>(Layout) ==
static_cast<int>(
ColMajor);
718 struct BlockBroadcastingParams {
719 Dimensions input_dims;
720 Dimensions output_dims;
721 Dimensions output_strides;
725 Index bcast_dim_size;
726 Index inner_dim_size;
730 Dimensions input_block_sizes;
731 Dimensions input_block_strides;
734 BroadcastDimensions bcast_block_sizes;
735 BroadcastDimensions bcast_block_strides;
736 BroadcastDimensions bcast_input_strides;
739 struct BlockBroadcastingIteratorState {
746 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockBroadcastingParams
747 blockBroadcastingParams(TensorBlockDesc& desc)
const {
748 BlockBroadcastingParams params;
750 params.input_dims = Dimensions(m_impl.dimensions());
753 params.output_dims = desc.dimensions();
754 params.output_strides = internal::strides<Layout>(params.output_dims);
758 params.bcast_dim = 0;
759 params.bcast_dim_size = 1;
760 params.inner_dim_size = 1;
764 params.inner_dim_count = 0;
766 for (
int i = 0; i < NumDims; ++i) {
767 const int dim = IsColMajor ? i : NumDims - i - 1;
769 if (params.output_dims[dim] == m_dimensions[dim]) {
770 params.inner_dim_size *= params.output_dims[dim];
771 ++params.inner_dim_count;
776 eigen_assert(params.output_dims[dim] < m_dimensions[dim]);
777 params.bcast_dim = dim;
778 params.bcast_dim_size = params.output_dims[dim];
783 for (
int i = 0; i < params.inner_dim_count; ++i) {
784 const int dim = IsColMajor ? i : NumDims - i - 1;
785 params.input_block_sizes[dim] = params.input_dims[dim];
787 for (
int i = params.inner_dim_count; i < NumDims; ++i) {
788 const int dim = IsColMajor ? i : NumDims - i - 1;
789 params.input_block_sizes[dim] = 1;
791 params.input_block_strides =
792 internal::strides<Layout>(params.input_block_sizes);
812 for (
int i = 0; i < params.inner_dim_count; ++i) {
813 const int dim = IsColMajor ? i : NumDims - i - 1;
815 const int copy_dim = IsColMajor ? 2 * i : 2 * NumDims - 2 * i - 1;
816 const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1;
818 params.bcast_block_sizes[copy_dim] = params.input_dims[dim];
819 params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
820 params.bcast_block_strides[copy_dim] = params.output_strides[dim];
821 params.bcast_block_strides[broadcast_dim] =
822 params.output_strides[dim] * params.input_dims[dim];
823 params.bcast_input_strides[copy_dim] = params.input_block_strides[dim];
824 params.bcast_input_strides[broadcast_dim] = 0;
827 for (
int i = 2 * params.inner_dim_count; i < 2 * NumDims; ++i) {
828 const int dim = IsColMajor ? i : 2 * NumDims - i - 1;
829 params.bcast_block_sizes[dim] = 1;
830 params.bcast_block_strides[dim] = 0;
831 params.bcast_input_strides[dim] = 0;
837 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock emptyBlock()
const {
838 DSizes<Index, NumDims> dimensions;
839 for (
int i = 0; i < NumDims; ++i) dimensions[i] = 0;
840 return TensorBlock(internal::TensorBlockKind::kView, NULL, dimensions);
843 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Index BroadcastBlockAlongBcastDim(
844 BlockBroadcastingParams params, Index bcast_offset,
845 TensorBlockScratch& scratch, ScalarNoConst* materialized_output,
846 ScalarNoConst** materialized_input,
847 size_t* materialized_input_size)
const {
848 if (params.bcast_dim_size == 1) {
850 return BroadcastBlock(
851 params.input_block_sizes, params.input_block_strides,
852 params.bcast_block_sizes, params.bcast_block_strides,
853 params.bcast_input_strides, bcast_offset, 0, scratch,
854 materialized_output, materialized_input, materialized_input_size);
856 }
else if (params.input_dims[params.bcast_dim] == 1) {
858 const int broadcast_bcast_dim =
859 IsColMajor ? 2 * params.inner_dim_count + 1
860 : 2 * NumDims - 2 * params.inner_dim_count - 2;
862 params.bcast_block_sizes[broadcast_bcast_dim] = params.bcast_dim_size;
863 params.bcast_input_strides[broadcast_bcast_dim] = 0;
864 params.bcast_block_strides[broadcast_bcast_dim] =
865 params.output_strides[params.bcast_dim];
867 return BroadcastBlock(
868 params.input_block_sizes, params.input_block_strides,
869 params.bcast_block_sizes, params.bcast_block_strides,
870 params.bcast_input_strides, bcast_offset, 0, scratch,
871 materialized_output, materialized_input, materialized_input_size);
876 Index num_output_coeffs = 0;
898 const Index bcast_dim_left_index =
899 bcast_offset / m_outputStrides[params.bcast_dim];
902 const Index input_bcast_dim_size = params.input_dims[params.bcast_dim];
906 const Index first_multiple =
907 divup<Index>(bcast_dim_left_index, input_bcast_dim_size) *
908 input_bcast_dim_size;
910 if (first_multiple <= bcast_dim_left_index + params.bcast_dim_size) {
912 const Index last_multiple =
913 (bcast_dim_left_index + params.bcast_dim_size) /
914 input_bcast_dim_size * input_bcast_dim_size;
915 const int copy_bcast_dim =
916 IsColMajor ? 2 * params.inner_dim_count
917 : 2 * NumDims - 2 * params.inner_dim_count - 1;
918 const int broadcast_bcast_dim =
919 IsColMajor ? 2 * params.inner_dim_count + 1
920 : 2 * NumDims - 2 * params.inner_dim_count - 2;
922 if (first_multiple > bcast_dim_left_index) {
923 const Index head_size = first_multiple - bcast_dim_left_index;
924 params.input_block_sizes[params.bcast_dim] = head_size;
925 params.bcast_block_sizes[copy_bcast_dim] = head_size;
926 params.bcast_input_strides[copy_bcast_dim] =
927 params.input_block_strides[params.bcast_dim];
928 params.bcast_block_strides[copy_bcast_dim] =
929 params.output_strides[params.bcast_dim];
930 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
931 params.bcast_input_strides[broadcast_bcast_dim] = 0;
932 params.bcast_block_strides[broadcast_bcast_dim] =
933 params.output_strides[params.bcast_dim] *
934 params.input_dims[params.bcast_dim];
936 num_output_coeffs += BroadcastBlock(
937 params.input_block_sizes, params.input_block_strides,
938 params.bcast_block_sizes, params.bcast_block_strides,
939 params.bcast_input_strides, bcast_offset, 0, scratch,
940 materialized_output, materialized_input, materialized_input_size);
942 if (first_multiple < last_multiple) {
943 params.input_block_sizes[params.bcast_dim] = input_bcast_dim_size;
944 params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size;
945 params.bcast_input_strides[copy_bcast_dim] =
946 params.input_block_strides[params.bcast_dim];
947 params.bcast_block_strides[copy_bcast_dim] =
948 params.output_strides[params.bcast_dim];
949 params.bcast_block_sizes[broadcast_bcast_dim] =
950 (last_multiple - first_multiple) / input_bcast_dim_size;
951 params.bcast_input_strides[broadcast_bcast_dim] = 0;
952 params.bcast_block_strides[broadcast_bcast_dim] =
953 params.output_strides[params.bcast_dim] *
954 params.input_dims[params.bcast_dim];
955 const Index offset = (first_multiple - bcast_dim_left_index) *
956 m_outputStrides[params.bcast_dim];
958 num_output_coeffs += BroadcastBlock(
959 params.input_block_sizes, params.input_block_strides,
960 params.bcast_block_sizes, params.bcast_block_strides,
961 params.bcast_input_strides, bcast_offset, offset, scratch,
962 materialized_output, materialized_input, materialized_input_size);
964 if (last_multiple < bcast_dim_left_index + params.bcast_dim_size) {
965 const Index tail_size =
966 bcast_dim_left_index + params.bcast_dim_size - last_multiple;
967 params.input_block_sizes[params.bcast_dim] = tail_size;
968 params.bcast_block_sizes[copy_bcast_dim] = tail_size;
969 params.bcast_input_strides[copy_bcast_dim] =
970 params.input_block_strides[params.bcast_dim];
971 params.bcast_block_strides[copy_bcast_dim] =
972 params.output_strides[params.bcast_dim];
973 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
974 params.bcast_input_strides[broadcast_bcast_dim] = 0;
975 params.bcast_block_strides[broadcast_bcast_dim] =
976 params.output_strides[params.bcast_dim] *
977 params.input_dims[params.bcast_dim];
978 const Index offset = (last_multiple - bcast_dim_left_index) *
979 m_outputStrides[params.bcast_dim];
981 num_output_coeffs += BroadcastBlock(
982 params.input_block_sizes, params.input_block_strides,
983 params.bcast_block_sizes, params.bcast_block_strides,
984 params.bcast_input_strides, bcast_offset, offset, scratch,
985 materialized_output, materialized_input, materialized_input_size);
989 const int copy_bcast_dim =
990 IsColMajor ? 2 * params.inner_dim_count
991 : 2 * NumDims - 2 * params.inner_dim_count - 1;
992 params.input_block_sizes[params.bcast_dim] = params.bcast_dim_size;
993 params.bcast_block_sizes[copy_bcast_dim] = params.bcast_dim_size;
994 params.bcast_input_strides[copy_bcast_dim] =
995 params.input_block_strides[params.bcast_dim];
996 params.bcast_block_strides[copy_bcast_dim] =
997 params.output_strides[params.bcast_dim];
999 num_output_coeffs += BroadcastBlock(
1000 params.input_block_sizes, params.input_block_strides,
1001 params.bcast_block_sizes, params.bcast_block_strides,
1002 params.bcast_input_strides, bcast_offset, 0, scratch,
1003 materialized_output, materialized_input, materialized_input_size);
1006 return num_output_coeffs;
1010 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Index BroadcastBlock(
1011 const Dimensions& input_block_sizes,
1012 const Dimensions& input_block_strides,
1013 const BroadcastDimensions& bcast_block_sizes,
1014 const BroadcastDimensions& bcast_block_strides,
1015 const BroadcastDimensions& bcast_input_strides, Index bcast_offset,
1016 Index offset, TensorBlockScratch& scratch,
1017 ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
1018 size_t* materialized_input_size)
const {
1021 const Index input_offset = bcast_offset + offset;
1022 TensorBlockDesc input_desc(
1023 IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset),
1026 ArgTensorBlock input_block = m_impl.block(input_desc, scratch);
1031 const ScalarNoConst* input_buffer = NULL;
1033 if (input_block.data() != NULL) {
1035 input_buffer = input_block.data();
1042 const size_t input_total_size = input_block_sizes.TotalSize();
1043 if (*materialized_input == NULL ||
1044 *materialized_input_size < input_total_size) {
1045 *materialized_input_size = input_total_size;
1046 void* mem = scratch.allocate(*materialized_input_size *
sizeof(Scalar));
1047 *materialized_input =
static_cast<ScalarNoConst*
>(mem);
1050 typedef internal::TensorBlockAssignment<
1051 ScalarNoConst, NumDims,
typename ArgTensorBlock::XprType,
Index>
1052 TensorBlockAssignment;
1054 TensorBlockAssignment::Run(
1055 TensorBlockAssignment::target(input_block_sizes, input_block_strides,
1056 *materialized_input),
1057 input_block.expr());
1059 input_buffer = *materialized_input;
1065 typedef internal::TensorBlockIO<ScalarNoConst, Index, 2 * NumDims, Layout>
1068 typename TensorBlockIO::Src src(bcast_input_strides, input_buffer);
1069 typename TensorBlockIO::Dst dst(bcast_block_sizes, bcast_block_strides,
1070 materialized_output + offset);
1072 return TensorBlockIO::Copy(dst, src);
1076 const Device EIGEN_DEVICE_REF m_device;
1077 const std::remove_reference_t<Broadcast> m_broadcast;
1078 Dimensions m_dimensions;
1079 array<Index, NumDims> m_outputStrides;
1080 array<Index, NumDims> m_inputStrides;
1081 TensorEvaluator<ArgType, Device> m_impl;
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index