|
|
|
@ -80,5 +80,21 @@ struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
struct EigenScalar {
|
|
|
|
|
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
|
|
|
|
|
using Type = Eigen::TensorMap<
|
|
|
|
|
Eigen::TensorFixedSize<T, Eigen::Sizes<>, MajorType, IndexType>>;
|
|
|
|
|
using ConstType = Eigen::TensorMap<
|
|
|
|
|
Eigen::TensorFixedSize<const T, Eigen::Sizes<>, MajorType, IndexType>>;
|
|
|
|
|
|
|
|
|
|
static Type From(Tensor& tensor) { return Type(tensor.data<T>()); }
|
|
|
|
|
|
|
|
|
|
static ConstType From(const Tensor& tensor) {
|
|
|
|
|
return ConstType(tensor.data<T>());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|