|
|
|
@ -49,10 +49,17 @@ class WhileOp : public framework::OperatorBase {
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)),
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(Condition) of WhileOp is not found."));
|
|
|
|
|
|
|
|
|
|
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
|
|
|
|
|
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
cond.dims(), paddle::framework::make_ddim({1}),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Input(Condition) of WhileOp must be 1. But now "
|
|
|
|
|
"the Condition's shape is ",
|
|
|
|
|
cond.dims().to_str(), ".\n"));
|
|
|
|
|
|
|
|
|
|
framework::Executor executor(dev_place);
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
@ -72,7 +79,9 @@ class WhileOp : public framework::OperatorBase {
|
|
|
|
|
step_scopes->clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(step_scopes->size(), 0, "The StepScope should be empty.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(step_scopes->size(), 0,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The Output(StepScope) of WhileOp should be empty."));
|
|
|
|
|
|
|
|
|
|
bool cond_data = GetCondData(cond);
|
|
|
|
|
bool is_test = Attr<bool>("is_test");
|
|
|
|
@ -160,8 +169,10 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const override {
|
|
|
|
|
PADDLE_ENFORCE(!Attr<bool>("is_test"),
|
|
|
|
|
"GradOp is only callable when is_test is false");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
Attr<bool>("is_test"), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"WhileGradOp is only callable when is_test is false."));
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(dev_place);
|
|
|
|
@ -180,7 +191,14 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
auto inside_og_names =
|
|
|
|
|
Attr<std::vector<std::string>>("original_output_grad");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(outside_og_names.size(), inside_og_names.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(outside_og_names.size(), inside_og_names.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of original output gradient names "
|
|
|
|
|
"does not match the number of backward input "
|
|
|
|
|
"gradient names. The number of Backward input "
|
|
|
|
|
"names is %d and the numbers of original output "
|
|
|
|
|
"gradient names is %d.",
|
|
|
|
|
outside_og_names.size(), inside_og_names.size()));
|
|
|
|
|
|
|
|
|
|
for (auto cur_scope_iter = step_scopes->rbegin();
|
|
|
|
|
cur_scope_iter != step_scopes->rend(); ++cur_scope_iter) {
|
|
|
|
@ -222,11 +240,18 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
inside_array[j].set_lod(outside_array->at(j).lod());
|
|
|
|
|
inside_array[j].ShareDataWith(outside_array->at(j));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
inside_array[j].numel(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The numel of %d-th element of var %s (LoDTensorArray) "
|
|
|
|
|
"in while block must be 0, but received its numel is %d.",
|
|
|
|
|
j, inside_og_name, inside_array[j].numel()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Currently only support LoDTensor and LoDTensorArray.");
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
"Currently only support LoDTensor and LoDTensorArray in "
|
|
|
|
|
"WhileGradOp."));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true,
|
|
|
|
@ -236,7 +261,13 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
// and inputs.
|
|
|
|
|
auto &pg_ig_names = Outputs(kXGRAD);
|
|
|
|
|
auto &p_names = Inputs(kX);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pg_ig_names.size(), p_names.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(pg_ig_names.size(), p_names.size(),
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The number of names in Outputs(X@GRAD) does not "
|
|
|
|
|
"match the number of names in Inputs(X). The "
|
|
|
|
|
"number of names in Outputs(X@GRAD) is %d and "
|
|
|
|
|
"the number of names in Inputs(X) is %d.",
|
|
|
|
|
pg_ig_names.size(), p_names.size()));
|
|
|
|
|
for (size_t param_id = 0; param_id < pg_ig_names.size(); ++param_id) {
|
|
|
|
|
if (pg_ig_names[param_id] == framework::kEmptyVarName) {
|
|
|
|
|
continue; // parameter doesn't have gradient
|
|
|
|
@ -247,7 +278,9 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
// for example lookup_table_grad_op, the input(Idx) doesn't have
|
|
|
|
|
// gradient.
|
|
|
|
|
auto pg_ig_var = cur_scope.FindVar(inside_grad_name);
|
|
|
|
|
PADDLE_ENFORCE(pg_ig_var != nullptr);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
pg_ig_var, platform::errors::NotFound("Variable %s is not found.",
|
|
|
|
|
inside_grad_name));
|
|
|
|
|
if (pg_ig_var->IsType<framework::LoDTensorArray>()) {
|
|
|
|
|
auto pg_ig_lod_t_arr =
|
|
|
|
|
pg_ig_var->GetMutable<framework::LoDTensorArray>();
|
|
|
|
@ -277,13 +310,16 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
// zero gradient variable in step 0
|
|
|
|
|
if (cur_scope_iter == step_scopes->rbegin()) {
|
|
|
|
|
auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
var, platform::errors::NotFound("Variable %s is not found.",
|
|
|
|
|
inside_grad_name));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
var->IsType<framework::LoDTensorArray>() ||
|
|
|
|
|
var->IsType<LoDTensor>(),
|
|
|
|
|
"Currently the type of var only can be LoDTensorArray, "
|
|
|
|
|
"or LoDTensor, but the received var[%s] is %s.",
|
|
|
|
|
inside_grad_name, framework::ToTypeName(var->Type()));
|
|
|
|
|
true, platform::errors::InvalidArgument(
|
|
|
|
|
"Currently the type of var only can be LoDTensorArray, "
|
|
|
|
|
"or LoDTensor, but the received var[%s] is %s.",
|
|
|
|
|
inside_grad_name, framework::ToTypeName(var->Type())));
|
|
|
|
|
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
auto &inside_tensor = var->Get<framework::LoDTensor>();
|
|
|
|
@ -422,41 +458,24 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
ctx->HasOutputs(framework::GradVarName(kX));
|
|
|
|
|
ctx->HasInputs(kOutputs);
|
|
|
|
|
ctx->HasInputs(framework::GradVarName(kOutputs));
|
|
|
|
|
|
|
|
|
|
auto pg_ig_names = ctx->Outputs(kXGRAD);
|
|
|
|
|
std::vector<framework::InferShapeVarPtr> in_var_ptrs =
|
|
|
|
|
ctx->GetInputVarPtrs(kX);
|
|
|
|
|
std::vector<framework::InferShapeVarPtr> out_var_ptrs =
|
|
|
|
|
ctx->GetOutputVarPtrs(kXGRAD);
|
|
|
|
|
PADDLE_ENFORCE(in_var_ptrs.size() == out_var_ptrs.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_var_ptrs.size(), out_var_ptrs.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The size of Inputs(X) must be the same as "
|
|
|
|
|
"the size of Outputs(X@GRAD)."));
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < in_var_ptrs.size(); ++i) {
|
|
|
|
|
if (pg_ig_names[i] == framework::kEmptyVarName) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
framework::Variable *in_var =
|
|
|
|
|
boost::get<framework::Variable *>(in_var_ptrs[i]);
|
|
|
|
|
framework::Variable *out_var =
|
|
|
|
|
boost::get<framework::Variable *>(out_var_ptrs[i]);
|
|
|
|
|
|
|
|
|
|
auto type = framework::ToVarType(in_var->Type());
|
|
|
|
|
if (type == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
out_var->GetMutable<LoDTensor>()->Resize(
|
|
|
|
|
in_var->Get<framework::LoDTensor>().dims());
|
|
|
|
|
} else if (type == framework::proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
out_var->GetMutable<framework::SelectedRows>()->set_height(
|
|
|
|
|
in_var->Get<framework::SelectedRows>().GetCompleteDims()[0]);
|
|
|
|
|
} else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
|
|
|
|
|
PADDLE_THROW("WhileGradOp doesn't support type %d",
|
|
|
|
|
static_cast<int>(type));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
framework::VarDesc *in_var =
|
|
|
|
|
boost::get<framework::VarDesc *>(in_var_ptrs[i]);
|
|
|
|
|
boost::get<framework::VarDesc *>(out_var_ptrs[i])
|
|
|
|
|
->SetShape(in_var->GetShape());
|
|
|
|
|
}
|
|
|
|
|
framework::VarDesc *in_var =
|
|
|
|
|
boost::get<framework::VarDesc *>(in_var_ptrs[i]);
|
|
|
|
|
boost::get<framework::VarDesc *>(out_var_ptrs[i])
|
|
|
|
|
->SetShape(in_var->GetShape());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|