You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							196 lines
						
					
					
						
							7.8 KiB
						
					
					
				
			
		
		
	
	
							196 lines
						
					
					
						
							7.8 KiB
						
					
					
				| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License. */
 | |
| 
 | |
| #pragma once
 | |
| #include <vector>
 | |
| #include "dgc/dgc.h"
 | |
| #include "paddle/fluid/framework/eigen.h"
 | |
| #include "paddle/fluid/memory/malloc.h"
 | |
| #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
 | |
| 
 | |
| namespace paddle {
 | |
| namespace operators {
 | |
| 
 | |
| inline float get_period_sparcity(const std::vector<float>& sparsity,
 | |
|                                  float cur_step, float rampup_steps) {
 | |
|   PADDLE_ENFORCE_GE(static_cast<int>(cur_step), 0,
 | |
|                     platform::errors::InvalidArgument(
 | |
|                         "DGC current step=%d, but it must >= 0, "
 | |
|                         "please submit issue in github",
 | |
|                         static_cast<int>(cur_step)));
 | |
| 
 | |
|   size_t idx = static_cast<int>(cur_step * sparsity.size() / rampup_steps);
 | |
|   if (idx >= sparsity.size()) {
 | |
|     idx = sparsity.size() - 1;
 | |
|   }
 | |
| 
 | |
|   PADDLE_ENFORCE_LT(
 | |
|       idx, sparsity.size(),
 | |
|       platform::errors::OutOfRange(
 | |
|           "sparsity index out of bounds. idx=%d >= sparsity.size=%d", idx,
 | |
|           sparsity.size()));
 | |
|   return sparsity[idx];
 | |
| }
 | |
| 
 | |
| template <typename DeviceContext, typename T>
 | |
| class DGCOpKernel : public framework::OpKernel<T> {
 | |
|  public:
 | |
|   void Compute(const framework::ExecutionContext& ctx) const override {
 | |
|     auto u = ctx.Input<framework::Tensor>("U");
 | |
|     auto v = ctx.Input<framework::Tensor>("V");
 | |
|     auto g = ctx.Input<framework::Tensor>("Grad");
 | |
| 
 | |
|     auto grad_out = ctx.Output<framework::Tensor>("Grad_out");
 | |
| 
 | |
|     // attrs
 | |
|     float m = ctx.Attr<float>("m");
 | |
|     bool use_nesterov = ctx.Attr<bool>("use_nesterov");
 | |
|     auto sparsity = ctx.Attr<std::vector<float>>("sparsity");
 | |
|     auto rampup_begin_step = ctx.Attr<float>("rampup_begin_step");
 | |
|     auto rampup_step = ctx.Attr<float>("rampup_step");
 | |
| 
 | |
|     // nranks
 | |
|     auto nranks_tensor = ctx.Input<framework::Tensor>("nranks");
 | |
|     const int nranks = static_cast<const int>(*nranks_tensor->data<float>());
 | |
|     PADDLE_ENFORCE_GT(nranks, 1,
 | |
|                       platform::errors::PreconditionNotMet(
 | |
|                           "DGC is not useful when num_trainers <= 1. Please "
 | |
|                           "use multi card or multi machine GPU"));
 | |
| 
 | |
|     // regularization
 | |
|     auto p = ctx.Input<framework::Tensor>("Param");
 | |
|     float regular_coeff = ctx.Attr<float>("regular_coeff");
 | |
|     int regular_type = ctx.Attr<int>("regular_type");
 | |
| 
 | |
|     auto p_e = framework::EigenVector<T>::Flatten(*p);
 | |
|     auto g_e = framework::EigenVector<T>::Flatten(*g);
 | |
|     auto grad_out_e = framework::EigenVector<T>::Flatten(*grad_out);
 | |
| 
 | |
|     auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | |
|     auto& eigen_ctx = *dev_ctx.eigen_device();
 | |
| 
 | |
|     // NOTE. In paddle, loss has divided by nranks. Because dgc_op is before
 | |
|     // allreduce, so local regular_coeff need div nranks too. But now we
 | |
|     // multi grad with nranks in dgc_op, in that case regular_coeff don't
 | |
|     // need to /nranks, can prevent precision loss. For coeff often equal
 | |
|     // with 1e-4, if nranks=32, coeff/nranks will be 3.125e-6, the numerical
 | |
|     // accuracy of coeff/nranks will be too low.
 | |
|     PADDLE_ENFORCE_EQ(regular_type >= 0 && regular_type <= 2, true,
 | |
|                       platform::errors::InvalidArgument(
 | |
|                           "DGC only support one of None|L1Decay|L2Decay "
 | |
|                           "Regularization for now."));
 | |
|     if (regular_type == 0) {
 | |
|       grad_out_e.device(eigen_ctx) = (1.0 * nranks) * g_e;
 | |
|     } else if (regular_type == 1) {
 | |
|       // L1Decay. grad = grad + coeff * sign(param)
 | |
|       grad_out_e.device(eigen_ctx) =
 | |
|           (1.0 * nranks) * g_e + regular_coeff * p_e.sign();
 | |
|     } else if (regular_type == 2) {
 | |
|       // L2Decay. grad = grad + coeff * param
 | |
|       grad_out_e.device(eigen_ctx) = (1.0 * nranks) * g_e + regular_coeff * p_e;
 | |
|     }
 | |
| 
 | |
|     // current step
 | |
|     auto current_step_tensor = ctx.Input<framework::Tensor>("current_step");
 | |
|     const float* current_step = current_step_tensor->data<float>();
 | |
| 
 | |
|     if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
 | |
|       VLOG(10) << "current_step:" << *current_step
 | |
|                << " < rampup_begin_step:" << rampup_begin_step
 | |
|                << " so does't use dgc";
 | |
|       return;
 | |
|     }
 | |
| 
 | |
|     float ratio =
 | |
|         1 - get_period_sparcity(
 | |
|                 sparsity, static_cast<float>(*current_step - rampup_begin_step),
 | |
|                 rampup_step);
 | |
|     PADDLE_ENFORCE_GE(ratio, 0.0, platform::errors::InvalidArgument(
 | |
|                                       "DGC sparsity ratio must >= 0"));
 | |
|     PADDLE_ENFORCE_LT(ratio, 1.0, platform::errors::InvalidArgument(
 | |
|                                       "DGC sparsity ratio must < 1"));
 | |
|     int k = static_cast<int>(g->numel() * ratio);
 | |
| 
 | |
|     VLOG(10) << "m:" << m << ", use_nesterov:" << use_nesterov
 | |
|              << ", rampup_begin_step:" << rampup_begin_step
 | |
|              << ", rampup_step:" << rampup_step
 | |
|              << ",  current_step:" << *current_step << ", ratio:" << ratio
 | |
|              << ", k:" << k << ", nranks:" << nranks;
 | |
| 
 | |
|     auto k_out = ctx.Output<framework::Tensor>("k");
 | |
|     T* k_out_data = k_out->data<T>();
 | |
|     *k_out_data = k;
 | |
| 
 | |
|     auto u_out = ctx.Output<framework::Tensor>("U_out");
 | |
|     auto v_out = ctx.Output<framework::Tensor>("V_out");
 | |
|     auto encode_grad_out = ctx.Output<framework::Tensor>("EncodeGrad");
 | |
|     auto gather_buff = ctx.Output<framework::Tensor>("GatherBuff");
 | |
| 
 | |
|     // FIXME(gongwb): use cublas.
 | |
|     auto u_out_e = framework::EigenVector<T>::Flatten(*u_out);
 | |
|     auto u_e = framework::EigenVector<T>::Flatten(*u);
 | |
| 
 | |
|     // calc local momentum from global momentum
 | |
|     // NOTE. If grad not multi nranks, need add below code.
 | |
|     // if (static_cast<int>(*current_step) ==
 | |
|     //     static_cast<int>(rampup_begin_step)) {
 | |
|     //   u_out_e.device(eigen_ctx) = (1.0 / nranks) * u_e;
 | |
|     // }
 | |
| 
 | |
|     if (use_nesterov) {
 | |
|       // u = m * (u + g)
 | |
|       u_out_e.device(eigen_ctx) = m * (u_e + grad_out_e);
 | |
| 
 | |
|       // v = u + v + g
 | |
|       ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
 | |
|           ctx, u, v, 0, AddFunctor<T>(), v_out);
 | |
| 
 | |
|       ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
 | |
|           ctx, g, v, 0, AddFunctor<T>(), v_out);
 | |
|     } else {
 | |
|       // u = m * u + g
 | |
|       u_out_e.device(eigen_ctx) = m * u_e + grad_out_e;
 | |
| 
 | |
|       // v = u + v
 | |
|       ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
 | |
|           ctx, u, v, 0, AddFunctor<T>(), v_out);
 | |
|     }
 | |
