"fix tensor shared_ptr"

gangliao-patch-1
dzhwinter 8 years ago
parent a46f3fcefc
commit df5bc78702

@ -5,7 +5,6 @@
#include <string.h> #include <string.h>
#include <memory> #include <memory>
#include "paddle/math/MemoryHandle.h"
#include "paddle/utils/Common.h" #include "paddle/utils/Common.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
@ -15,17 +14,16 @@ namespace optimizer {
template <class T> template <class T>
class TensorT { class TensorT {
public: public:
TensorT(size_t size) TensorT(size_t size) : height_(1), width_(size) {
: TensorT(std::make_shared<CpuMemoryHandle>(size * sizeof(float)), size) { data_ptr_ = std::shared_ptr<T>(new T[size], std::default_delete<T[]>());
data_ = data_ptr_.get();
} }
TensorT(CpuMemHandlePtr handle, size_t size)
: height_(1),
width_(size),
data_(reinterpret_cast<T*>(handle->getBuf())) {}
TensorT(T* data, size_t size) : height_(1), width_(size), data_(data) {} TensorT(T* data, size_t size)
: height_(1), width_(size), data_ptr_(nullptr), data_(data) {}
TensorT(T* data, size_t h, size_t w) : height_(h), width_(w), data_(data) {} TensorT(T* data, size_t h, size_t w)
: height_(h), width_(w), data_ptr_(nullptr), data_(data) {}
virtual ~TensorT() {} virtual ~TensorT() {}
@ -45,6 +43,7 @@ public:
protected: protected:
size_t height_; size_t height_;
size_t width_; size_t width_;
std::shared_ptr<T> data_ptr_;
T* data_; T* data_;
}; };

Loading…
Cancel
Save