parent
							
								
									12e1719f96
								
							
						
					
					
						commit
						900fbb83f9
					
				| @ -0,0 +1,86 @@ | |||||||
|  | // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" | ||||||
|  | #include "paddle/fluid/framework/data_layout.h" | ||||||
|  | #include "paddle/fluid/framework/lod_tensor.h" | ||||||
|  | #include "paddle/fluid/framework/tensor_util.h" | ||||||
|  | #include "paddle/fluid/platform/enforce.h" | ||||||
|  | 
 | ||||||
|  | namespace paddle { | ||||||
|  | namespace { | ||||||
|  | bool IsPersistable(const framework::VarDesc *var) { | ||||||
|  |   if (var->Persistable() && | ||||||
|  |       var->GetType() != framework::proto::VarType::FEED_MINIBATCH && | ||||||
|  |       var->GetType() != framework::proto::VarType::FETCH_LIST) { | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  |   return false; | ||||||
|  | } | ||||||
|  | }  // namespace
 | ||||||
|  | namespace inference { | ||||||
|  | namespace analysis { | ||||||
|  | 
 | ||||||
|  | void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { | ||||||
|  |   PADDLE_ENFORCE(argument->scope_valid()); | ||||||
|  |   PADDLE_ENFORCE(argument->use_gpu_valid()); | ||||||
|  | 
 | ||||||
|  |   platform::Place place; | ||||||
|  | 
 | ||||||
|  |   // The parameters are on the cpu, therefore, synchronization is not necessary.
 | ||||||
|  |   if (!argument->use_gpu()) return; | ||||||
|  | 
 | ||||||
|  |   LOG(INFO) << "Sync params from CPU to GPU"; | ||||||
|  | 
 | ||||||
|  |   PADDLE_ENFORCE(argument->gpu_device_id_valid()); | ||||||
|  |   place = platform::CUDAPlace(argument->gpu_device_id()); | ||||||
|  | 
 | ||||||
|  |   auto *scope = argument->scope_ptr(); | ||||||
|  |   // Get the program which has been processed by several passes.
 | ||||||
|  |   analysis_program_.reset( | ||||||
|  |       new framework::ProgramDesc(argument->ir_analyzed_program())); | ||||||
|  | 
 | ||||||
|  |   const auto &global_block = analysis_program_->Block(0); | ||||||
|  | 
 | ||||||
|  |   // sync the params from cpu to gpu.
 | ||||||
|  |   for (auto &var : global_block.AllVars()) { | ||||||
|  |     if (IsPersistable(var)) { | ||||||
|  |       std::string var_name = var->Name(); | ||||||
|  |       LOG(INFO) << var_name; | ||||||
|  |       auto &t = inference::analysis::GetFromScope<framework::LoDTensor>( | ||||||
|  |           *scope, var_name); | ||||||
|  | 
 | ||||||
|  |       platform::CPUPlace cpu_place; | ||||||
|  |       framework::LoDTensor temp_tensor; | ||||||
|  |       temp_tensor.Resize(t.dims()); | ||||||
|  |       temp_tensor.mutable_data<float>(cpu_place); | ||||||
|  | 
 | ||||||
|  |       // Copy the parameter data to a tmp tensor.
 | ||||||
|  |       TensorCopySync(t, cpu_place, &temp_tensor); | ||||||
|  |       // Reallocation the space on GPU
 | ||||||
|  |       t.mutable_data<float>(place); | ||||||
|  | 
 | ||||||
|  |       // Copy parameter data to newly allocated GPU space.
 | ||||||
|  |       TensorCopySync(temp_tensor, place, &t); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | std::string IrParamsSyncAmongDevicesPass::repr() const { | ||||||
|  |   return "ir-params-sync-among-devices-pass"; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace analysis
 | ||||||
|  | }  // namespace inference
 | ||||||
|  | }  // namespace paddle
 | ||||||
| @ -0,0 +1,42 @@ | |||||||
|  | // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #pragma once | ||||||
|  | 
 | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
|  | #include "paddle/fluid/framework/scope.h" | ||||||
|  | #include "paddle/fluid/inference/analysis/analysis_pass.h" | ||||||
|  | #include "paddle/fluid/inference/analysis/helper.h" | ||||||
|  | #include "paddle/fluid/platform/place.h" | ||||||
|  | 
 | ||||||
|  | namespace paddle { | ||||||
|  | namespace inference { | ||||||
|  | namespace analysis { | ||||||
|  | 
 | ||||||
|  | /*
 | ||||||
|  |  * Sync parameter from CPU to GPU. | ||||||
|  |  */ | ||||||
|  | class IrParamsSyncAmongDevicesPass : public AnalysisPass { | ||||||
|  |  public: | ||||||
|  |   void RunImpl(Argument *argument) override; | ||||||
|  |   std::string repr() const override; | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   std::unique_ptr<framework::ProgramDesc> analysis_program_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace analysis
 | ||||||
|  | }  // namespace inference
 | ||||||
|  | }  // namespace paddle
 | ||||||
					Loading…
					
					
				
		Reference in new issue