| 
 | |
|     T* v_out_data = v_out->mutable_data<T>(ctx.GetPlace());
 | |
|     T* u_out_data = u_out->mutable_data<T>(ctx.GetPlace());
 | |
|     T* encode_grad_out_data = encode_grad_out->mutable_data<T>(
 | |
|         framework::DDim{2 * k}, ctx.GetPlace());
 | |
|     gather_buff->mutable_data<T>(framework::DDim{2 * k * nranks},
 | |
|                                  ctx.GetPlace());
 | |
| 
 | |
|     int buf_size = paddle::communication::dgc::get_buffer_size(k);
 | |
|     auto tmp_ious_data = memory::Alloc(dev_ctx, buf_size);
 | |
|     void* buf = reinterpret_cast<void*>(tmp_ious_data->ptr());
 | |
| 
 | |
|     if (!paddle::communication::dgc::k_select(
 | |
|             static_cast<void*>(encode_grad_out_data), k, v_out_data,
 | |
|             static_cast<int>(v_out->numel()), buf, dev_ctx.stream(),
 | |
|             u_out_data)) {
 | |
|       // TODO(weihang): owner should polish this error message
 | |
|       PADDLE_THROW(platform::errors::InvalidArgument(
 | |
|           "V_out numel error, V_out numel is %d.", v_out->numel()));
 | |
|     }
 | |
| 
 | |
|     math::SetConstant<DeviceContext, T> tset;
 | |
|     tset(dev_ctx, grad_out, static_cast<T>(0));
 | |
|   }
 | |
| };
 | |
| }  // namespace operators
 | |
| }  // namespace paddle
 |