|
|
@ -72,8 +72,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program,
|
|
|
|
OpVariantSet &ops = op_and_grad_op->first;
|
|
|
|
OpVariantSet &ops = op_and_grad_op->first;
|
|
|
|
OpVariantSet &grad_ops = op_and_grad_op->second;
|
|
|
|
OpVariantSet &grad_ops = op_and_grad_op->second;
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(),
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
"There are extra grad ops in the graph or program");
|
|
|
|
ops.size(), grad_ops.size(),
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"There are more grad ops than forward ops in the graph or program, "
|
|
|
|
|
|
|
|
"the number of ops is %d and the number of grad_ops is %d.",
|
|
|
|
|
|
|
|
ops.size(), grad_ops.size()));
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < program.Size(); ++i) {
|
|
|
|
for (size_t i = 1; i < program.Size(); ++i) {
|
|
|
|
auto &block = program.Block(i);
|
|
|
|
auto &block = program.Block(i);
|
|
|
@ -87,8 +91,12 @@ static void FindAllOpAndGradOp(const framework::ProgramDesc &program,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(),
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
"There are extra grad ops in the graph or program");
|
|
|
|
ops.size(), grad_ops.size(),
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"There are more grad ops than forward ops in the graph or program, "
|
|
|
|
|
|
|
|
"the number of ops is %d and the number of grad_ops is %d.",
|
|
|
|
|
|
|
|
ops.size(), grad_ops.size()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Returns GradVarName of input var names
|
|
|
|
// Returns GradVarName of input var names
|
|
|
@ -169,7 +177,11 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
fwd_input.size(), in_grads.size(),
|
|
|
|
fwd_input.size(), in_grads.size(),
|
|
|
|
"Backward input gradient number does not match forward input number.");
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
|
|
|
"Backward input gradient number does not match forward "
|
|
|
|
|
|
|
|
"input number. The number of forward input number is %d and the "
|
|
|
|
|
|
|
|
"number of backward input gradient number is %d.",
|
|
|
|
|
|
|
|
fwd_input.size(), in_grads.size()));
|
|
|
|
for (size_t i = 0; i < in_grads.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < in_grads.size(); ++i) {
|
|
|
|
if (in_grads[i] == framework::kEmptyVarName) {
|
|
|
|
if (in_grads[i] == framework::kEmptyVarName) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
@ -181,9 +193,13 @@ static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr(
|
|
|
|
auto &fwd_param = fwd_op.Inputs().at(RecurrentBase::kParameters);
|
|
|
|
auto &fwd_param = fwd_op.Inputs().at(RecurrentBase::kParameters);
|
|
|
|
auto ¶m_grads =
|
|
|
|
auto ¶m_grads =
|
|
|
|
bwd_op.Outputs().at(framework::GradVarName(RecurrentBase::kParameters));
|
|
|
|
bwd_op.Outputs().at(framework::GradVarName(RecurrentBase::kParameters));
|
|
|
|
PADDLE_ENFORCE_EQ(fwd_param.size(), param_grads.size(),
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
"Backward parameter gradient number does not match forward "
|
|
|
|
fwd_param.size(), param_grads.size(),
|
|
|
|
"parameter number.");
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
|
|
|
"Backward parameter gradient number does not match "
|
|
|
|
|
|
|
|
"forward parameter number. The number of forward parameter number is "
|
|
|
|
|
|
|
|
"%d and the number of backward parameter gradient is %d.",
|
|
|
|
|
|
|
|
fwd_param.size(), param_grads.size()));
|
|
|
|
for (size_t i = 0; i < fwd_param.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < fwd_param.size(); ++i) {
|
|
|
|
if (param_grads[i] == framework::kEmptyVarName) {
|
|
|
|
if (param_grads[i] == framework::kEmptyVarName) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
@ -241,12 +257,16 @@ void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
|
|
|
|
const OpVariant *matched_fwd_op = nullptr;
|
|
|
|
const OpVariant *matched_fwd_op = nullptr;
|
|
|
|
for (auto &fwd_op : recurrent_ops) {
|
|
|
|
for (auto &fwd_op : recurrent_ops) {
|
|
|
|
if (IsMatchedRecurrentOpAndRecurrentGradOp(fwd_op, bwd_op)) {
|
|
|
|
if (IsMatchedRecurrentOpAndRecurrentGradOp(fwd_op, bwd_op)) {
|
|
|
|
PADDLE_ENFORCE(matched_fwd_op == nullptr,
|
|
|
|
PADDLE_ENFORCE_EQ(matched_fwd_op, nullptr,
|
|
|
|
"Found multiple matched recurrent op");
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
|
|
|
"Found multiple recurrent forward op matches "
|
|
|
|
|
|
|
|
"recurrent grad op."));
|
|
|
|
matched_fwd_op = &fwd_op;
|
|
|
|
matched_fwd_op = &fwd_op;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, "Cannot find matched forward op");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
|
|
|
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
|
|
|
"Cannot find matched forward op."));
|
|
|
|
SetRecurrentOpAndRecurrentGradOpSkipVarAttr(*matched_fwd_op, bwd_op);
|
|
|
|
SetRecurrentOpAndRecurrentGradOpSkipVarAttr(*matched_fwd_op, bwd_op);
|
|
|
|
recurrent_ops.erase(*matched_fwd_op);
|
|
|
|
recurrent_ops.erase(*matched_fwd_op);
|
|
|
|
}
|
|
|
|
}
|
|
|
|