|
|
|
@ -95,12 +95,12 @@ class OperatorBase {
|
|
|
|
|
const VariableNameMap& Inputs() const { return inputs_; }
|
|
|
|
|
const VariableNameMap& Outputs() const { return outputs_; }
|
|
|
|
|
//! Get a input with argument's name described in `op_proto`
|
|
|
|
|
const std::string& Input(const std::string& name) const;
|
|
|
|
|
std::string Input(const std::string& name) const;
|
|
|
|
|
//! Get a input which has multiple variables.
|
|
|
|
|
const std::vector<std::string>& Inputs(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
//! Get a output with argument's name described in `op_proto`
|
|
|
|
|
const std::string& Output(const std::string& name) const;
|
|
|
|
|
std::string Output(const std::string& name) const;
|
|
|
|
|
//! Get an output which has multiple variables.
|
|
|
|
|
//! TODO add a vector_view to prevent memory copy.
|
|
|
|
|
const std::vector<std::string>& Outputs(const std::string& name) const;
|
|
|
|
@ -127,6 +127,10 @@ class OperatorBase {
|
|
|
|
|
// IG (Inputs Gradients)
|
|
|
|
|
VariableNameMap outputs_;
|
|
|
|
|
AttributeMap attrs_;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void GenerateTemporaryNames();
|
|
|
|
|
void CheckAllInputOutputSet() const;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Macro for define a clone method.
|
|
|
|
@ -238,11 +242,13 @@ class InferShapeContext {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const Variable* InputVar(const std::string& name) const {
|
|
|
|
|
return scope_.FindVar(op_.Input(name));
|
|
|
|
|
auto ipt = op_.Input(name);
|
|
|
|
|
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Variable* OutputVar(const std::string& name) const {
|
|
|
|
|
return scope_.FindVar(op_.Output(name));
|
|
|
|
|
auto opt = op_.Output(name);
|
|
|
|
|
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<const Variable*> MultiInputVar(
|
|
|
|
@ -250,9 +256,11 @@ class InferShapeContext {
|
|
|
|
|
auto names = op_.Inputs(name);
|
|
|
|
|
std::vector<const Variable*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(
|
|
|
|
|
names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) { return scope_.FindVar(name); });
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) {
|
|
|
|
|
return name == kEmptyVarName ? nullptr
|
|
|
|
|
: scope_.FindVar(name);
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -260,24 +268,24 @@ class InferShapeContext {
|
|
|
|
|
auto names = op_.Outputs(name);
|
|
|
|
|
std::vector<const Variable*> res;
|
|
|
|
|
res.reserve(names.size());
|
|
|
|
|
std::transform(
|
|
|
|
|
names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) { return scope_.FindVar(name); });
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[this](const std::string& name) {
|
|
|
|
|
return name == kEmptyVarName ? nullptr
|
|
|
|
|
: scope_.FindVar(name);
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T* Input(const std::string& name) const {
|
|
|
|
|
auto* var = InputVar(name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
|
|
|
|
|
return &var->Get<T>();
|
|
|
|
|
return var == nullptr ? nullptr : &var->Get<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* Output(const std::string& name) const {
|
|
|
|
|
auto var = OutputVar(name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name);
|
|
|
|
|
return var->GetMutable<T>();
|
|
|
|
|
return var == nullptr ? nullptr : var->GetMutable<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -288,10 +296,7 @@ class InferShapeContext {
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[&](const std::string& sub_name) {
|
|
|
|
|
auto var = scope_.FindVar(sub_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
var, "MultiInput(%s:%s) should not be nullptr", name,
|
|
|
|
|
sub_name);
|
|
|
|
|
return &var->Get<T>();
|
|
|
|
|
return var == nullptr ? nullptr : &var->Get<T>();
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
@ -304,10 +309,7 @@ class InferShapeContext {
|
|
|
|
|
std::transform(names.begin(), names.end(), std::back_inserter(res),
|
|
|
|
|
[&](const std::string& sub_name) {
|
|
|
|
|
auto var = scope_.FindVar(sub_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
var, "MultiOutput(%s:%s) should not be nullptr.", name,
|
|
|
|
|
sub_name);
|
|
|
|
|
return var->GetMutable<T>();
|
|
|
|
|
return var == nullptr ? nullptr : var->GetMutable<T>();
|
|
|
|
|
});
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|