|
|
|
@ -15,6 +15,7 @@ limitations under the License. */
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <atomic>
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/data_transform.h"
|
|
|
|
|
#include "paddle/framework/executor.h"
|
|
|
|
|
#include "paddle/framework/lod_tensor_array.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope,
|
|
|
|
|
expected_kernel_key);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel_iter->second->Compute(ctx);
|
|
|
|
|
if (actual_kernel_key == expected_kernel_key) {
|
|
|
|
|
kernel_iter->second->Compute(ctx);
|
|
|
|
|
} else {
|
|
|
|
|
Scope& op_scope = scope.NewScope();
|
|
|
|
|
auto input_vars = this->InputVars();
|
|
|
|
|
for (auto var_name : input_vars) {
|
|
|
|
|
op_scope.Var(var_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
|
|
|
|
|
platform::DeviceContext* trans_dev_ctx = nullptr;
|
|
|
|
|
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
|
|
|
|
|
|
|
|
|
|
// TODO(qijun) get appropriate DataTransformFN from global map
|
|
|
|
|
framework::DataTransformFN trans_fun = nullptr;
|
|
|
|
|
|
|
|
|
|
// Wait for transform starting
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
|
|
|
|
|
for (auto var_name : input_vars) {
|
|
|
|
|
trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
|
|
|
|
|
op_scope.FindVar(var_name));
|
|
|
|
|
}
|
|
|
|
|
// Wait for data transform finishing
|
|
|
|
|
for (auto ctx : trans_dev_ctx_vec) {
|
|
|
|
|
ctx->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create a new ExecutionContext
|
|
|
|
|
ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
|
|
|
|
|
kernel_iter->second->Compute(op_ctx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpKernelType OperatorWithKernel::GetActualKernelType(
|
|
|
|
|