|
|
|
@ -23,7 +23,7 @@ namespace framework {
|
|
|
|
|
// EigenDim converts paddle::platform::DDim into Eigen::DSizes.
|
|
|
|
|
template <int D>
|
|
|
|
|
struct EigenDim {
|
|
|
|
|
typedef Eigen::DSizes<Eigen::DenseIndex, D> Type;
|
|
|
|
|
using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
|
|
|
|
|
|
|
|
|
|
static Type From(const DDim& dims) {
|
|
|
|
|
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
|
|
|
|
@ -69,12 +69,23 @@ struct EigenVector {
|
|
|
|
|
using ConstType =
|
|
|
|
|
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>,
|
|
|
|
|
Eigen::Aligned>;
|
|
|
|
|
|
|
|
|
|
// From is to transfer a one dimension Tensor into a one dimension EigenVector
|
|
|
|
|
static Type From(Tensor& tensor) { return EigenTensor<T, 1>::From(tensor); }
|
|
|
|
|
|
|
|
|
|
// Flatten is to reshape a Tensor into a one dimension EigenVector
|
|
|
|
|
static Type Flatten(Tensor& tensor) {
|
|
|
|
|
return EigenTensor<T, 1>::From(
|
|
|
|
|
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ConstType From(const Tensor& tensor) {
|
|
|
|
|
return EigenTensor<T, 1>::From(tensor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ConstType Flatten(const Tensor& tensor) {
|
|
|
|
|
return EigenTensor<T, 1>::From(
|
|
|
|
|
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix.
|
|
|
|
|