From 0575fd4647bf414662d31c02371a68689273b22c Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 2 Feb 2018 17:31:37 +0800 Subject: [PATCH 1/3] simplify shape inference code --- paddle/framework/op_desc.cc | 19 ------------------- paddle/framework/operator.cc | 8 -------- paddle/framework/shape_inference.cc | 23 +++++++++++++++++++---- paddle/framework/shape_inference.h | 8 +++----- 4 files changed, 22 insertions(+), 36 deletions(-) diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index f8df2cf97a..f554c77845 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -39,10 +39,6 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasOutputs(const std::string &name) const override; - DDim GetInputDim(const std::string &name) const override; - - void SetOutputDim(const std::string &name, const DDim &dim) override; - AttrReader Attrs() const override; const std::vector &Inputs( @@ -444,21 +440,6 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const { return true; } -DDim CompileTimeInferShapeContext::GetInputDim(const std::string &name) const { - std::vector ddims = GetInputsDim(name); - auto length = ddims.size(); - PADDLE_ENFORCE_EQ(length, 1UL, - "Input(%s) should have 1 value, " - "but it has %d now", - name, length); - return ddims[0]; -} - -void CompileTimeInferShapeContext::SetOutputDim(const std::string &name, - const DDim &dim) { - SetOutputsDim(name, {dim}); -} - AttrReader CompileTimeInferShapeContext::Attrs() const { return AttrReader(op_.GetAttrMap()); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 4e854f54dd..81fa8cf477 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -366,14 +366,6 @@ class RuntimeInferShapeContext : public InferShapeContext { return true; } - DDim GetInputDim(const std::string& name) const override { - return GetDim(op_.Input(name)); - } - - void SetOutputDim(const std::string& name, const DDim& dim) override { - SetDim(op_.Output(name), dim); - } - AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } const std::vector& Inputs( diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index e53cc0cdab..14dba75808 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -18,10 +18,18 @@ limitations under the License. */ namespace paddle { namespace framework { +framework::DDim InferShapeContext::GetInputDim(const std::string &name) const { + const std::vector &arg_names = Inputs(name); + PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, + "Input(%s) shoudl holds one element, but now it holds %d", + name, arg_names.size()); + return this->GetDim(arg_names[0]); +} + std::vector InferShapeContext::GetInputsDim( const std::string &name) const { - const std::vector &names = Inputs(name); - return GetDims(names); + const std::vector &arg_names = Inputs(name); + return GetDims(arg_names); } DDim InferShapeContext::GetInputsElementDim(const std::string &name, @@ -30,13 +38,21 @@ DDim InferShapeContext::GetInputsElementDim(const std::string &name, return this->GetDim(names[idx]); } +void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { + auto &arg_names = Outputs(name); + PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, + "Output(%s) shoudl holds one element, but now it holds %d", + name, arg_names.size()); + SetDim(arg_names[0], dim); +} + void InferShapeContext::SetOutputsDim( const std::string &name, const std::vector &dims) { auto &names = Outputs(name); SetDims(names, dims); } -std::vector InferShapeContext::GetDims( +std::vector InferShapeContext::GetDims( const std::vector &names) const { std::vector ret; ret.reserve(names.size()); @@ -45,7 +61,6 @@ std::vector InferShapeContext::GetDims( [this](const std::string &name) { return this->GetDim(name); }); return ret; } - void InferShapeContext::SetDims(const std::vector &names, const std::vector &dims) { size_t length = names.size(); diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index f93319d8f2..77fc9359be 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -35,12 +35,12 @@ class InferShapeContext { virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; - virtual framework::DDim GetInputDim(const std::string &name) const = 0; + framework::DDim GetInputDim(const std::string &name) const; std::vector GetInputsDim(const std::string &name) const; DDim GetInputsElementDim(const std::string &name, int idx) const; - virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; + void SetOutputDim(const std::string &name, const DDim &dim); void SetOutputsDim(const std::string &name, const std::vector &dims); @@ -63,9 +63,7 @@ class InferShapeContext { virtual framework::DDim GetDim(const std::string &name) const = 0; virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; - std::vector GetDims( - const std::vector &names) const; - + std::vector GetDims(const std::vector &names) const; std::vector GetVarTypes( const std::vector &names) const; From ab1341eab7c67b991499ed9cad15d8901e2bc76b Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 2 Feb 2018 18:40:52 +0800 Subject: [PATCH 2/3] fix typo --- paddle/framework/shape_inference.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index 14dba75808..0e17219e4e 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -21,7 +21,7 @@ namespace framework { framework::DDim InferShapeContext::GetInputDim(const std::string &name) const { const std::vector &arg_names = Inputs(name); PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, - "Input(%s) shoudl holds one element, but now it holds %d", + "Input(%s) should hold one element, but now it holds %d", name, arg_names.size()); return this->GetDim(arg_names[0]); } @@ -41,7 +41,7 @@ DDim InferShapeContext::GetInputsElementDim(const std::string &name, void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { auto &arg_names = Outputs(name); PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, - "Output(%s) shoudl holds one element, but now it holds %d", + "Output(%s) should hold one element, but now it holds %d", name, arg_names.size()); SetDim(arg_names[0], dim); } From 9a970702d3cff51af75ff513c07dbeb226920f1d Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 2 Feb 2018 18:56:16 +0800 Subject: [PATCH 3/3] remove unnecessary framework:: --- paddle/framework/shape_inference.cc | 12 ++++++------ paddle/framework/shape_inference.h | 13 ++++++------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index 0e17219e4e..a0fa467291 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -18,7 +18,7 @@ limitations under the License. */ namespace paddle { namespace framework { -framework::DDim InferShapeContext::GetInputDim(const std::string &name) const { +DDim InferShapeContext::GetInputDim(const std::string &name) const { const std::vector &arg_names = Inputs(name); PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, "Input(%s) should hold one element, but now it holds %d", @@ -26,7 +26,7 @@ framework::DDim InferShapeContext::GetInputDim(const std::string &name) const { return this->GetDim(arg_names[0]); } -std::vector InferShapeContext::GetInputsDim( +std::vector InferShapeContext::GetInputsDim( const std::string &name) const { const std::vector &arg_names = Inputs(name); return GetDims(arg_names); @@ -46,15 +46,15 @@ void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { SetDim(arg_names[0], dim); } -void InferShapeContext::SetOutputsDim( - const std::string &name, const std::vector &dims) { +void InferShapeContext::SetOutputsDim(const std::string &name, + const std::vector &dims) { auto &names = Outputs(name); SetDims(names, dims); } std::vector InferShapeContext::GetDims( const std::vector &names) const { - std::vector ret; + std::vector ret; ret.reserve(names.size()); std::transform( names.begin(), names.end(), std::back_inserter(ret), @@ -62,7 +62,7 @@ std::vector InferShapeContext::GetDims( return ret; } void InferShapeContext::SetDims(const std::vector &names, - const std::vector &dims) { + const std::vector &dims) { size_t length = names.size(); PADDLE_ENFORCE_EQ(length, dims.size()); for (size_t i = 0; i < length; ++i) { diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 77fc9359be..830f199ed1 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -35,14 +35,13 @@ class InferShapeContext { virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; - framework::DDim GetInputDim(const std::string &name) const; + DDim GetInputDim(const std::string &name) const; - std::vector GetInputsDim(const std::string &name) const; + std::vector GetInputsDim(const std::string &name) const; DDim GetInputsElementDim(const std::string &name, int idx) const; void SetOutputDim(const std::string &name, const DDim &dim); - void SetOutputsDim(const std::string &name, - const std::vector &dims); + void SetOutputsDim(const std::string &name, const std::vector &dims); virtual AttrReader Attrs() const = 0; virtual const std::vector &Inputs( @@ -57,11 +56,11 @@ class InferShapeContext { // Note: In while op, we need this to be public void SetDims(const std::vector &names, - const std::vector &dims); + const std::vector &dims); protected: - virtual framework::DDim GetDim(const std::string &name) const = 0; - virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; + virtual DDim GetDim(const std::string &name) const = 0; + virtual void SetDim(const std::string &name, const DDim &dim) = 0; std::vector GetDims(const std::vector &names) const; std::vector GetVarTypes(