Named to RuntimeInferShape

test=develop
revert-14666-feature/estiminate_flops
baojun-nervana 6 years ago
parent 24e70920db
commit e6bd53be60

@ -279,7 +279,7 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphOperator::backend_ =
ngraph::runtime::Backend::create("CPU"); ngraph::runtime::Backend::create("CPU");
void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) { void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
op->RunInferShape(scope_, place_); op->RuntimeInferShape(scope_, place_);
for (auto& var_name_item : op->Inputs()) { for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name); auto* var = scope_.FindVar(var_name);

@ -695,8 +695,8 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name); "Tensor %s contains NAN", name);
} }
void OperatorWithKernel::RunInferShape(const Scope& scope, void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }

@ -129,8 +129,8 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void RunInferShape(const Scope& scope, virtual void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const {} const platform::Place& place) const {}
protected: protected:
std::string type_; std::string type_;
@ -351,8 +351,8 @@ class OperatorWithKernel : public OperatorBase {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
} }
void RunInferShape(const Scope& scope, void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const override; const platform::Place& place) const override;
protected: protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;

Loading…
Cancel
Save