|
|
|
@ -413,37 +413,51 @@ void OperatorWithKernel::Run(const Scope& scope,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (actual_kernel_key == expected_kernel_key) {
|
|
|
|
|
kernel_iter->second->Compute(ctx);
|
|
|
|
|
PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_,
|
|
|
|
|
"Currently, model parallelism is only supported between "
|
|
|
|
|
"CPU and other devices. For example, multi-GPU model "
|
|
|
|
|
"parallelism will failed.");
|
|
|
|
|
} else {
|
|
|
|
|
Scope& op_scope = scope.NewScope();
|
|
|
|
|
auto input_vars = this->InputVars();
|
|
|
|
|
for (auto var_name : input_vars) {
|
|
|
|
|
op_scope.Var(var_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
|
|
|
|
|
platform::DeviceContext* trans_dev_ctx = nullptr;
|
|
|
|
|
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
|
|
|
|
|
const DataTransformFn* trans_fun =
|
|
|
|
|
DataTransformFnMap::Instance().GetNullable(
|
|
|
|
|
std::make_pair(actual_kernel_key, expected_kernel_key));
|
|
|
|
|
if (trans_fun) {
|
|
|
|
|
auto input_vars = this->InputVars();
|
|
|
|
|
// TODO(qijun) filter the input vars that do not need to be transformed
|
|
|
|
|
|
|
|
|
|
// filter vars that has been transformed
|
|
|
|
|
std::vector<std::string> need_trans;
|
|
|
|
|
for (auto var_name : input_vars) {
|
|
|
|
|
auto var_name_trans =
|
|
|
|
|
var_name + framework::KernelTypeToString(expected_kernel_key);
|
|
|
|
|
if (!scope.FindVar(var_name_trans)) {
|
|
|
|
|
const_cast<Scope&>(scope).Var(var_name_trans);
|
|
|
|
|
need_trans.push_back(var_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(qijun) get appropriate DataTransformFN from global map
|
|
|
|
|
framework::DataTransformFN trans_fun = nullptr;
|
|
|
|
|
if (!need_trans.empty()) {
|
|
|
|
|
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
|
|
|
|
|
platform::DeviceContext* trans_dev_ctx = nullptr;
|
|
|
|
|
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
|
|
|
|
|
|
|
|
|
|
// Wait for transform starting
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
// Wait for transform starting
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
|
|
|
|
|
for (auto var_name : input_vars) {
|
|
|
|
|
trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
|
|
|
|
|
op_scope.FindVar(var_name));
|
|
|
|
|
}
|
|
|
|
|
// Wait for data transform finishing
|
|
|
|
|
for (auto ctx : trans_dev_ctx_vec) {
|
|
|
|
|
ctx->Wait();
|
|
|
|
|
for (auto var_name : need_trans) {
|
|
|
|
|
(*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
|
|
|
|
|
scope.FindVar(var_name + framework::KernelTypeToString(
|
|
|
|
|
expected_kernel_key)));
|
|
|
|
|
}
|
|
|
|
|
// Wait for data transform finishing
|
|
|
|
|
for (auto ctx : trans_dev_ctx_vec) {
|
|
|
|
|
ctx->Wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create a new ExecutionContext
|
|
|
|
|
ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
|
|
|
|
|
kernel_iter->second->Compute(op_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel_iter->second->Compute(ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpKernelType OperatorWithKernel::GetActualKernelType(
|
|
|
|
|