Call Tensor::numel() everywhere.

enforce_failed
qingqing01 8 years ago
parent a2a69f2a54
commit ad64ca5da2

@ -162,7 +162,10 @@ class Tensor {
/*! points to dimensions of memory block. */ /*! points to dimensions of memory block. */
DDim dims_; DDim dims_;
/*! the element count of tensor. */ /**
* A cache of the number of elements in a tensor.
* Would be 0 for an uninitialized tensor.
*/
int64_t numel_; int64_t numel_;
/** /**

@ -24,7 +24,7 @@ inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first."); holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
holder_->size(), numel_ * sizeof(T) + offset_, holder_->size(), numel() * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " "Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n" "first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored."); "or maybe the required data-type mismatches the data already stored.");
@ -54,11 +54,11 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
template <typename T> template <typename T>
inline T* Tensor::mutable_data(platform::Place place) { inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
PADDLE_ENFORCE_GT(numel_, 0, PADDLE_ENFORCE_GT(numel(), 0,
"Tensor's numel must be larger than zero to call " "Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."); "Tensor::mutable_data. Call Tensor::set_dim first.");
/* some versions of boost::variant don't have operator!= */ /* some versions of boost::variant don't have operator!= */
int64_t size = numel_ * sizeof(T); int64_t size = numel() * sizeof(T);
if (holder_ == nullptr || !(holder_->place() == place) || if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) { holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
@ -131,7 +131,7 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE_LT(begin_idx, end_idx, PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index."); "Begin index must be less than end index.");
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1."); PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
size_t base = numel_ / dims_[0]; size_t base = numel() / dims_[0];
Tensor dst; Tensor dst;
dst.holder_ = holder_; dst.holder_ = holder_;
DDim dst_dims = dims_; DDim dst_dims = dims_;

Loading…
Cancel
Save