|
|
|
@ -33,13 +33,13 @@ template <typename T>
|
|
|
|
|
struct EigenDeviceConverter;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct EigenDeviceConverter<CPUPlace> {
|
|
|
|
|
struct EigenDeviceConverter<platform::CPUPlace> {
|
|
|
|
|
using EigenDeviceType = Eigen::DefaultDevice;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
template <>
|
|
|
|
|
struct EigenDeviceConverter<GPUPlace> {
|
|
|
|
|
struct EigenDeviceConverter<platform::GPUPlace> {
|
|
|
|
|
using EigenDeviceType = Eigen::GpuDevice;
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
@ -87,39 +87,38 @@ class OperatorBase {
|
|
|
|
|
AttributeMap attrs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* KernelContext is the only parameter of Kernel Run function.
|
|
|
|
|
* Run will get input/output variables, state such as momentum and
|
|
|
|
|
* device resource such as CUDA stream, cublas handle, etc. from
|
|
|
|
|
* KernelContext. User should construct it before run the Operator.
|
|
|
|
|
*/
|
|
|
|
|
class KernelContext {
|
|
|
|
|
class OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
|
|
|
|
|
const platform::DeviceContext& device_context)
|
|
|
|
|
: op_(*op), scope_(scope), device_context_(device_context) {}
|
|
|
|
|
|
|
|
|
|
const Variable* Input(int index) const {
|
|
|
|
|
return scope_->GetVariable(op_.inputs_[index]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Variable* Output(int index) const {
|
|
|
|
|
return scope_->GetVariable(op_.outputs_[index]);
|
|
|
|
|
}
|
|
|
|
|
/**
|
|
|
|
|
* KernelContext is the only parameter of Kernel Run function.
|
|
|
|
|
* Run will get input/output variables, state such as momentum and
|
|
|
|
|
* device resource such as CUDA stream, cublas handle, etc. from
|
|
|
|
|
* KernelContext. User should construct it before run the Operator.
|
|
|
|
|
*/
|
|
|
|
|
class KernelContext {
|
|
|
|
|
public:
|
|
|
|
|
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
|
|
|
|
|
const platform::DeviceContext& device_context)
|
|
|
|
|
: op_(*op), scope_(scope), device_context_(device_context) {}
|
|
|
|
|
|
|
|
|
|
const Variable* Input(int index) const {
|
|
|
|
|
return scope_->GetVariable(op_.inputs_[index]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::DeviceContext& device_context() const { return device_context_; }
|
|
|
|
|
Variable* Output(int index) const {
|
|
|
|
|
return scope_->GetVariable(op_.outputs_[index]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, typename DeviceType = EigenDeviceConverter<
|
|
|
|
|
PlaceType>::EigenDeviceType>
|
|
|
|
|
DeviceType* get_eigen_device();
|
|
|
|
|
template <typename PlaceType,
|
|
|
|
|
typename DeviceType =
|
|
|
|
|
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
|
|
|
|
|
DeviceType* get_eigen_device() const;
|
|
|
|
|
|
|
|
|
|
const OperatorBase& op_;
|
|
|
|
|
const std::shared_ptr<Scope>& scope_;
|
|
|
|
|
const platform::DeviceContext& device_context_;
|
|
|
|
|
};
|
|
|
|
|
const OperatorBase& op_;
|
|
|
|
|
const std::shared_ptr<Scope>& scope_;
|
|
|
|
|
const platform::DeviceContext& device_context_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
virtual void Compute(const KernelContext& context) const = 0;
|
|
|
|
|
|
|
|
|
|
virtual ~OpKernel() {}
|
|
|
|
|