|
|
|
|
@ -45,12 +45,10 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
|
|
|
|
|
int& idx, bool is_grad) {
|
|
|
|
|
const std::vector<std::string>& src_inout =
|
|
|
|
|
src_type == IN ? src_op->inputs_ : src_op->outputs_;
|
|
|
|
|
const VarIndexMap& src_varmap = *src_op->in_out_idxs_;
|
|
|
|
|
const std::vector<int>* src_format = GetOpFormat(src_op, src_type);
|
|
|
|
|
|
|
|
|
|
std::vector<std::string>& dst_inout =
|
|
|
|
|
dst_type == IN ? dst_op->inputs_ : dst_op->outputs_;
|
|
|
|
|
VarIndexMap& dst_varmap = *dst_op->in_out_idxs_;
|
|
|
|
|
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
|
|
|
|
|
const OpProto& proto = OpRegistry::protos().at(src_op->type_);
|
|
|
|
|
const auto& src_arg_list = src_type == IN ? proto.inputs() : proto.outputs();
|
|
|
|
|
@ -59,8 +57,8 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
|
|
|
|
|
std::string src_name = arg.name();
|
|
|
|
|
std::string dst_name =
|
|
|
|
|
is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name;
|
|
|
|
|
dst_varmap[dst_name] = idx++;
|
|
|
|
|
int src_arg_idx = src_varmap.at(src_name);
|
|
|
|
|
(*dst_op->in_out_idxs_)[dst_name] = idx++;
|
|
|
|
|
int src_arg_idx = src_op->in_out_idxs_->at(src_name);
|
|
|
|
|
int src_begin =
|
|
|
|
|
src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx);
|
|
|
|
|
int src_end = src_format == nullptr ? src_arg_idx + 1
|
|
|
|
|
|