|
|
@ -287,7 +287,6 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
|
|
|
|
|
|
|
auto p_names = ctx->Inputs(kParameters);
|
|
|
|
auto p_names = ctx->Inputs(kParameters);
|
|
|
|
auto pg_names = ctx->Outputs(kParamGrads);
|
|
|
|
auto pg_names = ctx->Outputs(kParamGrads);
|
|
|
|
auto dims = ctx->GetInputsDim(kParameters);
|
|
|
|
|
|
|
|
auto var_types = ctx->GetInputsVarType(kParameters);
|
|
|
|
auto var_types = ctx->GetInputsVarType(kParameters);
|
|
|
|
std::vector<std::string> names_to_set;
|
|
|
|
std::vector<std::string> names_to_set;
|
|
|
|
std::vector<framework::DDim> dims_to_set;
|
|
|
|
std::vector<framework::DDim> dims_to_set;
|
|
|
@ -295,13 +294,14 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
if (pg_names[i] == framework::kEmptyVarName) {
|
|
|
|
if (pg_names[i] == framework::kEmptyVarName) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto dims = ctx->GetInputsElementDim(kParameters, i);
|
|
|
|
if (var_types[i] == framework::VarDesc::LOD_TENSOR) {
|
|
|
|
if (var_types[i] == framework::VarDesc::LOD_TENSOR) {
|
|
|
|
names_to_set.push_back(pg_names[i]);
|
|
|
|
names_to_set.push_back(pg_names[i]);
|
|
|
|
dims_to_set.push_back(dims[i]);
|
|
|
|
dims_to_set.push_back(dims);
|
|
|
|
} else if (var_types[i] == framework::VarDesc::LOD_TENSOR_ARRAY) {
|
|
|
|
} else if (var_types[i] == framework::VarDesc::LOD_TENSOR_ARRAY) {
|
|
|
|
// not sure how to set the dim of LOD_TENSOR_ARRAY
|
|
|
|
// not sure how to set the dim of LOD_TENSOR_ARRAY
|
|
|
|
names_to_set.push_back(pg_names[i]);
|
|
|
|
names_to_set.push_back(pg_names[i]);
|
|
|
|
dims_to_set.push_back(dims[i]);
|
|
|
|
dims_to_set.push_back(dims);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ctx->SetDims(names_to_set, dims_to_set);
|
|
|
|
ctx->SetDims(names_to_set, dims_to_set);
|
|
|
|