"change device context to pointer"

fixstartbug
dongzhihong 8 years ago
parent d911b1b5c4
commit b18e614163

@ -22,14 +22,14 @@ namespace framework {
template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
return *device_context_->get_eigen_device<Eigen::DefaultDevice>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.get_eigen_device<Eigen::GpuDevice>();
return *device_context_->get_eigen_device<Eigen::GpuDevice>();
}
#endif

@ -252,7 +252,7 @@ struct EigenDeviceConverter<platform::GPUPlace> {
class ExecutionContext : public OperatorContext {
public:
ExecutionContext(const OperatorBase* op, const Scope& scope,
const platform::DeviceContext& device_context)
const platform::DeviceContext* device_context)
: OperatorContext(op, scope), device_context_(device_context) {}
template <typename PlaceType,
@ -260,9 +260,9 @@ class ExecutionContext : public OperatorContext {
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
platform::Place GetPlace() const { return device_context_->GetPlace(); }
const platform::DeviceContext& device_context_;
const platform::DeviceContext* device_context_;
};
class OpKernel {
@ -311,7 +311,7 @@ class OperatorWithKernel : public OperatorBase {
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, dev_ctx));
opKernel->Compute(ExecutionContext(this, scope, &dev_ctx));
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>&

Loading…
Cancel
Save