|
|
|
@ -120,10 +120,10 @@ class OperatorBase {
|
|
|
|
|
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OperatorContext {
|
|
|
|
|
class InferShapeContext {
|
|
|
|
|
public:
|
|
|
|
|
OperatorContext(const OperatorBase* op, const Scope& scope)
|
|
|
|
|
: op_(*op), scope_(scope) {}
|
|
|
|
|
InferShapeContext(const OperatorBase& op, const Scope& scope)
|
|
|
|
|
: op_(op), scope_(scope) {}
|
|
|
|
|
|
|
|
|
|
size_t InputSize() const { return op_.inputs_.size(); }
|
|
|
|
|
|
|
|
|
@ -234,12 +234,6 @@ class OperatorContext {
|
|
|
|
|
const Scope& scope_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class InferShapeContext : public OperatorContext {
|
|
|
|
|
public:
|
|
|
|
|
InferShapeContext(const OperatorBase* op, const Scope& scope)
|
|
|
|
|
: OperatorContext(op, scope) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct EigenDeviceConverter;
|
|
|
|
|
|
|
|
|
@ -255,11 +249,11 @@ struct EigenDeviceConverter<platform::GPUPlace> {
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
class ExecutionContext : public OperatorContext {
|
|
|
|
|
class ExecutionContext : public InferShapeContext {
|
|
|
|
|
public:
|
|
|
|
|
ExecutionContext(const OperatorBase* op, const Scope& scope,
|
|
|
|
|
ExecutionContext(const OperatorBase& op, const Scope& scope,
|
|
|
|
|
const platform::DeviceContext* device_context)
|
|
|
|
|
: OperatorContext(op, scope), device_context_(device_context) {}
|
|
|
|
|
: InferShapeContext(op, scope), device_context_(device_context) {}
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType,
|
|
|
|
|
typename DeviceType =
|
|
|
|
@ -311,13 +305,13 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
|
|
|
|
|
|
|
|
|
|
void InferShape(const Scope& scope) const override {
|
|
|
|
|
InferShape(InferShapeContext(this, scope));
|
|
|
|
|
InferShape(InferShapeContext(*this, scope));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const final {
|
|
|
|
|
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
|
|
|
|
|
opKernel->Compute(ExecutionContext(this, scope, &dev_ctx));
|
|
|
|
|
opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
|
|
|
|
|