|
|
|
@ -129,8 +129,8 @@ class OperatorBase {
|
|
|
|
|
virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
|
|
|
|
|
|
|
|
|
|
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
|
|
|
|
|
virtual void RunInferShape(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {}
|
|
|
|
|
virtual void RuntimeInferShape(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::string type_;
|
|
|
|
@ -351,8 +351,8 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RunInferShape(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const override;
|
|
|
|
|
void RuntimeInferShape(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const override;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
|
|
|
|
|