|
|
|
@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type,
|
|
|
|
|
const VariableNameMap& outputs,
|
|
|
|
|
const AttributeMap& attrs)
|
|
|
|
|
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
|
|
|
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
|
|
|
for (auto& output : outputs_) {
|
|
|
|
|
for (auto& output_name : output.second) {
|
|
|
|
|
if (output_name == kTempVarName) {
|
|
|
|
|
output_name += type_;
|
|
|
|
|
output_name += "@";
|
|
|
|
|
output_name += std::to_string(gUniqId.fetch_add(1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
GenerateTemporaryNames();
|
|
|
|
|
CheckAllInputOutputSet();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
|
|
|
|
@ -156,6 +148,35 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
|
|
|
|
|
return ret_val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OperatorBase::CheckAllInputOutputSet() const {
|
|
|
|
|
auto& info_map = OpInfoMap::Instance();
|
|
|
|
|
auto* op_info = info_map.GetNullable(Type());
|
|
|
|
|
if (op_info == nullptr) return;
|
|
|
|
|
|
|
|
|
|
for (auto& in : op_info->Proto().inputs()) {
|
|
|
|
|
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
|
|
|
|
|
"input %s is not set", in.name());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto& out : op_info->Proto().outputs()) {
|
|
|
|
|
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
|
|
|
|
|
"output %s is not set", out.name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OperatorBase::GenerateTemporaryNames() {
|
|
|
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
|
|
|
for (auto& output : outputs_) {
|
|
|
|
|
for (auto& output_name : output.second) {
|
|
|
|
|
if (output_name == kTempVarName) {
|
|
|
|
|
output_name += type_;
|
|
|
|
|
output_name += "@";
|
|
|
|
|
output_name += std::to_string(gUniqId.fetch_add(1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpProtoAndCheckerMaker::Validate() {
|
|
|
|
|
validated_ = true;
|
|
|
|
|
CheckNoDuplicatedInOutAttrs();
|
|
|
|
|