|
|
|
@ -17,6 +17,7 @@
|
|
|
|
|
#include <list>
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
|
#include "paddle/operators/recurrent_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -178,6 +179,22 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
return false;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// process recurrent gradient op as a special operator.
|
|
|
|
|
if (forwardOp.Type() == "recurrent_op") {
|
|
|
|
|
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
|
|
|
|
|
// this will result in infinite loop.
|
|
|
|
|
const auto& rnnop =
|
|
|
|
|
*static_cast<const operators::RecurrentOp*>(&forwardOp);
|
|
|
|
|
auto rnn_grad_op =
|
|
|
|
|
static_cast<operators::RecurrentGradientOp*>(grad_op.get());
|
|
|
|
|
const auto& stepnet_op =
|
|
|
|
|
*static_cast<const OperatorBase*>(&rnnop.stepnet());
|
|
|
|
|
// create stepnet's gradient op
|
|
|
|
|
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
|
|
|
|
|
rnn_grad_op->set_stepnet(
|
|
|
|
|
std::static_pointer_cast<operators::NetOp>(grad_stepnet));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (net->ops_.empty()) { // Current no aux op is added to network
|
|
|
|
|
return grad_op;
|
|
|
|
|
}
|
|
|
|
|