|
|
|
@ -41,7 +41,9 @@ void ThreadSafeNameSet::Insert(const std::string& name) {
|
|
|
|
|
void ThreadSafeNameSet::Remove(const std::string& name) {
|
|
|
|
|
std::lock_guard<std::mutex> guard(mtx_);
|
|
|
|
|
auto iter = set_.find(name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(iter != set_.end(), true, "%s does not exist", name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
iter != set_.end(), true,
|
|
|
|
|
platform::errors::NotFound("Variable name %s does not exist", name));
|
|
|
|
|
set_.erase(iter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -54,48 +56,6 @@ ThreadSafeNameSet VarBase::name_set_;
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> VarBase::AliveVarNames() { return name_set_.Names(); }
|
|
|
|
|
|
|
|
|
|
static framework::VariableNameMap CreateVarNameMap(
|
|
|
|
|
const framework::OpInfo& op_info, const std::string& op_type,
|
|
|
|
|
const NameVarBaseMap& varbase_map, bool is_input) {
|
|
|
|
|
if (op_info.proto_ == nullptr) {
|
|
|
|
|
framework::VariableNameMap result;
|
|
|
|
|
|
|
|
|
|
for (auto& it : varbase_map) {
|
|
|
|
|
auto& var_vector = it.second;
|
|
|
|
|
std::vector<std::string> args;
|
|
|
|
|
args.reserve(var_vector.size());
|
|
|
|
|
for (auto& var_base : var_vector) {
|
|
|
|
|
args.emplace_back(var_base->Name());
|
|
|
|
|
}
|
|
|
|
|
result[it.first] = std::move(args);
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::VariableNameMap result;
|
|
|
|
|
|
|
|
|
|
for (auto& var :
|
|
|
|
|
is_input ? op_info.Proto().inputs() : op_info.Proto().outputs()) {
|
|
|
|
|
auto it = varbase_map.find(var.name());
|
|
|
|
|
if (it == varbase_map.end()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
var.dispensable(), true,
|
|
|
|
|
"Var: %s not dispensable and there are no such var in inputs",
|
|
|
|
|
var.name());
|
|
|
|
|
result[var.name()] = {};
|
|
|
|
|
} else {
|
|
|
|
|
auto& var_vector = it->second;
|
|
|
|
|
std::vector<std::string> args;
|
|
|
|
|
args.reserve(var_vector.size());
|
|
|
|
|
for (auto& var_base : var_vector) {
|
|
|
|
|
args.emplace_back(var_base->Name());
|
|
|
|
|
}
|
|
|
|
|
result[var.name()] = std::move(args);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static framework::RuntimeContext PrepareRuntimeContext(
|
|
|
|
|
const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
|
|
|
|
|
framework::VariableValueMap inputs, outputs;
|
|
|
|
@ -323,7 +283,9 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
|
|
|
|
|
const framework::AttributeMap& attrs,
|
|
|
|
|
const platform::Place& place) {
|
|
|
|
|
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
op_kernel, platform::errors::PermissionDenied(
|
|
|
|
|
"Only support operator with kernel in Dygraph mode."));
|
|
|
|
|
auto& info = op.Info();
|
|
|
|
|
if (info.infer_var_type_) {
|
|
|
|
|
RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, outs, attrs);
|
|
|
|
|