|
|
|
@ -517,6 +517,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
// do data transform
|
|
|
|
|
Scope& new_scope = scope.NewScope();
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> inplace_vars;
|
|
|
|
|
for (auto& var_name_item : this->Inputs()) {
|
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
|
auto* var = scope.FindVar(var_name);
|
|
|
|
@ -529,10 +530,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
auto out_var_names = OutputVars(true);
|
|
|
|
|
if (std::find(out_var_names.begin(), out_var_names.end(),
|
|
|
|
|
var_name) != out_var_names.end()) {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"var %s is both input and output, "
|
|
|
|
|
"does not support transform",
|
|
|
|
|
var_name);
|
|
|
|
|
inplace_vars.push_back(var_name);
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "Transform Variable " << var_name << " from "
|
|
|
|
|
<< kernel_type_for_var << " to " << expected_kernel_key;
|
|
|
|
@ -551,6 +549,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
kernel_iter->second->Compute(
|
|
|
|
|
ExecutionContext(*this, new_scope, *new_dev_ctx));
|
|
|
|
|
|
|
|
|
|
for (auto& var_name : inplace_vars) {
|
|
|
|
|
VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
|
|
|
|
|
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name));
|
|
|
|
|
auto* transformed_tensor = GetTensorFromVar(new_scope.FindVar(var_name));
|
|
|
|
|
original_tensor->ShareDataWith(*transformed_tensor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*For profiling/benchmark only*/
|
|
|
|
|
if (FLAGS_benchmark) {
|
|
|
|
|
new_dev_ctx->Wait();
|
|
|
|
|