"fix tensor shared_ptr"

gangliao-patch-1
dzhwinter 8 years ago
parent a46f3fcefc
commit df5bc78702

@ -5,7 +5,6 @@
#include <string.h>
#include <memory>
#include "paddle/math/MemoryHandle.h"
#include "paddle/utils/Common.h"
#include "paddle/utils/Logging.h"
@ -15,17 +14,16 @@ namespace optimizer {
template <class T>
class TensorT {
public:
TensorT(size_t size)
: TensorT(std::make_shared<CpuMemoryHandle>(size * sizeof(float)), size) {
TensorT(size_t size) : height_(1), width_(size) {
data_ptr_ = std::shared_ptr<T>(new T[size], std::default_delete<T[]>());
data_ = data_ptr_.get();
}
TensorT(CpuMemHandlePtr handle, size_t size)
: height_(1),
width_(size),
data_(reinterpret_cast<T*>(handle->getBuf())) {}
TensorT(T* data, size_t size) : height_(1), width_(size), data_(data) {}
TensorT(T* data, size_t size)
: height_(1), width_(size), data_ptr_(nullptr), data_(data) {}
TensorT(T* data, size_t h, size_t w) : height_(h), width_(w), data_(data) {}
TensorT(T* data, size_t h, size_t w)
: height_(h), width_(w), data_ptr_(nullptr), data_(data) {}
virtual ~TensorT() {}
@ -45,6 +43,7 @@ public:
protected:
size_t height_;
size_t width_;
std::shared_ptr<T> data_ptr_;
T* data_;
};

Loading…
Cancel
Save