Merge pull request #14959 from panyx0718/clean2

Further op RunImpl refactor
for_weibo
Xin Pan 7 years ago committed by GitHub
commit a872eb90c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -110,22 +110,125 @@ class CompileTimeInferShapeContext : public InferShapeContext {
}
}
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) override {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
return block_.FindVarRecursive(name);
});
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) override {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
return block_.FindVarRecursive(name);
});
return res;
}
DDim GetInputDim(const std::string &name) const override {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}
std::vector<DDim> GetInputsDim(const std::string &name) const override {
const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(arg_names);
}
bool IsRuntime() const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const override {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const override {
return GetVarTypes(Outputs(name));
}
void SetOutputDim(const std::string &name, const DDim &dim) override {
auto &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, arg_names.size());
SetDim(arg_names[0], dim);
}
void SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) override {
auto &names = Outputs(name);
SetDims(names, dims);
}
protected:
proto::VarType::Type GetVarType(const std::string &name) const override;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarType::Type> retv;
retv.resize(names.size());
std::transform(
names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}
proto::VarType::Type GetVarType(const std::string &name) const;
DDim GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}
DDim GetDim(const std::string &name) const override;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
std::vector<DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void SetDim(const std::string &name, const DDim &dim);
void SetDim(const std::string &name, const DDim &dim) override;
void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
if (names[i] == framework::kEmptyVarName) {
continue;
}
SetDim(names[i], dims[i]);
}
}
std::vector<DDim> GetRepeatedDims(const std::string &name) const override;
void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) override;
InferShapeVarPtr GetVarPtr(const std::string &name) override;
const OpDesc &op_;
const BlockDesc &block_;
};
@ -644,20 +747,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
return op_.Output(name);
}
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
const std::string &name) const {
auto var = block_.FindVarRecursive(name);
@ -696,10 +785,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
return block_.FindVarRecursive(name)->GetType();
}
InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr(
const std::string &name) {
return block_.FindVarRecursive(name);
}
} // namespace framework
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -22,20 +22,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
DDim InferShapeContext::GetInputDim(const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}
std::vector<DDim> InferShapeContext::GetInputsDim(
const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(arg_names);
}
std::vector<DDim> InferShapeContext::GetReaderDims(
const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
@ -46,26 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return this->GetRepeatedDims(arg_names[0]);
}
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const {
const std::vector<std::string> &names = Inputs(name);
return this->GetDim(names[idx]);
}
void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) {
auto &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, arg_names.size());
SetDim(arg_names[0], dim);
}
void InferShapeContext::SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) {
auto &names = Outputs(name);
SetDims(names, dims);
}
void InferShapeContext::SetReaderDims(const std::string &name,
const std::vector<DDim> &dims) {
const std::vector<std::string> &arg_names = Outputs(name);
@ -76,69 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return this->SetRepeatedDims(arg_names[0], dims);
}
std::vector<InferShapeVarPtr> InferShapeContext::GetInputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<InferShapeVarPtr> InferShapeContext::GetOutputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void InferShapeContext::SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
if (names[i] == framework::kEmptyVarName) {
continue;
}
SetDim(names[i], dims[i]);
}
}
std::vector<proto::VarType::Type> InferShapeContext::GetInputsVarType(
const std::string &name) const {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarType::Type> InferShapeContext::GetOutputsVarType(
const std::string &name) const {
return GetVarTypes(Outputs(name));
}
std::vector<proto::VarType::Type> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarType::Type> retv;
retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}
} // namespace framework
} // namespace paddle

@ -33,22 +33,23 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const;
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const;
virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const = 0;
virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0;
DDim GetInputDim(const std::string &name) const;
std::vector<DDim> GetInputsDim(const std::string &name) const;
std::vector<DDim> GetReaderDims(const std::string &name) const;
DDim GetInputsElementDim(const std::string &name, int idx) const;
virtual DDim GetInputDim(const std::string &name) const = 0;
virtual std::vector<DDim> GetInputsDim(const std::string &name) const = 0;
virtual std::vector<DDim> GetReaderDims(const std::string &name) const;
void SetOutputDim(const std::string &name, const DDim &dim);
void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims);
void SetReaderDims(const std::string &name, const std::vector<DDim> &dims);
virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
virtual void SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) = 0;
virtual void SetReaderDims(const std::string &name,
const std::vector<DDim> &dims);
virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs(
@ -67,27 +68,15 @@ class InferShapeContext {
virtual bool IsRuntime() const = 0;
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name);
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
// Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims);
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) = 0;
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) = 0;
protected:
virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const DDim &dim) = 0;
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
virtual void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const;
virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
};
} // namespace framework

@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs));
auto p_names = ctx->Inputs(kX);
auto pg_ig_names = ctx->Outputs(kXGRAD);
auto var_types = ctx->GetInputsVarType(kX);
std::vector<std::string> names_to_set;
std::vector<framework::DDim> dims_to_set;
for (size_t i = 0; i < p_names.size(); ++i) {
std::vector<framework::InferShapeVarPtr> in_var_ptrs =
ctx->GetInputVarPtrs(kX);
std::vector<framework::InferShapeVarPtr> out_var_ptrs =
ctx->GetOutputVarPtrs(kXGRAD);
PADDLE_ENFORCE(in_var_ptrs.size() == out_var_ptrs.size());
for (size_t i = 0; i < in_var_ptrs.size(); ++i) {
if (pg_ig_names[i] == framework::kEmptyVarName) {
continue;
}
auto dims = ctx->GetInputsElementDim(kX, i);
if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
names_to_set.push_back(pg_ig_names[i]);
dims_to_set.push_back(dims);
} else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_ig_names[i]);
dims_to_set.push_back(dims);
if (ctx->IsRuntime()) {
framework::Variable *in_var =
boost::get<framework::Variable *>(in_var_ptrs[i]);
framework::Variable *out_var =
boost::get<framework::Variable *>(out_var_ptrs[i]);
auto type = framework::ToVarType(in_var->Type());
if (type == framework::proto::VarType::LOD_TENSOR) {
out_var->GetMutable<LoDTensor>()->Resize(
in_var->Get<framework::LoDTensor>().dims());
} else if (type == framework::proto::VarType::SELECTED_ROWS) {
out_var->GetMutable<framework::SelectedRows>()->set_height(
in_var->Get<framework::SelectedRows>().GetCompleteDims()[0]);
} else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
PADDLE_THROW("WhileGradOp doesn't support type %d",
static_cast<int>(type));
}
} else {
framework::VarDesc *in_var =
boost::get<framework::VarDesc *>(in_var_ptrs[i]);
boost::get<framework::VarDesc *>(out_var_ptrs[i])
->SetShape(in_var->GetShape());
}
}
ctx->SetDims(names_to_set, dims_to_set);
}
};

Loading…
Cancel
Save