diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index 99c4cf0da6..61bae1aba4 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -279,7 +279,7 @@ std::shared_ptr NgraphOperator::backend_ = ngraph::runtime::Backend::create("CPU"); void NgraphOperator::GetNgInputShape(std::shared_ptr op) { - op->RunInferShape(scope_, place_); + op->RuntimeInferShape(scope_, place_); for (auto& var_name_item : op->Inputs()) { for (auto& var_name : var_name_item.second) { auto* var = scope_.FindVar(var_name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index a816aa94c0..f3d225df69 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -695,8 +695,8 @@ static void CheckTensorNANOrInf(const std::string& name, "Tensor %s contains NAN", name); } -void OperatorWithKernel::RunInferShape(const Scope& scope, - const platform::Place& place) const { +void OperatorWithKernel::RuntimeInferShape(const Scope& scope, + const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); this->InferShape(&infer_shape_ctx); } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index fcf889f3db..efc9a1b6f5 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -129,8 +129,8 @@ class OperatorBase { virtual std::vector OutputVars(bool has_intermediate) const; void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } - virtual void RunInferShape(const Scope& scope, - const platform::Place& place) const {} + virtual void RuntimeInferShape(const Scope& scope, + const platform::Place& place) const {} protected: std::string type_; @@ -351,8 +351,8 @@ class OperatorWithKernel : public OperatorBase { OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); } - void RunInferShape(const Scope& scope, - const platform::Place& place) const override; + void RuntimeInferShape(const Scope& scope, + const platform::Place& place) const override; protected: virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;