|
|
|
@ -1,4 +1,19 @@
|
|
|
|
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/grad_op_creator.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -22,15 +37,15 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GradOpCreator::BuildOpInOutArgList() {
|
|
|
|
|
const OpProto& op_proto = OpRegistry::protos().at(op_->type);
|
|
|
|
|
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op->type_));
|
|
|
|
|
const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
|
|
|
|
|
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
|
|
|
|
|
const std::vector<int>& in_format =
|
|
|
|
|
op_->attrs_.count("input_format")
|
|
|
|
|
? op->GetAttr<std::vector<int>>("input_format")
|
|
|
|
|
? op_->GetAttr<std::vector<int>>("input_format")
|
|
|
|
|
: std::vector<int>();
|
|
|
|
|
const std::vector<int>& out_format =
|
|
|
|
|
op_->attrs_.count("output_format")
|
|
|
|
|
? op->GetAttr<std::vector<int>>("output_format")
|
|
|
|
|
? op_->GetAttr<std::vector<int>>("output_format")
|
|
|
|
|
: std::vector<int>();
|
|
|
|
|
for (const auto& var : op_proto.inputs()) {
|
|
|
|
|
arg_list_.emplace_back(
|
|
|
|
@ -46,14 +61,15 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
|
|
|
|
|
std::vector<std::string>& in_out,
|
|
|
|
|
std::vector<int>& format,
|
|
|
|
|
VarIndexMap* varmap, int& idx,
|
|
|
|
|
bool is_grad) {
|
|
|
|
|
bool is_grad) const {
|
|
|
|
|
std::string var_name = arg->proto_name_;
|
|
|
|
|
if (is_grad) {
|
|
|
|
|
var_name += OperatorBase::GRAD_VAR_SUFFIX();
|
|
|
|
|
}
|
|
|
|
|
*(varmap)[var_name] = idx++;
|
|
|
|
|
(*varmap)[var_name] = idx++;
|
|
|
|
|
size_t pre_sz = in_out.size();
|
|
|
|
|
auto base_it = arg->type == IN ? op_->inputs_.begin() : op_->outputs_.begin();
|
|
|
|
|
auto base_it =
|
|
|
|
|
arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin();
|
|
|
|
|
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_,
|
|
|
|
|
std::back_inserter(in_out));
|
|
|
|
|
if (is_grad) {
|
|
|
|
@ -96,4 +112,4 @@ void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
} // namespace paddle
|
|
|
|
|