Merge pull request #2953 from wangkuiyi/tensor_type_to_eigen
Refactorize Tensor to Eigen convesioncblas_new
commit
d81084939b
@ -0,0 +1,84 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/tensor.h"
|
||||||
|
#include "unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
// EigenDim converts paddle::platform::DDim into Eigen::DSizes.
|
||||||
|
template <int D>
|
||||||
|
struct EigenDim {
|
||||||
|
using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
|
||||||
|
|
||||||
|
static Type From(const DDim& dims) {
|
||||||
|
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
|
||||||
|
Type ret;
|
||||||
|
for (int d = 0; d < arity(dims); d++) {
|
||||||
|
ret[d] = dims[d];
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor.
|
||||||
|
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
||||||
|
typename IndexType = Eigen::DenseIndex>
|
||||||
|
struct EigenTensor {
|
||||||
|
// TODO(qijun) Now, default type in unaligned, and we will make a benchmark on
|
||||||
|
// the speed of aligned and unaligned version in future.
|
||||||
|
using Type = Eigen::TensorMap<Eigen::Tensor<T, D, MajorType, IndexType>>;
|
||||||
|
|
||||||
|
using ConstType =
|
||||||
|
Eigen::TensorMap<Eigen::Tensor<const T, D, MajorType, IndexType>>;
|
||||||
|
|
||||||
|
static Type From(Tensor& tensor, DDim dims) {
|
||||||
|
return Type(tensor.data<T>(), EigenDim<D>::From(dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); }
|
||||||
|
|
||||||
|
static ConstType From(const Tensor& tensor, DDim dims) {
|
||||||
|
return ConstType(tensor.data<T>(), EigenDim<D>::From(dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
static ConstType From(const Tensor& tensor) {
|
||||||
|
return From(tensor, tensor.dims_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int MajorType = Eigen::RowMajor,
|
||||||
|
typename IndexType = Eigen::DenseIndex>
|
||||||
|
struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
|
||||||
|
// Flatten is to reshape a Tensor into a one dimension EigenVector
|
||||||
|
static typename EigenTensor<T, 1>::Type Flatten(Tensor& tensor) {
|
||||||
|
return EigenTensor<T, 1>::From(
|
||||||
|
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
|
||||||
|
}
|
||||||
|
|
||||||
|
static typename EigenTensor<T, 1>::ConstType Flatten(const Tensor& tensor) {
|
||||||
|
return EigenTensor<T, 1>::From(
|
||||||
|
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int MajorType = Eigen::RowMajor,
|
||||||
|
typename IndexType = Eigen::DenseIndex>
|
||||||
|
using EigenMatrix = EigenTensor<T, 2, MajorType, IndexType>;
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,101 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "paddle/framework/eigen.h"
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
TEST(EigenDim, From) {
|
||||||
|
EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3}));
|
||||||
|
ASSERT_EQ(1, ed[0]);
|
||||||
|
ASSERT_EQ(2, ed[1]);
|
||||||
|
ASSERT_EQ(3, ed[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Eigen, Tensor) {
|
||||||
|
Tensor t;
|
||||||
|
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
|
||||||
|
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||||
|
p[i] = static_cast<float>(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
|
||||||
|
|
||||||
|
ASSERT_EQ(1, et.dimension(0));
|
||||||
|
ASSERT_EQ(2, et.dimension(1));
|
||||||
|
ASSERT_EQ(3, et.dimension(2));
|
||||||
|
|
||||||
|
for (int i = 0; i < 1; i++) {
|
||||||
|
for (int j = 0; j < 2; j++) {
|
||||||
|
for (int k = 0; k < 3; k++) {
|
||||||
|
ASSERT_NEAR((i * 2 + j) * 3 + k, et(i, j, k), 1e-6f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Eigen, VectorFrom) {
|
||||||
|
Tensor t;
|
||||||
|
float* p = t.mutable_data<float>(make_ddim({6}), platform::CPUPlace());
|
||||||
|
for (int i = 0; i < 6; i++) {
|
||||||
|
p[i] = static_cast<float>(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
EigenVector<float>::Type ev = EigenVector<float>::From(t);
|
||||||
|
|
||||||
|
ASSERT_EQ(6, ev.dimension(0));
|
||||||
|
|
||||||
|
for (int i = 0; i < 6; i++) {
|
||||||
|
ASSERT_NEAR(i, ev(i), 1e-6f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Eigen, VectorFlatten) {
|
||||||
|
Tensor t;
|
||||||
|
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
|
||||||
|
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||||
|
p[i] = static_cast<float>(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
EigenVector<float>::Type ev = EigenVector<float>::Flatten(t);
|
||||||
|
|
||||||
|
ASSERT_EQ(1 * 2 * 3, ev.dimension(0));
|
||||||
|
|
||||||
|
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||||
|
ASSERT_NEAR(i, ev(i), 1e-6f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Eigen, Matrix) {
|
||||||
|
Tensor t;
|
||||||
|
float* p = t.mutable_data<float>(make_ddim({2, 3}), platform::CPUPlace());
|
||||||
|
for (int i = 0; i < 2 * 3; i++) {
|
||||||
|
p[i] = static_cast<float>(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
EigenMatrix<float>::Type em = EigenMatrix<float>::From(t);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, em.dimension(0));
|
||||||
|
ASSERT_EQ(3, em.dimension(1));
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
for (int j = 0; j < 3; j++) {
|
||||||
|
ASSERT_NEAR(i * 3 + j, em(i, j), 1e-6f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -1,67 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "unsupported/Eigen/CXX11/Tensor"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace framework {
|
|
||||||
|
|
||||||
// Helper to define Tensor types given that the scalar is of type T.
|
|
||||||
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
|
|
||||||
struct TTypes {
|
|
||||||
// Rank-<NDIMS> tensor of scalar type T.
|
|
||||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
|
|
||||||
Eigen::Aligned>
|
|
||||||
Tensor;
|
|
||||||
typedef Eigen::TensorMap<
|
|
||||||
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
|
||||||
ConstTensor;
|
|
||||||
|
|
||||||
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
|
|
||||||
typedef Eigen::TensorMap<
|
|
||||||
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
|
|
||||||
Eigen::Aligned>
|
|
||||||
Scalar;
|
|
||||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
|
|
||||||
Eigen::RowMajor, IndexType>,
|
|
||||||
Eigen::Aligned>
|
|
||||||
ConstScalar;
|
|
||||||
|
|
||||||
// Rank-1 tensor (vector) of scalar type T.
|
|
||||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
|
|
||||||
Eigen::Aligned>
|
|
||||||
Flat;
|
|
||||||
typedef Eigen::TensorMap<
|
|
||||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
|
||||||
ConstFlat;
|
|
||||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
|
|
||||||
Eigen::Aligned>
|
|
||||||
Vec;
|
|
||||||
typedef Eigen::TensorMap<
|
|
||||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
|
||||||
ConstVec;
|
|
||||||
|
|
||||||
// Rank-2 tensor (matrix) of scalar type T.
|
|
||||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
|
|
||||||
Eigen::Aligned>
|
|
||||||
Matrix;
|
|
||||||
typedef Eigen::TensorMap<
|
|
||||||
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
|
||||||
ConstMatrix;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace framework
|
|
||||||
} // namespace paddle
|
|
Loading…
Reference in new issue