set correct place for output tensor

cblas_new
qijun 8 years ago
parent 6dc567a52e
commit 2a03e3808d

@ -18,14 +18,14 @@ namespace paddle {
namespace framework {
template <>
Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device<
Eigen::DefaultDevice* OpKernel::KernelContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice* OpKernel::KernelContext::get_eigen_device<
Eigen::GpuDevice* OpKernel::KernelContext::GetEigenDevice<
platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
}

@ -109,7 +109,9 @@ class OpKernel {
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* get_eigen_device() const;
DeviceType* GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
const OperatorBase& op_;
const ScopePtr& scope_;

@ -27,9 +27,9 @@ public:
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(Place());
output->mutable_data<T>(context.GetPlace());
output->flat<T>().device(*(context.get_eigen_device<Place>())) =
output->flat<T>().device(*(context.GetEigenDevice<Place>())) =
input0.flat<T>() + input1.flat<T>();
}
};

Loading…
Cancel
Save