|
|
@ -137,6 +137,23 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RuntimeContext::RuntimeContext(const VariableNameMap& innames,
|
|
|
|
|
|
|
|
const VariableNameMap& outnames,
|
|
|
|
|
|
|
|
const Scope& scope) {
|
|
|
|
|
|
|
|
for (auto& var_name_item : innames) {
|
|
|
|
|
|
|
|
std::vector<Variable*>& input_vars = inputs[var_name_item.first];
|
|
|
|
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
|
|
|
|
input_vars.push_back(scope.FindVar(var_name));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto& var_name_item : outnames) {
|
|
|
|
|
|
|
|
std::vector<Variable*>& output_vars = outputs[var_name_item.first];
|
|
|
|
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
|
|
|
|
output_vars.push_back(scope.FindVar(var_name));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
|
|
|
|
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
|
|
|
|
VLOG(4) << place << " " << DebugStringEx(&scope);
|
|
|
|
VLOG(4) << place << " " << DebugStringEx(&scope);
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
@ -704,6 +721,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
|
|
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
const platform::Place& place) const {
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
|
|
|
RuntimeContext ctx(Inputs(), Outputs(), scope);
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
|
|
@ -717,15 +735,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
|
|
|
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
|
|
|
|
auto expected_kernel_key = this->GetExpectedKernelType(
|
|
|
|
// transform functions are ready.
|
|
|
|
ExecutionContext(*this, scope, *dev_ctx, ctx));
|
|
|
|
|
|
|
|
|
|
|
|
// for (auto& candidate : kKernelPriority) {
|
|
|
|
|
|
|
|
// Do selection
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto expected_kernel_key =
|
|
|
|
|
|
|
|
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
|
|
|
|
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
|
|
|
|
|
|
|
|
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
|
auto kernel_iter = kernels.find(expected_kernel_key);
|
|
|
@ -744,7 +755,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
KernelTypeToString(expected_kernel_key));
|
|
|
|
KernelTypeToString(expected_kernel_key));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
RuntimeContext ctx;
|
|
|
|
|
|
|
|
// do data transformScope &transfer_scope;
|
|
|
|
// do data transformScope &transfer_scope;
|
|
|
|
std::vector<std::string> transfered_inplace_vars;
|
|
|
|
std::vector<std::string> transfered_inplace_vars;
|
|
|
|
auto* transfer_scope =
|
|
|
|
auto* transfer_scope =
|
|
|
@ -760,7 +770,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
|
|
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
|
|
|
|
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx));
|
|
|
|
|
|
|
|
|
|
|
|
if (!transfered_inplace_vars.empty()) {
|
|
|
|
if (!transfered_inplace_vars.empty()) {
|
|
|
|
// there is inplace variable has been transfered.
|
|
|
|
// there is inplace variable has been transfered.
|
|
|
@ -784,6 +794,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::TransferInplaceVarsBack(
|
|
|
|
void OperatorWithKernel::TransferInplaceVarsBack(
|
|
|
|
const Scope& scope, const std::vector<std::string>& inplace_vars,
|
|
|
|
const Scope& scope, const std::vector<std::string>& inplace_vars,
|
|
|
|
const Scope& transfer_scope) const {
|
|
|
|
const Scope& transfer_scope) const {
|
|
|
@ -806,7 +817,6 @@ Scope* OperatorWithKernel::PrepareData(
|
|
|
|
Scope* new_scope = nullptr;
|
|
|
|
Scope* new_scope = nullptr;
|
|
|
|
for (auto& var_name_item : Inputs()) {
|
|
|
|
for (auto& var_name_item : Inputs()) {
|
|
|
|
std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first];
|
|
|
|
std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first];
|
|
|
|
input_vars.resize(var_name_item.second.size());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
|
|
|
|
auto& var_name = var_name_item.second[i];
|
|
|
|
auto& var_name = var_name_item.second[i];
|
|
|
@ -869,8 +879,6 @@ Scope* OperatorWithKernel::PrepareData(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto& var_name_item : Outputs()) {
|
|
|
|
for (auto& var_name_item : Outputs()) {
|
|
|
|
std::vector<Variable*>& output_vars = ctx->outputs[var_name_item.first];
|
|
|
|
std::vector<Variable*>& output_vars = ctx->outputs[var_name_item.first];
|
|
|
|
output_vars.resize(var_name_item.second.size());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
|
|
|
|
auto& var_name = var_name_item.second[i];
|
|
|
|
auto& var_name = var_name_item.second[i];
|
|
|
|
output_vars[i] = scope.FindVar(var_name);
|
|
|
|
output_vars[i] = scope.FindVar(var_name);
|
|
|
|