|
|
|
@ -14,17 +14,18 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include "paddle/memory/memcpy.h"
|
|
|
|
|
#include "paddle/platform/enforce.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void Tensor::check_memory_size() const {
|
|
|
|
|
PADDLE_ENFORCE(holder_ != nullptr,
|
|
|
|
|
"Tenosr holds no memory. Call Tensor::mutable_data first.");
|
|
|
|
|
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
|
|
|
|
|
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
|
|
|
|
|
"first to re-allocate memory.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
|
|
|
|
|
PADDLE_ENFORCE_GE(holder_->size(), product(dims_) * sizeof(T) + offset_,
|
|
|
|
|
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
|
|
|
|
|
"first to re-allocate memory.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -51,9 +52,9 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline T* Tensor::mutable_data(platform::Place place) {
|
|
|
|
|
static_assert(std::is_pod<T>::value, "T must be POD");
|
|
|
|
|
PADDLE_ENFORCE(product(dims_) > 0,
|
|
|
|
|
"Tensor's numel must be larger than zero to call "
|
|
|
|
|
"Tensor::mutable_data. Call Tensor::set_dim first.");
|
|
|
|
|
PADDLE_ENFORCE_GT(product(dims_), 0,
|
|
|
|
|
"Tensor's numel must be larger than zero to call "
|
|
|
|
|
"Tensor::mutable_data. Call Tensor::set_dim first.");
|
|
|
|
|
/* some versions of boost::variant don't have operator!= */
|
|
|
|
|
size_t size = product(dims_) * sizeof(T);
|
|
|
|
|
if (holder_ == nullptr || !(holder_->place() == place) ||
|
|
|
|
@ -120,11 +121,11 @@ inline void Tensor::CopyFrom(const Tensor& src,
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
|
|
|
|
|
check_memory_size<T>();
|
|
|
|
|
PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero.");
|
|
|
|
|
PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound.");
|
|
|
|
|
PADDLE_ENFORCE(begin_idx < end_idx,
|
|
|
|
|
"Begin index must be less than end index.");
|
|
|
|
|
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
|
|
|
|
|
PADDLE_ENFORCE_GE(begin_idx, 0, "Slice begin index is less than zero.");
|
|
|
|
|
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
|
|
|
|
|
PADDLE_ENFORCE_LT(begin_idx, end_idx,
|
|
|
|
|
"Begin index must be less than end index.");
|
|
|
|
|
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
|
|
|
|
|
int base = product(dims_) / dims_[0];
|
|
|
|
|
Tensor dst;
|
|
|
|
|
dst.holder_ = holder_;
|
|
|
|
|