You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
1.3 KiB
55 lines
1.3 KiB
8 years ago
|
#pragma once
|
||
8 years ago
|
/**
|
||
|
* @brief tensor used by optimizer
|
||
|
*/
|
||
|
|
||
|
#include <string.h>
|
||
8 years ago
|
#include <memory>
|
||
8 years ago
|
#include "paddle/utils/Common.h"
|
||
|
#include "paddle/utils/Logging.h"
|
||
8 years ago
|
|
||
|
namespace paddle {
|
||
|
namespace optimizer {
|
||
|
|
||
|
template <class T>
|
||
8 years ago
|
class TensorT {
|
||
8 years ago
|
public:
|
||
8 years ago
|
TensorT(size_t size) : height_(1), width_(size) {
|
||
|
data_ptr_ = std::shared_ptr<T>(new T[size], std::default_delete<T[]>());
|
||
|
data_ = data_ptr_.get();
|
||
8 years ago
|
}
|
||
8 years ago
|
|
||
8 years ago
|
TensorT(T* data, size_t size)
|
||
|
: height_(1), width_(size), data_ptr_(nullptr), data_(data) {}
|
||
8 years ago
|
|
||
8 years ago
|
TensorT(T* data, size_t h, size_t w)
|
||
|
: height_(h), width_(w), data_ptr_(nullptr), data_(data) {}
|
||
8 years ago
|
|
||
8 years ago
|
virtual ~TensorT() {}
|
||
8 years ago
|
|
||
8 years ago
|
T* get_buffer() { return this->data_; }
|
||
8 years ago
|
|
||
8 years ago
|
T& operator[](const size_t idx) {
|
||
8 years ago
|
CHECK(idx >= 0 && idx < this->width_) << "out of index range";
|
||
8 years ago
|
return data_[idx];
|
||
|
}
|
||
8 years ago
|
T& operator[](const size_t idx) const {
|
||
8 years ago
|
CHECK(idx >= 0 && idx < this->width_) << "out of index range";
|
||
|
return data_[idx];
|
||
8 years ago
|
}
|
||
8 years ago
|
// TODO: replace with tensorshape
|
||
8 years ago
|
size_t size() const { return this->width_ * this->height_; }
|
||
8 years ago
|
|
||
|
protected:
|
||
|
size_t height_;
|
||
|
size_t width_;
|
||
8 years ago
|
std::shared_ptr<T> data_ptr_;
|
||
8 years ago
|
T* data_;
|
||
8 years ago
|
};
|
||
|
|
||
8 years ago
|
// TODO(zhihong): design problem of dynamic datatype, need to fix it
|
||
8 years ago
|
typedef TensorT<float> Tensor;
|
||
8 years ago
|
|
||
8 years ago
|
} // namespace optimizer
|
||
|
} // namespace paddle
|