enum ==> enum class

cblas_new
fengjiayi 8 years ago
parent 5e37872462
commit 26ab453801

@ -358,7 +358,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
3UL /* external input number */
+ 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/
//- 1UL /*ignoreGradient varable number*/
+ 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add */

@ -8,9 +8,9 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either
express or implied. See the License for the specific language governing
permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_proto.pb.h"
@ -23,10 +23,10 @@ class OpRegistry;
using VarIndexMap = std::unordered_map<std::string, int>;
enum OpArgType { IN, OUT };
enum class OpArgType { IN, OUT };
static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) {
std::string key = type == IN ? "input_format" : "output_name";
std::string key = type == OpArgType::IN ? "input_format" : "output_name";
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
@ -34,7 +34,7 @@ static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) {
static const std::vector<int>* GetOpFormat(const OperatorBase* op,
const OpArgType& type) {
std::string key = type == IN ? "input_format" : "output_name";
std::string key = type == OpArgType::IN ? "input_format" : "output_name";
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
@ -44,14 +44,15 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
const OpArgType& src_type, const OpArgType& dst_type,
int& idx, bool is_grad) {
const std::vector<std::string>& src_inout =
src_type == IN ? src_op->inputs_ : src_op->outputs_;
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
const std::vector<int>* src_format = GetOpFormat(src_op, src_type);
std::vector<std::string>& dst_inout =
dst_type == IN ? dst_op->inputs_ : dst_op->outputs_;
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const auto& src_arg_list = src_type == IN ? proto.inputs() : proto.outputs();
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) {
std::string src_name = arg.name();
@ -83,19 +84,20 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
if (GetOpFormat(op, OUT) != nullptr) {
if (GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["output_format"] = std::vector<int>({0});
}
if (GetOpFormat(op, IN) != nullptr || GetOpFormat(op, OUT) != nullptr) {
if (GetOpFormat(op, OpArgType::IN) != nullptr ||
GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["input_format"] = std::vector<int>({0});
}
grad_op->in_out_idxs_.reset(new VarIndexMap());
int in_idx = 0;
int out_idx = 0;
TransOpArg(op, grad_op, IN, IN, in_idx, false); // I
TransOpArg(op, grad_op, OUT, IN, in_idx, false); // G
TransOpArg(op, grad_op, OUT, IN, in_idx, true); // OG
TransOpArg(op, grad_op, IN, OUT, out_idx, true); // IG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG
return grad_op;
}

Loading…
Cancel
Save