|
|
|
@ -384,6 +384,24 @@ class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
const Scope& scope_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const platform::DeviceContext* GetDeviceContext(
|
|
|
|
|
framework::KernelTypePair& kernel_pair) {
|
|
|
|
|
auto& actual_kernel_key = kernel_pair.first;
|
|
|
|
|
auto& expected_kernel_key = kernel_pair.second;
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
|
|
|
|
|
if (platform::is_gpu_place(actual_kernel_key.place_) &&
|
|
|
|
|
platform::is_cpu_place(expected_kernel_key.place_)) {
|
|
|
|
|
return pool.Get(actual_kernel_key.place_);
|
|
|
|
|
} else if (platform::is_cpu_place(actual_kernel_key.place_) &&
|
|
|
|
|
platform::is_gpu_place(expected_kernel_key.place_)) {
|
|
|
|
|
return pool.Get(expected_kernel_key.place_);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"Currently, model parallelism is only supported between CPU and CUDA");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::Run(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
|
|
|
|
@ -418,9 +436,9 @@ void OperatorWithKernel::Run(const Scope& scope,
|
|
|
|
|
"CPU and other devices. For example, multi-GPU model "
|
|
|
|
|
"parallelism will failed.");
|
|
|
|
|
} else {
|
|
|
|
|
auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
|
|
|
|
|
const DataTransformFn* trans_fun =
|
|
|
|
|
DataTransformFnMap::Instance().GetNullable(
|
|
|
|
|
std::make_pair(actual_kernel_key, expected_kernel_key));
|
|
|
|
|
DataTransformFnMap::Instance().GetNullable(kernel_pair);
|
|
|
|
|
if (trans_fun) {
|
|
|
|
|
auto input_vars = this->InputVars();
|
|
|
|
|
// TODO(qijun) filter the input vars that do not need to be transformed
|
|
|
|
@ -437,22 +455,18 @@ void OperatorWithKernel::Run(const Scope& scope,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!need_trans.empty()) {
|
|
|
|
|
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
|
|
|
|
|
platform::DeviceContext* trans_dev_ctx = nullptr;
|
|
|
|
|
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
|
|
|
|
|
auto trans_dev_ctx = GetDeviceContext(kernel_pair);
|
|
|
|
|
|
|
|
|
|
// Wait for transform starting
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
|
|
|
|
|
for (auto var_name : need_trans) {
|
|
|
|
|
(*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
|
|
|
|
|
(*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)),
|
|
|
|
|
scope.FindVar(var_name + framework::KernelTypeToString(
|
|
|
|
|
expected_kernel_key)));
|
|
|
|
|
}
|
|
|
|
|
// Wait for data transform finishing
|
|
|
|
|
for (auto ctx : trans_dev_ctx_vec) {
|
|
|
|
|
ctx->Wait();
|
|
|
|
|
}
|
|
|
|
|
trans_dev_ctx->Wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|