|
|
|
@ -12,8 +12,9 @@
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <paddle/framework/backward.h>
|
|
|
|
|
#include <paddle/framework/net.h>
|
|
|
|
|
#include "paddle/framework/backward.h"
|
|
|
|
|
#include "paddle/framework/net.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -71,6 +72,24 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
//! TODO(dzh)
|
|
|
|
|
} else {
|
|
|
|
|
//! TODO(fjy)
|
|
|
|
|
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
|
|
|
|
|
for (std::string& grad_input : grad_op->inputs_) {
|
|
|
|
|
if (no_grad_names.count(grad_input)) {
|
|
|
|
|
std::string prefix = grad_input.substr(
|
|
|
|
|
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size());
|
|
|
|
|
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX();
|
|
|
|
|
std::vector<std::string> fill_zeros_in = {prefix};
|
|
|
|
|
std::vector<std::string> fill_zeros_out = {grad_input};
|
|
|
|
|
net.AddOp(OpRegistry::CreateOp("fill_zeros_like", fill_zeros_in,
|
|
|
|
|
fill_zeros_out, AttributeMap()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (std::string& grad_output : grad_op->output_) {
|
|
|
|
|
if (no_grad_names.count(grad_output)) {
|
|
|
|
|
grad_output = OperatorBase::EMPTY_VAR_NAME();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
net.AddOp(grad_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
net->CompleteAddOp();
|
|
|
|
|