|
|
|
@ -252,7 +252,7 @@ struct EigenDeviceConverter<platform::GPUPlace> {
|
|
|
|
|
class ExecutionContext : public OperatorContext {
|
|
|
|
|
public:
|
|
|
|
|
ExecutionContext(const OperatorBase* op, const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& device_context)
|
|
|
|
|
const platform::DeviceContext* device_context)
|
|
|
|
|
: OperatorContext(op, scope), device_context_(device_context) {}
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType,
|
|
|
|
@ -260,9 +260,9 @@ class ExecutionContext : public OperatorContext {
|
|
|
|
|
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
|
|
|
|
|
DeviceType& GetEigenDevice() const;
|
|
|
|
|
|
|
|
|
|
platform::Place GetPlace() const { return device_context_.GetPlace(); }
|
|
|
|
|
platform::Place GetPlace() const { return device_context_->GetPlace(); }
|
|
|
|
|
|
|
|
|
|
const platform::DeviceContext& device_context_;
|
|
|
|
|
const platform::DeviceContext* device_context_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpKernel {
|
|
|
|
@ -311,7 +311,7 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
void Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const final {
|
|
|
|
|
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
|
|
|
|
|
opKernel->Compute(ExecutionContext(this, scope, dev_ctx));
|
|
|
|
|
opKernel->Compute(ExecutionContext(this, scope, &dev_ctx));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
|
|
|
|
|