|
|
|
@ -12,7 +12,8 @@ OperatorBase* GradOpCreator::Create() {
|
|
|
|
|
|
|
|
|
|
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
|
|
|
|
|
const VarIndexMap& var_map,
|
|
|
|
|
const vector<int>& format, InOutType type) {
|
|
|
|
|
const std::vector<int>& format,
|
|
|
|
|
InOutType type) {
|
|
|
|
|
int idx = var_map.at(var.name());
|
|
|
|
|
int begin_idx = format.empty() ? idx : format.at(idx);
|
|
|
|
|
int end_idx = format.empty() ? idx + 1 : format.at(idx + 1);
|
|
|
|
@ -23,11 +24,11 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
|
|
|
|
|
void GradOpCreator::BuildOpInOutArgList() {
|
|
|
|
|
const OpProto& op_proto = OpRegistry::protos().at(op_->type);
|
|
|
|
|
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op->type_));
|
|
|
|
|
const vector<int>& in_format =
|
|
|
|
|
const std::vector<int>& in_format =
|
|
|
|
|
op_->attrs_.count("input_format")
|
|
|
|
|
? op->GetAttr<std::vector<int>>("input_format")
|
|
|
|
|
: std::vector<int>();
|
|
|
|
|
const vector<int>& out_format =
|
|
|
|
|
const std::vector<int>& out_format =
|
|
|
|
|
op_->attrs_.count("output_format")
|
|
|
|
|
? op->GetAttr<std::vector<int>>("output_format")
|
|
|
|
|
: std::vector<int>();
|
|
|
|
@ -41,10 +42,11 @@ void GradOpCreator::BuildOpInOutArgList() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GradOpCreator::PushArgIntoGradOp(const OpInOutArg* arg,
|
|
|
|
|
vector<std::string>& in_out,
|
|
|
|
|
vector<int>& format, VarIndexMap* varmap,
|
|
|
|
|
int& idx, bool is_grad) {
|
|
|
|
|
void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
|
|
|
|
|
std::vector<std::string>& in_out,
|
|
|
|
|
std::vector<int>& format,
|
|
|
|
|
VarIndexMap* varmap, int& idx,
|
|
|
|
|
bool is_grad) {
|
|
|
|
|
std::string var_name = arg->proto_name_;
|
|
|
|
|
if (is_grad) {
|
|
|
|
|
var_name += OperatorBase::GRAD_VAR_SUFFIX();
|
|
|
|
@ -70,22 +72,22 @@ void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
|
|
|
|
|
VarIndexMap* grad_varmap = new VarIndexMap();
|
|
|
|
|
int in_idx = 0;
|
|
|
|
|
int out_idx = 0;
|
|
|
|
|
vector<int> in_format({0});
|
|
|
|
|
vector<int> out_format({0});
|
|
|
|
|
std::vector<int> in_format({0});
|
|
|
|
|
std::vector<int> out_format({0});
|
|
|
|
|
for (const auto& arg : arg_list_) {
|
|
|
|
|
// op_'s inputs_ and outputs_
|
|
|
|
|
if (arg->needed_in_grad_) {
|
|
|
|
|
PushArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
|
|
|
|
|
in_idx, false);
|
|
|
|
|
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
|
|
|
|
|
in_idx, false);
|
|
|
|
|
}
|
|
|
|
|
if (arg->type_ == IN) {
|
|
|
|
|
// gradients of op_'s inputs_
|
|
|
|
|
PushArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap,
|
|
|
|
|
out_idx, true);
|
|
|
|
|
AddArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap,
|
|
|
|
|
out_idx, true);
|
|
|
|
|
} else {
|
|
|
|
|
// gradients of op_'s outputs_
|
|
|
|
|
PushArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
|
|
|
|
|
in_idx, true);
|
|
|
|
|
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
|
|
|
|
|
in_idx, true);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
grad_op->attrs_["input_format"] = in_format;
|
|
|
|
|