From 62eb43ba98931f303127441b0f53f142b12f439f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 18 Dec 2018 20:22:56 +0800 Subject: [PATCH 1/8] convert more test=develop --- paddle/fluid/framework/operator.cc | 35 ++++++++++++++---------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 8c83748668..5bee6b41bd 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, const Scope& scope) { for (auto& var_name_item : innames) { std::vector& input_vars = inputs[var_name_item.first]; + input_vars.reserve(var_name_item.second.size()); for (auto& var_name : var_name_item.second) { input_vars.push_back(scope.FindVar(var_name)); } } for (auto& var_name_item : outnames) { std::vector& output_vars = outputs[var_name_item.first]; + output_vars.reserve(var_name_item.second.size()); for (auto& var_name : var_name_item.second) { output_vars.push_back(scope.FindVar(var_name)); } @@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext { bool HasOutput(const std::string& name) const override { // has only one output - const auto& outs = op_.Outputs(); + const auto& outs = ctx_.outputs; auto it = outs.find(name); if (it == outs.end()) { return false; } const auto& out = it->second; - if (out.size() == 0 || out[0] == kEmptyVarName) { + if (out.size() == 0) { return false; } PADDLE_ENFORCE_EQ(out.size(), 1UL, "Output %s should not have more than one outputs", name); - return scope_.FindVar(out[0]) != nullptr; + return out[0] != nullptr; } bool HasInputs(const std::string& name) const override { - if (!op_.HasInputs(name)) { - return false; - } - auto inputs = op_.Inputs(name); - if (inputs.empty()) { + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end()) { return false; } - for (auto& input : inputs) { - if (scope_.FindVar(input) == nullptr) { + for (auto& input : it->second) { + if (input == nullptr) { return false; } } @@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext { } bool HasOutputs(const std::string& name) const override { - if (!op_.HasOutputs(name)) { - return false; - } - auto outputs = op_.Outputs(name); - if (outputs.empty()) { + const auto& outs = ctx_.outputs; + auto it = outs.find(name); + if (it == outs.end()) { return false; } - for (auto& output : outputs) { - if (scope_.FindVar(output) == nullptr) { + for (auto& output : it->second) { + if (output == nullptr) { return false; } } @@ -864,8 +862,7 @@ Scope* OperatorWithKernel::PrepareData( for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto& var_name = var_name_item.second[i]; - auto* var = scope.FindVar(var_name); - input_vars[i] = var; + auto* var = input_vars[i]; // Only tensor can be tranfer to another device. if (var == nullptr || !VarIsTensor(*var)) { From 0e0983cc1d9a607ba8a339bbbe9e495e304cd11f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 18 Dec 2018 21:27:04 +0800 Subject: [PATCH 2/8] convert more infer shape --- paddle/fluid/framework/operator.cc | 34 ++++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5bee6b41bd..a7bee3344d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -614,16 +614,19 @@ class RuntimeInferShapeContext : public InferShapeContext { void ShareDim(const std::string& in, const std::string& out, size_t i = 0, size_t j = 0) override { - PADDLE_ENFORCE_LT(i, Inputs(in).size()); - PADDLE_ENFORCE_LT(j, Outputs(out).size()); - const std::string& input_n = Inputs(in)[i]; - const std::string& output_n = Outputs(out)[j]; + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i, + "Inputs %s should have %llu argument", in, i); + PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j, + "Outputs %s should have %llu argument", out, j); + + Variable* in_var = in_it->second[i]; + Variable* out_var = out_it->second[j]; - Variable* in_var = scope_.FindVar(input_n); - Variable* out_var = scope_.FindVar(output_n); PADDLE_ENFORCE(in_var->Type() == out_var->Type(), - "The type of %s and %s is not the same.", output_n, - GetDim(input_n)); + "The type of %s and %s is not the same.", in_var->Type(), + out_var->Type()); if (in_var->IsType()) { auto& in_sele_rows = in_var->Get(); @@ -644,13 +647,16 @@ class RuntimeInferShapeContext : public InferShapeContext { void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, size_t j = 0) const override { - const std::vector& inputs = Inputs(in); - const std::vector& outputs = Outputs(out); - PADDLE_ENFORCE_LT(i, inputs.size()); - PADDLE_ENFORCE_LT(j, outputs.size()); - Variable* in_var = scope_.FindVar(inputs.at(i)); + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i, + "Inputs %s should have %llu argument", in, i); + PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j, + "Outputs %s should have %llu argument", out, j); + + Variable* in_var = in_it->second.at(i); if (!in_var->IsType()) return; - Variable* out_var = scope_.FindVar(outputs.at(j)); + Variable* out_var = out_it->second.at(j); PADDLE_ENFORCE(out_var->IsType(), "The %d-th output of Output(%s) must be LoDTensor.", j, out); auto in_tensor = in_var->Get(); From 52d3903a1208747c2e3c97b90bb0f48e08f7a85b Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 19 Dec 2018 10:03:36 +0800 Subject: [PATCH 3/8] fix test=develop --- paddle/fluid/framework/operator.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index a7bee3344d..e023d165b0 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -575,7 +575,7 @@ class RuntimeInferShapeContext : public InferShapeContext { bool HasInputs(const std::string& name) const override { const auto& ins = ctx_.inputs; auto it = ins.find(name); - if (it == ins.end()) { + if (it == ins.end() || it->second.empty()) { return false; } for (auto& input : it->second) { @@ -589,7 +589,7 @@ class RuntimeInferShapeContext : public InferShapeContext { bool HasOutputs(const std::string& name) const override { const auto& outs = ctx_.outputs; auto it = outs.find(name); - if (it == outs.end()) { + if (it == outs.end() || it->second.empty()) { return false; } for (auto& output : it->second) { From 4dd61e7260314faa4b9b8f5a4c5406af013d919e Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 19 Dec 2018 11:07:16 +0800 Subject: [PATCH 4/8] convert GetInputVarPtrs and GetOutputVarPtrs test=develop --- paddle/fluid/framework/op_desc.cc | 31 ++++++++++++++----- paddle/fluid/framework/operator.cc | 36 +++++++++++++++++++++-- paddle/fluid/framework/shape_inference.cc | 22 -------------- paddle/fluid/framework/shape_inference.h | 31 ++++++++++--------- 4 files changed, 74 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index dde642764f..0a3bb586fc 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -110,6 +110,30 @@ class CompileTimeInferShapeContext : public InferShapeContext { } } + std::vector GetInputVarPtrs( + const std::string &name) override { + const std::vector arg_names = Inputs(name); + std::vector res; + res.reserve(arg_names.size()); + std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), + [this](const std::string &name) { + return block_.FindVarRecursive(name); + }); + return res; + } + + std::vector GetOutputVarPtrs( + const std::string &name) override { + const std::vector arg_names = Outputs(name); + std::vector res; + res.reserve(arg_names.size()); + std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), + [this](const std::string &name) { + return block_.FindVarRecursive(name); + }); + return res; + } + bool IsRuntime() const override; protected: @@ -124,8 +148,6 @@ class CompileTimeInferShapeContext : public InferShapeContext { void SetRepeatedDims(const std::string &name, const std::vector &dims) override; - InferShapeVarPtr GetVarPtr(const std::string &name) override; - const OpDesc &op_; const BlockDesc &block_; }; @@ -696,10 +718,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType( return block_.FindVarRecursive(name)->GetType(); } -InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr( - const std::string &name) { - return block_.FindVarRecursive(name); -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e023d165b0..4ccef3105c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -691,6 +691,25 @@ class RuntimeInferShapeContext : public InferShapeContext { bool IsRuntime() const override { return true; } + // TODO(paddle-dev): Can this be template? + std::vector GetInputVarPtrs( + const std::string& name) override { + const std::vector& vars = InputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; + } + + std::vector GetOutputVarPtrs( + const std::string& name) override { + const std::vector& vars = OutputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; + } + protected: DDim GetDim(const std::string& name) const override { Variable* var = scope_.FindVar(name); @@ -733,11 +752,22 @@ class RuntimeInferShapeContext : public InferShapeContext { return ToVarType(var->Type()); } - InferShapeVarPtr GetVarPtr(const std::string& name) override { - return scope_.FindVar(name); + private: + const std::vector& InputVars(const std::string& name) const { + auto it = ctx_.inputs.find(name); + PADDLE_ENFORCE(it != ctx_.inputs.end(), + "Operator %s does not have the input %s.", op_.Type(), name); + return it->second; + } + + const std::vector& OutputVars(const std::string& name) const { + auto it = ctx_.outputs.find(name); + PADDLE_ENFORCE(it != ctx_.outputs.end(), + "Operator %s does not have the outputs %s.", op_.Type(), + name); + return it->second; } - private: const OperatorBase& op_; const Scope& scope_; const RuntimeContext& ctx_; diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index ddff2c7c26..0a7cebcc5a 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -76,28 +76,6 @@ void InferShapeContext::SetReaderDims(const std::string &name, return this->SetRepeatedDims(arg_names[0], dims); } -std::vector InferShapeContext::GetInputVarPtrs( - const std::string &name) { - const std::vector arg_names = Inputs(name); - std::vector res; - res.reserve(arg_names.size()); - std::transform( - arg_names.begin(), arg_names.end(), std::back_inserter(res), - [this](const std::string &name) { return this->GetVarPtr(name); }); - return res; -} - -std::vector InferShapeContext::GetOutputVarPtrs( - const std::string &name) { - const std::vector arg_names = Outputs(name); - std::vector res; - res.reserve(arg_names.size()); - std::transform( - arg_names.begin(), arg_names.end(), std::back_inserter(res), - [this](const std::string &name) { return this->GetVarPtr(name); }); - return res; -} - std::vector InferShapeContext::GetDims( const std::vector &names) const { std::vector ret; diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index d73cca121e..543696d43b 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -33,22 +33,24 @@ class InferShapeContext { virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; - std::vector GetInputsVarType( + virtual std::vector GetInputsVarType( const std::string &name) const; - std::vector GetOutputsVarType( + virtual std::vector GetOutputsVarType( const std::string &name) const; virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; - DDim GetInputDim(const std::string &name) const; - std::vector GetInputsDim(const std::string &name) const; - std::vector GetReaderDims(const std::string &name) const; - DDim GetInputsElementDim(const std::string &name, int idx) const; + virtual DDim GetInputDim(const std::string &name) const; + virtual std::vector GetInputsDim(const std::string &name) const; + virtual std::vector GetReaderDims(const std::string &name) const; + virtual DDim GetInputsElementDim(const std::string &name, int idx) const; - void SetOutputDim(const std::string &name, const DDim &dim); - void SetOutputsDim(const std::string &name, const std::vector &dims); - void SetReaderDims(const std::string &name, const std::vector &dims); + virtual void SetOutputDim(const std::string &name, const DDim &dim); + virtual void SetOutputsDim(const std::string &name, + const std::vector &dims); + virtual void SetReaderDims(const std::string &name, + const std::vector &dims); virtual AttrReader Attrs() const = 0; virtual const std::vector &Inputs( @@ -67,13 +69,14 @@ class InferShapeContext { virtual bool IsRuntime() const = 0; - std::vector GetInputVarPtrs(const std::string &name); - std::vector GetOutputVarPtrs(const std::string &name); - virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; + virtual std::vector GetInputVarPtrs( + const std::string &name) = 0; + virtual std::vector GetOutputVarPtrs( + const std::string &name) = 0; // Note: In while op, we need this to be public - void SetDims(const std::vector &names, - const std::vector &dims); + virtual void SetDims(const std::vector &names, + const std::vector &dims); protected: virtual DDim GetDim(const std::string &name) const = 0; From 8c19f0bfe3251ee270546d991e9412ff8dd50100 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 19 Dec 2018 14:53:33 +0800 Subject: [PATCH 5/8] fix test=develop --- paddle/fluid/framework/operator.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 4ccef3105c..2f418f728f 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -625,8 +625,7 @@ class RuntimeInferShapeContext : public InferShapeContext { Variable* out_var = out_it->second[j]; PADDLE_ENFORCE(in_var->Type() == out_var->Type(), - "The type of %s and %s is not the same.", in_var->Type(), - out_var->Type()); + "The type of %s and %s is not the same.", in, out); if (in_var->IsType()) { auto& in_sele_rows = in_var->Get(); From 876993887bba14c82e1d2d8e7718f9c8df630422 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 19 Dec 2018 15:44:39 +0800 Subject: [PATCH 6/8] convert more interface to avoid scope test=develop --- paddle/fluid/framework/op_desc.cc | 50 +++++++++++++------ paddle/fluid/framework/operator.cc | 28 +++++++++-- paddle/fluid/framework/shape_inference.cc | 30 ----------- paddle/fluid/framework/shape_inference.h | 8 +-- .../fluid/operators/controlflow/while_op.cc | 2 +- 5 files changed, 62 insertions(+), 56 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 0a3bb586fc..ef98558820 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -134,12 +134,46 @@ class CompileTimeInferShapeContext : public InferShapeContext { return res; } + DDim GetInputDim(const std::string &name) const override { + const std::vector &arg_names = Inputs(name); + PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, + "Input(%s) should hold one element, but now it holds %d", + name, arg_names.size()); + return this->GetDim(arg_names[0]); + } + + std::vector GetInputsDim(const std::string &name) const override { + const std::vector &arg_names = Inputs(name); + return GetDims(arg_names); + } + bool IsRuntime() const override; protected: proto::VarType::Type GetVarType(const std::string &name) const override; - DDim GetDim(const std::string &name) const override; + DDim GetDim(const std::string &name) const { + auto var = block_.FindVarRecursive(name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); + DDim res; + try { + auto shape = var->GetShape(); + res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape); + } catch (...) { + VLOG(5) << "GetDim of variable " << name << " error"; + std::rethrow_exception(std::current_exception()); + } + return res; + } + + std::vector GetDims(const std::vector &names) const { + std::vector ret; + ret.reserve(names.size()); + std::transform( + names.begin(), names.end(), std::back_inserter(ret), + [this](const std::string &name) { return this->GetDim(name); }); + return ret; + } void SetDim(const std::string &name, const DDim &dim) override; @@ -666,20 +700,6 @@ const std::vector &CompileTimeInferShapeContext::Outputs( return op_.Output(name); } -DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { - auto var = block_.FindVarRecursive(name); - PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); - DDim res; - try { - auto shape = var->GetShape(); - res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape); - } catch (...) { - VLOG(5) << "GetDim of variable " << name << " error"; - std::rethrow_exception(std::current_exception()); - } - return res; -} - std::vector CompileTimeInferShapeContext::GetRepeatedDims( const std::string &name) const { auto var = block_.FindVarRecursive(name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2f418f728f..2bfe055b4c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -709,9 +709,21 @@ class RuntimeInferShapeContext : public InferShapeContext { return res; } + DDim GetInputDim(const std::string& name) const override { + const std::vector& vars = InputVars(name); + PADDLE_ENFORCE_EQ(vars.size(), 1UL, + "Input(%s) should hold one element, but now it holds %d", + name, vars.size()); + return this->GetDim(vars[0]); + } + + std::vector GetInputsDim(const std::string& name) const override { + const std::vector& vars = InputVars(name); + return GetDims(vars); + } + protected: - DDim GetDim(const std::string& name) const override { - Variable* var = scope_.FindVar(name); + DDim GetDim(Variable* var) const { PADDLE_ENFORCE_NOT_NULL(var); if (var->IsType()) { return var->Get().dims(); @@ -719,12 +731,20 @@ class RuntimeInferShapeContext : public InferShapeContext { return var->Get().GetCompleteDims(); } else { PADDLE_THROW( - "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's " + "Only LoDTensor/SelectedRows support 'GetDim', but Variables " "type_id is %s.", - name, var->Type().name()); + var->Type().name()); } } + std::vector GetDims(const std::vector& vars) const { + std::vector ret; + ret.reserve(vars.size()); + std::transform(vars.begin(), vars.end(), std::back_inserter(ret), + [this](Variable* var) { return this->GetDim(var); }); + return ret; + } + std::vector GetRepeatedDims(const std::string& name) const override { PADDLE_THROW("Only compile time support this method"); } diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index 0a7cebcc5a..f274a1b73f 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -22,20 +22,6 @@ limitations under the License. */ namespace paddle { namespace framework { -DDim InferShapeContext::GetInputDim(const std::string &name) const { - const std::vector &arg_names = Inputs(name); - PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, - "Input(%s) should hold one element, but now it holds %d", - name, arg_names.size()); - return this->GetDim(arg_names[0]); -} - -std::vector InferShapeContext::GetInputsDim( - const std::string &name) const { - const std::vector &arg_names = Inputs(name); - return GetDims(arg_names); -} - std::vector InferShapeContext::GetReaderDims( const std::string &name) const { const std::vector &arg_names = Inputs(name); @@ -46,12 +32,6 @@ std::vector InferShapeContext::GetReaderDims( return this->GetRepeatedDims(arg_names[0]); } -DDim InferShapeContext::GetInputsElementDim(const std::string &name, - int idx) const { - const std::vector &names = Inputs(name); - return this->GetDim(names[idx]); -} - void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { auto &arg_names = Outputs(name); PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, @@ -76,16 +56,6 @@ void InferShapeContext::SetReaderDims(const std::string &name, return this->SetRepeatedDims(arg_names[0], dims); } -std::vector InferShapeContext::GetDims( - const std::vector &names) const { - std::vector ret; - ret.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(ret), - [this](const std::string &name) { return this->GetDim(name); }); - return ret; -} - void InferShapeContext::SetDims(const std::vector &names, const std::vector &dims) { size_t length = names.size(); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 543696d43b..6cf9cf3f38 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -41,10 +41,9 @@ class InferShapeContext { virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; - virtual DDim GetInputDim(const std::string &name) const; - virtual std::vector GetInputsDim(const std::string &name) const; + virtual DDim GetInputDim(const std::string &name) const = 0; + virtual std::vector GetInputsDim(const std::string &name) const = 0; virtual std::vector GetReaderDims(const std::string &name) const; - virtual DDim GetInputsElementDim(const std::string &name, int idx) const; virtual void SetOutputDim(const std::string &name, const DDim &dim); virtual void SetOutputsDim(const std::string &name, @@ -79,14 +78,11 @@ class InferShapeContext { const std::vector &dims); protected: - virtual DDim GetDim(const std::string &name) const = 0; virtual void SetDim(const std::string &name, const DDim &dim) = 0; virtual std::vector GetRepeatedDims(const std::string &name) const = 0; virtual void SetRepeatedDims(const std::string &name, const std::vector &dims) = 0; - std::vector GetDims(const std::vector &names) const; - std::vector GetVarTypes( const std::vector &names) const; diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index e91d9ef776..3f75ee956a 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -408,7 +408,7 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { if (pg_ig_names[i] == framework::kEmptyVarName) { continue; } - auto dims = ctx->GetInputsElementDim(kX, i); + auto dims = ctx->GetInputsDim(kX)[i]; if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { names_to_set.push_back(pg_ig_names[i]); dims_to_set.push_back(dims); From 9ef8a76873983c61eb91fab99f3306a5be8ef0c0 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 19 Dec 2018 16:13:31 +0800 Subject: [PATCH 7/8] convert more test=develop --- paddle/fluid/framework/op_desc.cc | 23 ++++++++++++++++++++++- paddle/fluid/framework/operator.cc | 23 +++++++++++++++++++++-- paddle/fluid/framework/shape_inference.cc | 20 -------------------- paddle/fluid/framework/shape_inference.h | 9 ++------- 4 files changed, 45 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index ef98558820..4d204aefde 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -149,8 +149,29 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool IsRuntime() const override; + std::vector GetInputsVarType( + const std::string &name) const override { + return GetVarTypes(Inputs(name)); + } + + std::vector GetOutputsVarType( + const std::string &name) const override { + return GetVarTypes(Outputs(name)); + } + protected: - proto::VarType::Type GetVarType(const std::string &name) const override; + std::vector GetVarTypes( + const std::vector &names) const { + std::vector retv; + retv.resize(names.size()); + std::transform( + names.begin(), names.end(), retv.begin(), + std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType), this, + std::placeholders::_1)); + return retv; + } + + proto::VarType::Type GetVarType(const std::string &name) const; DDim GetDim(const std::string &name) const { auto var = block_.FindVarRecursive(name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2bfe055b4c..eb172ca88f 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -722,6 +722,16 @@ class RuntimeInferShapeContext : public InferShapeContext { return GetDims(vars); } + std::vector GetInputsVarType( + const std::string& name) const override { + return GetVarTypes(InputVars(name)); + } + + std::vector GetOutputsVarType( + const std::string& name) const override { + return GetVarTypes(OutputVars(name)); + } + protected: DDim GetDim(Variable* var) const { PADDLE_ENFORCE_NOT_NULL(var); @@ -766,8 +776,17 @@ class RuntimeInferShapeContext : public InferShapeContext { PADDLE_THROW("Only compile time support this method"); } - proto::VarType::Type GetVarType(const std::string& name) const override { - auto* var = scope_.FindVar(name); + std::vector GetVarTypes( + const std::vector& vars) const { + std::vector retv; + retv.resize(vars.size()); + std::transform(vars.begin(), vars.end(), retv.begin(), + std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType), + this, std::placeholders::_1)); + return retv; + } + + proto::VarType::Type GetVarType(Variable* var) const { return ToVarType(var->Type()); } diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index f274a1b73f..4e67855b5c 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -68,25 +68,5 @@ void InferShapeContext::SetDims(const std::vector &names, } } -std::vector InferShapeContext::GetInputsVarType( - const std::string &name) const { - return GetVarTypes(Inputs(name)); -} - -std::vector InferShapeContext::GetOutputsVarType( - const std::string &name) const { - return GetVarTypes(Outputs(name)); -} - -std::vector InferShapeContext::GetVarTypes( - const std::vector &names) const { - std::vector retv; - retv.resize(names.size()); - std::transform(names.begin(), names.end(), retv.begin(), - std::bind(std::mem_fn(&InferShapeContext::GetVarType), this, - std::placeholders::_1)); - return retv; -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 6cf9cf3f38..415339a01d 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -34,9 +34,9 @@ class InferShapeContext { virtual bool HasOutput(const std::string &name) const = 0; virtual std::vector GetInputsVarType( - const std::string &name) const; + const std::string &name) const = 0; virtual std::vector GetOutputsVarType( - const std::string &name) const; + const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; @@ -82,11 +82,6 @@ class InferShapeContext { virtual std::vector GetRepeatedDims(const std::string &name) const = 0; virtual void SetRepeatedDims(const std::string &name, const std::vector &dims) = 0; - - std::vector GetVarTypes( - const std::vector &names) const; - - virtual proto::VarType::Type GetVarType(const std::string &name) const = 0; }; } // namespace framework From 1fe3ac352a3471e87111cd3f1021fe879fbdf6fe Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 19 Dec 2018 19:13:26 +0800 Subject: [PATCH 8/8] move more and fix while test=develop --- paddle/fluid/framework/op_desc.cc | 28 +++++++++++- paddle/fluid/framework/operator.cc | 33 ++++++++++++-- paddle/fluid/framework/shape_inference.cc | 26 ----------- paddle/fluid/framework/shape_inference.h | 9 +--- .../fluid/operators/controlflow/while_op.cc | 43 +++++++++++++------ 5 files changed, 87 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 4d204aefde..2fe1c94ec0 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -159,6 +159,20 @@ class CompileTimeInferShapeContext : public InferShapeContext { return GetVarTypes(Outputs(name)); } + void SetOutputDim(const std::string &name, const DDim &dim) override { + auto &arg_names = Outputs(name); + PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, + "Output(%s) should hold one element, but now it holds %d", + name, arg_names.size()); + SetDim(arg_names[0], dim); + } + + void SetOutputsDim(const std::string &name, + const std::vector &dims) override { + auto &names = Outputs(name); + SetDims(names, dims); + } + protected: std::vector GetVarTypes( const std::vector &names) const { @@ -196,7 +210,19 @@ class CompileTimeInferShapeContext : public InferShapeContext { return ret; } - void SetDim(const std::string &name, const DDim &dim) override; + void SetDim(const std::string &name, const DDim &dim); + + void SetDims(const std::vector &names, + const std::vector &dims) { + size_t length = names.size(); + PADDLE_ENFORCE_EQ(length, dims.size()); + for (size_t i = 0; i < length; ++i) { + if (names[i] == framework::kEmptyVarName) { + continue; + } + SetDim(names[i], dims[i]); + } + } std::vector GetRepeatedDims(const std::string &name) const override; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index eb172ca88f..4b520a393f 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -732,6 +732,20 @@ class RuntimeInferShapeContext : public InferShapeContext { return GetVarTypes(OutputVars(name)); } + void SetOutputDim(const std::string& name, const DDim& dim) override { + auto& vars = OutputVars(name); + PADDLE_ENFORCE_EQ(vars.size(), 1UL, + "Output(%s) should hold one element, but now it holds %d", + name, vars.size()); + SetDim(vars[0], dim); + } + + void SetOutputsDim(const std::string& name, + const std::vector& dims) override { + auto& vars = OutputVars(name); + SetDims(vars, dims); + } + protected: DDim GetDim(Variable* var) const { PADDLE_ENFORCE_NOT_NULL(var); @@ -759,15 +773,26 @@ class RuntimeInferShapeContext : public InferShapeContext { PADDLE_THROW("Only compile time support this method"); } - void SetDim(const std::string& name, const DDim& dim) override { - Variable* var = scope_.FindVar(name); + void SetDim(Variable* var, const DDim& dim) { if (var->IsType()) { var->GetMutable()->Resize(dim); } else if (var->IsType()) { var->GetMutable()->set_height(dim[0]); } else { - PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", - name, var->Type().name()); + PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", + var->Type().name()); + } + } + + void SetDims(const std::vector& vars, + const std::vector& dims) { + size_t length = vars.size(); + PADDLE_ENFORCE_EQ(length, dims.size()); + for (size_t i = 0; i < length; ++i) { + if (vars[i] == nullptr) { + continue; + } + SetDim(vars[i], dims[i]); } } diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index 4e67855b5c..4ac872ac3d 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -32,20 +32,6 @@ std::vector InferShapeContext::GetReaderDims( return this->GetRepeatedDims(arg_names[0]); } -void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) { - auto &arg_names = Outputs(name); - PADDLE_ENFORCE_EQ(arg_names.size(), 1UL, - "Output(%s) should hold one element, but now it holds %d", - name, arg_names.size()); - SetDim(arg_names[0], dim); -} - -void InferShapeContext::SetOutputsDim(const std::string &name, - const std::vector &dims) { - auto &names = Outputs(name); - SetDims(names, dims); -} - void InferShapeContext::SetReaderDims(const std::string &name, const std::vector &dims) { const std::vector &arg_names = Outputs(name); @@ -56,17 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name, return this->SetRepeatedDims(arg_names[0], dims); } -void InferShapeContext::SetDims(const std::vector &names, - const std::vector &dims) { - size_t length = names.size(); - PADDLE_ENFORCE_EQ(length, dims.size()); - for (size_t i = 0; i < length; ++i) { - if (names[i] == framework::kEmptyVarName) { - continue; - } - SetDim(names[i], dims[i]); - } -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 415339a01d..824f75b3d3 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -45,9 +45,9 @@ class InferShapeContext { virtual std::vector GetInputsDim(const std::string &name) const = 0; virtual std::vector GetReaderDims(const std::string &name) const; - virtual void SetOutputDim(const std::string &name, const DDim &dim); + virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; virtual void SetOutputsDim(const std::string &name, - const std::vector &dims); + const std::vector &dims) = 0; virtual void SetReaderDims(const std::string &name, const std::vector &dims); @@ -73,12 +73,7 @@ class InferShapeContext { virtual std::vector GetOutputVarPtrs( const std::string &name) = 0; - // Note: In while op, we need this to be public - virtual void SetDims(const std::vector &names, - const std::vector &dims); - protected: - virtual void SetDim(const std::string &name, const DDim &dim) = 0; virtual std::vector GetRepeatedDims(const std::string &name) const = 0; virtual void SetRepeatedDims(const std::string &name, const std::vector &dims) = 0; diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 3f75ee956a..48800947fd 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ctx->HasInputs(kOutputs); ctx->HasInputs(framework::GradVarName(kOutputs)); - auto p_names = ctx->Inputs(kX); auto pg_ig_names = ctx->Outputs(kXGRAD); - auto var_types = ctx->GetInputsVarType(kX); - std::vector names_to_set; - std::vector dims_to_set; - for (size_t i = 0; i < p_names.size(); ++i) { + std::vector in_var_ptrs = + ctx->GetInputVarPtrs(kX); + std::vector out_var_ptrs = + ctx->GetOutputVarPtrs(kXGRAD); + PADDLE_ENFORCE(in_var_ptrs.size() == out_var_ptrs.size()); + + for (size_t i = 0; i < in_var_ptrs.size(); ++i) { if (pg_ig_names[i] == framework::kEmptyVarName) { continue; } - auto dims = ctx->GetInputsDim(kX)[i]; - if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { - names_to_set.push_back(pg_ig_names[i]); - dims_to_set.push_back(dims); - } else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) { - // not sure how to set the dim of LOD_TENSOR_ARRAY - names_to_set.push_back(pg_ig_names[i]); - dims_to_set.push_back(dims); + if (ctx->IsRuntime()) { + framework::Variable *in_var = + boost::get(in_var_ptrs[i]); + framework::Variable *out_var = + boost::get(out_var_ptrs[i]); + + auto type = framework::ToVarType(in_var->Type()); + if (type == framework::proto::VarType::LOD_TENSOR) { + out_var->GetMutable()->Resize( + in_var->Get().dims()); + } else if (type == framework::proto::VarType::SELECTED_ROWS) { + out_var->GetMutable()->set_height( + in_var->Get().GetCompleteDims()[0]); + } else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) { + PADDLE_THROW("WhileGradOp doesn't support type %d", + static_cast(type)); + } + } else { + framework::VarDesc *in_var = + boost::get(in_var_ptrs[i]); + boost::get(out_var_ptrs[i]) + ->SetShape(in_var->GetShape()); } } - ctx->SetDims(names_to_set, dims_to_set); } };