refactor tensor mutable_data

cblas_new
qijun 8 years ago
parent 65dbeb6a24
commit 5017b15468

@ -14,17 +14,17 @@ limitations under the License. */
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"
namespace paddle {
namespace framework {

@ -62,21 +62,19 @@ class Tensor {
!(holder_->place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->size() < numel_ * sizeof(T) + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
}
#ifdef __CUDACC__
switch (place.which()) {
case 0:
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
break;
case 1:
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
break;
else if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), numel_ * sizeof(T)));
}
#else
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), numel_ * sizeof(T)));
else if (platform::is_gpu_place(place)) {
PADDLE_ENFORCE(true, "GPU not support!");
}
#endif
offset_ = 0;
}

@ -20,9 +20,9 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
#include <paddle/platform/place.h>
#include <memory>
#include <unsupported/Eigen/CXX11/Tensor>
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace platform {

Loading…
Cancel
Save