|
|
@ -17,16 +17,28 @@
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace framework {
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Holds all the transfer scope across the process.
|
|
|
|
std::unordered_map<size_t, Scope*>& global_transfer_data_cache() {
|
|
|
|
std::unordered_map<size_t, Scope*>& global_transfer_data_cache() {
|
|
|
|
thread_local auto* x = new std::unordered_map<size_t, Scope*>;
|
|
|
|
typedef std::unordered_map<size_t, Scope*> map_t;
|
|
|
|
|
|
|
|
thread_local std::unique_ptr<map_t> x(new map_t);
|
|
|
|
return *x;
|
|
|
|
return *x;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Holds all the transfer scope for this thread.
|
|
|
|
std::unordered_set<Scope*>& global_transfer_scope_cache() {
|
|
|
|
std::unordered_set<Scope*>& global_transfer_scope_cache() {
|
|
|
|
thread_local auto* x = new std::unordered_set<Scope*>;
|
|
|
|
typedef std::unordered_set<Scope*> set_t;
|
|
|
|
|
|
|
|
thread_local std::unique_ptr<set_t> x(new set_t);
|
|
|
|
return *x;
|
|
|
|
return *x;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Try to create a transfer scope. If one cached scope has match the
|
|
|
|
|
|
|
|
// requirement, just return that one.
|
|
|
|
|
|
|
|
// Inputs:
|
|
|
|
|
|
|
|
// @type0: the source kernel type.
|
|
|
|
|
|
|
|
// @type1: the target kernel type.
|
|
|
|
|
|
|
|
// @scope: the execution scope of this op.
|
|
|
|
|
|
|
|
// Returns: A scope used to hold the transfer data across the different kernel
|
|
|
|
|
|
|
|
// type.
|
|
|
|
Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
|
|
|
|
Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
|
|
|
|
const Scope* scope) {
|
|
|
|
const Scope* scope) {
|
|
|
|
Scope* new_scope{nullptr};
|
|
|
|
Scope* new_scope{nullptr};
|
|
|
@ -46,27 +58,5 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
|
|
|
|
return new_scope;
|
|
|
|
return new_scope;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void RemoveKidsFromTransferScopeCache(Scope* scope) {
|
|
|
|
|
|
|
|
auto it = global_transfer_scope_cache().find(scope);
|
|
|
|
|
|
|
|
if (it != global_transfer_scope_cache().end()) {
|
|
|
|
|
|
|
|
global_transfer_scope_cache().erase(it);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto* s : scope->kids()) {
|
|
|
|
|
|
|
|
auto it = global_transfer_scope_cache().find(s);
|
|
|
|
|
|
|
|
if (it != global_transfer_scope_cache().end()) {
|
|
|
|
|
|
|
|
global_transfer_scope_cache().erase(it);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// remove global transfer data cache
|
|
|
|
|
|
|
|
auto& cache = global_transfer_data_cache();
|
|
|
|
|
|
|
|
for (auto it = cache.begin(); it != cache.end();) {
|
|
|
|
|
|
|
|
if (it->second == scope)
|
|
|
|
|
|
|
|
it = cache.erase(it);
|
|
|
|
|
|
|
|
else
|
|
|
|
|
|
|
|
it++;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|