make unit test of backward_test pass.

revert-3824-remove_grad_op_type
qingqing01 8 years ago
parent 8810490570
commit dfb4ea764b

@ -25,7 +25,7 @@ template <typename Map, typename T>
static void ForEachVarName(Map& names, T callback) {
for (auto& name : names) {
for (auto& n : name.second) {
if (callback(n)) break;
if (callback(n)) return;
}
}
}
@ -33,12 +33,12 @@ static void ForEachVarName(Map& names, T callback) {
static bool AllInSet(
const std::unordered_map<std::string, std::vector<std::string>>& names,
const std::string& suffix, const std::unordered_set<std::string>& set) {
bool ret_val = true;
ForEachVarName(names, [&ret_val, &set, &suffix](const std::string& n) {
ret_val = set.find(n + suffix) == set.end();
return !ret_val;
bool all_in_set = true;
ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) {
all_in_set = set.find(n + suffix) != set.end();
return !all_in_set;
});
return ret_val;
return all_in_set;
}
static std::shared_ptr<OperatorBase> NOP() {

File diff suppressed because it is too large Load Diff

@ -43,7 +43,7 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
const std::string& OperatorBase::Input(const std::string& name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have output %s", type_,
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have input %s", type_,
name);
PADDLE_ENFORCE_EQ(it->second.size(), 1UL,
"Op %s input %s should contain only one variable", type_,

Loading…
Cancel
Save