|
|
|
@ -90,6 +90,7 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
|
|
|
|
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
|
|
|
|
|
using framework::To32BitIndex;
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class ExpandKernel : public framework::OpKernel<T> {
|
|
|
|
@ -131,7 +132,13 @@ class ExpandKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto y = EigenTensor<T, Rank>::From(*out0);
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
y.device(place) = x.broadcast(bcast_dims);
|
|
|
|
|
// use 32-bit index to speed up
|
|
|
|
|
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
|
|
|
|
|
if (use_32bit_index) {
|
|
|
|
|
To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims);
|
|
|
|
|
} else {
|
|
|
|
|
y.device(place) = x.broadcast(bcast_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|