|
|
|
@ -28,13 +28,6 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
class Tensor {
|
|
|
|
|
public:
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T* data() const {
|
|
|
|
|
PADDLE_ENFORCE(holder_ != nullptr,
|
|
|
|
|
"Tensor::data must be called after Tensor::mutable_data.");
|
|
|
|
|
return static_cast<const T*>(holder_->Ptr());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* data() const {
|
|
|
|
|
PADDLE_ENFORCE(holder_ != nullptr,
|
|
|
|
@ -60,14 +53,14 @@ class Tensor {
|
|
|
|
|
size_t NumElements() const { return product(dims_); }
|
|
|
|
|
|
|
|
|
|
template <typename T, size_t NDIMS>
|
|
|
|
|
typename TTypes<T, NDIMS>::Tensor Tensor::shaped(DDim new_dims) {
|
|
|
|
|
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) {
|
|
|
|
|
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
|
|
|
|
|
paddle::framework::ToEigenDSizes(new_dims);
|
|
|
|
|
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
|
|
|
|
|
return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, size_t NDIMS>
|
|
|
|
|
typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
|
|
|
|
|
typename TTypes<T, NDIMS>::Tensor tensor() {
|
|
|
|
|
return typename TTypes<T, NDIMS>::Tensor(
|
|
|
|
|
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
|
|
|
|
|
}
|
|
|
|
@ -92,7 +85,7 @@ class Tensor {
|
|
|
|
|
|
|
|
|
|
// const versions of all the methods above.
|
|
|
|
|
template <typename T, size_t NDIMS>
|
|
|
|
|
typename TTypes<T, NDIMS>::ConstantTensor Tensor::tensor() const {
|
|
|
|
|
typename TTypes<T, NDIMS>::ConstantTensor tensor() const {
|
|
|
|
|
return typename TTypes<T, NDIMS>::Tensor(
|
|
|
|
|
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
|
|
|
|
|
}
|
|
|
|
|