|
|
|
@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
ctx->HasInputs(kOutputs);
|
|
|
|
|
ctx->HasInputs(framework::GradVarName(kOutputs));
|
|
|
|
|
|
|
|
|
|
auto p_names = ctx->Inputs(kX);
|
|
|
|
|
auto pg_ig_names = ctx->Outputs(kXGRAD);
|
|
|
|
|
auto var_types = ctx->GetInputsVarType(kX);
|
|
|
|
|
std::vector<std::string> names_to_set;
|
|
|
|
|
std::vector<framework::DDim> dims_to_set;
|
|
|
|
|
for (size_t i = 0; i < p_names.size(); ++i) {
|
|
|
|
|
std::vector<framework::InferShapeVarPtr> in_var_ptrs =
|
|
|
|
|
ctx->GetInputVarPtrs(kX);
|
|
|
|
|
std::vector<framework::InferShapeVarPtr> out_var_ptrs =
|
|
|
|
|
ctx->GetOutputVarPtrs(kXGRAD);
|
|
|
|
|
PADDLE_ENFORCE(in_var_ptrs.size() == out_var_ptrs.size());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < in_var_ptrs.size(); ++i) {
|
|
|
|
|
if (pg_ig_names[i] == framework::kEmptyVarName) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto dims = ctx->GetInputsDim(kX)[i];
|
|
|
|
|
if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
names_to_set.push_back(pg_ig_names[i]);
|
|
|
|
|
dims_to_set.push_back(dims);
|
|
|
|
|
} else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) {
|
|
|
|
|
// not sure how to set the dim of LOD_TENSOR_ARRAY
|
|
|
|
|
names_to_set.push_back(pg_ig_names[i]);
|
|
|
|
|
dims_to_set.push_back(dims);
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
framework::Variable *in_var =
|
|
|
|
|
boost::get<framework::Variable *>(in_var_ptrs[i]);
|
|
|
|
|
framework::Variable *out_var =
|
|
|
|
|
boost::get<framework::Variable *>(out_var_ptrs[i]);
|
|
|
|
|
|
|
|
|
|
auto type = framework::ToVarType(in_var->Type());
|
|
|
|
|
if (type == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
out_var->GetMutable<LoDTensor>()->Resize(
|
|
|
|
|
in_var->Get<framework::LoDTensor>().dims());
|
|
|
|
|
} else if (type == framework::proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
out_var->GetMutable<framework::SelectedRows>()->set_height(
|
|
|
|
|
in_var->Get<framework::SelectedRows>().GetCompleteDims()[0]);
|
|
|
|
|
} else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
|
|
|
|
|
PADDLE_THROW("WhileGradOp doesn't support type %d",
|
|
|
|
|
static_cast<int>(type));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
framework::VarDesc *in_var =
|
|
|
|
|
boost::get<framework::VarDesc *>(in_var_ptrs[i]);
|
|
|
|
|
boost::get<framework::VarDesc *>(out_var_ptrs[i])
|
|
|
|
|
->SetShape(in_var->GetShape());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ctx->SetDims(names_to_set, dims_to_set);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|