Simple Implementation

revert-3824-remove_grad_op_type
Yu Yang 8 years ago
parent fd8df0806d
commit d7a1e40e10

@ -35,16 +35,10 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
switch (ins.size()) {
case 0:
return kEmptyVarName;
case 1:
return ins[0];
default:
PADDLE_THROW("Op %s input %s should contain only one variable", type_,
name);
return "";
}
PADDLE_ENFORCE_LE(ins.size(), 1UL,
"Op %s input %s should contain only one variable", type_,
name);
return ins.empty() ? kEmptyVarName : ins[0];
}
const std::vector<std::string>& OperatorBase::Inputs(
@ -57,16 +51,10 @@ const std::vector<std::string>& OperatorBase::Inputs(
std::string OperatorBase::Output(const std::string& name) const {
auto& outs = Outputs(name);
switch (outs.size()) {
case 0:
return kEmptyVarName;
case 1:
return outs[0];
default:
PADDLE_THROW("Op %s output %s should contain only one variable", type_,
name);
return "";
}
PADDLE_ENFORCE_LE(outs.size(), 1UL,
"Op %s output %s should contain only one variable", type_,
name);
return outs.empty() ? kEmptyVarName : outs[0];
}
const std::vector<std::string>& OperatorBase::Outputs(

@ -239,20 +239,12 @@ class InferShapeContext {
const Variable* InputVar(const std::string& name) const {
auto ipt = op_.Input(name);
if (ipt == kEmptyVarName) {
return nullptr;
} else {
return scope_.FindVar(ipt);
}
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
}
Variable* OutputVar(const std::string& name) const {
auto opt = op_.Output(name);
if (opt == kEmptyVarName) {
return nullptr;
} else {
return scope_.FindVar(opt);
}
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}
const std::vector<const Variable*> MultiInputVar(
@ -262,8 +254,8 @@ class InferShapeContext {
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return name != kEmptyVarName ? scope_.FindVar(name)
: nullptr;
return name == kEmptyVarName ? nullptr
: scope_.FindVar(name);
});
return res;
}
@ -274,8 +266,8 @@ class InferShapeContext {
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return name != kEmptyVarName ? scope_.FindVar(name)
: nullptr;
return name == kEmptyVarName ? nullptr
: scope_.FindVar(name);
});
return res;
}

Loading…
Cancel
Save