diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc index c11fb1b1d2..9f5f2abfd2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -48,6 +48,10 @@ void SparseApplyFtrlPSKernel::InitKernel( if (grad_shape[0] != indices_size_) { MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; } + init_accum_ = AnfAlgo::GetNodeAttr(cnode, "init_accum"); + if (init_accum_ < 0) { + MS_LOG(EXCEPTION) << "init_accum should be a non-negative scalar"; + } lr_ = AnfAlgo::GetNodeAttr(cnode, "lr"); if (lr_ <= 0) { MS_LOG(EXCEPTION) << "lr should be a positive scalar"; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h index 6d37dd4495..64a3c259af 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h @@ -28,7 +28,7 @@ using mindspore::kernel::SparseApplyFtrlCPUKernel; class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerKernel { public: SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) - : PServerKernel(rank_id, pserver_num, worker_num) {} + : PServerKernel(rank_id, pserver_num, worker_num), init_accum_(0.1) {} ~SparseApplyFtrlPSKernel() override = default; void InitKernel(const CNodePtr &cnode, @@ -41,9 +41,11 @@ class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerK const std::vector &input_sizes() const override; const std::vector &output_sizes() const override; const std::vector &workspace_sizes() const override; + const float init_accum() const { return init_accum_; } protected: void ReInit(const std::vector &) override; + float init_accum_; }; } // namespace ps } // namespace kernel diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index 7f7bc6e0a0..94afc07fad 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -100,16 +100,11 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { for (size_t i = 0; i < indices_index; i++) { indice_offset += lengths[i]; } - float *incr_indice_data = values.data() + indice_offset; + int *incr_indice_data = reinterpret_cast(values.data()) + indice_offset; size_t incr_indice_size = lengths[indices_index]; size_t incr_indice_data_size = incr_indice_size * sizeof(int); - std::vector converted_indices(incr_indice_size); - for (size_t i = 0; i < incr_indice_size; i++) { - converted_indices[i] = static_cast(incr_indice_data[i]); - } - - auto ret2 = memcpy_s(accum_indices_data + indices_offset_, incr_indice_data_size, converted_indices.data(), - incr_indice_data_size); + auto ret2 = + memcpy_s(accum_indices_data + indices_offset_, incr_indice_data_size, incr_indice_data, incr_indice_data_size); if (ret2 != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; } diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc index 6da99f1422..f0a0491968 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -18,14 +18,16 @@ #include #include #include +#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" namespace mindspore { namespace parallel { namespace ps { +using mindspore::kernel::ps::SparseApplyFtrlPSKernel; OptimizerInfo *OptimizerInfoBuilder::Build(const std::shared_ptr &pserver_kernel, const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) { - OptimizerInfo *optim_info = BuildInputs(weight, keys, values, lens, inputs_shape, worker_num); + OptimizerInfo *optim_info = BuildInputs(weight, keys, values, lens, inputs_shape, worker_num, pserver_kernel); std::vector ws_sizes = pserver_kernel->workspace_sizes(); BuildWorkspaces(optim_info, ws_sizes, worker_num); BuildOutputs(optim_info, worker_num); @@ -45,7 +47,7 @@ void OptimizerInfoBuilder::BuildWorkspaces(OptimizerInfo *info, const std::vecto OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, - size_t worker_num) { + size_t worker_num, const std::shared_ptr &) { AddressPtr weight_addr = std::make_shared(); weight_addr->addr = weight->data(); weight_addr->size = weight->size() * sizeof(float); @@ -74,7 +76,7 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, - size_t worker_num) { + size_t worker_num, const std::shared_ptr &) { AddressPtr weight_addr = std::make_shared(); weight_addr->addr = weight->data(); weight_addr->size = weight->size() * sizeof(float); @@ -140,13 +142,9 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(int), std::multiplies()); AddressPtr indices = std::make_shared(); indices->addr = new int[total_indice_size * worker_num]; - std::vector converted_indices(lens[7]); size_t indices_data_size = lens[7] * sizeof(int); - float *indices_data = reinterpret_cast(epsilon->addr) + lens[5] + lens[6]; - for (int i = 0; i < lens[7]; i++) { - converted_indices[i] = static_cast(indices_data[i]); - } - ret = memcpy_s(indices->addr, indices_data_size, converted_indices.data(), indices_data_size); + int *indices_data = reinterpret_cast(epsilon->addr) + lens[5] + lens[6]; + ret = memcpy_s(indices->addr, indices_data_size, indices_data, indices_data_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } @@ -158,7 +156,8 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, - size_t worker_num) { + size_t worker_num, + const std::shared_ptr &pserver_kernel) { AddressPtr weight_addr = std::make_shared(); weight_addr->addr = weight->data(); weight_addr->size = weight->size() * sizeof(float); @@ -167,7 +166,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, accum->size = weight->size() * sizeof(float); for (size_t i = 0; i < weight->size(); i++) { float *tmp = reinterpret_cast(accum->addr); - tmp[i] = 1.0; + tmp[i] = std::dynamic_pointer_cast(pserver_kernel)->init_accum(); } AddressPtr linear = std::make_shared(); linear->addr = new float[weight->size()]; @@ -192,13 +191,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies()); AddressPtr indices = std::make_shared(); indices->addr = new int[total_indice_size * worker_num]; - std::vector converted_indices(lens[1]); size_t indices_data_size = lens[1] * sizeof(int); - float *indices_data = reinterpret_cast(values.data()) + lens[0]; - for (int i = 0; i < lens[1]; i++) { - converted_indices[i] = static_cast(indices_data[i]); - } - ret = memcpy_s(indices->addr, indices_data_size, converted_indices.data(), indices_data_size); + int *indices_data = reinterpret_cast(values.data()) + lens[0]; + ret = memcpy_s(indices->addr, indices_data_size, indices_data, indices_data_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h index a719f9dc91..21920938a2 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h @@ -38,7 +38,8 @@ class OptimizerInfoBuilder { size_t worker_num); virtual OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, - const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) = 0; + const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num, + const std::shared_ptr &pserver_kernel) = 0; virtual void BuildWorkspaces(OptimizerInfo *info, const std::vector &ws_sizes, size_t worker_num); virtual void BuildOutputs(OptimizerInfo *info, size_t worker_num) {} @@ -47,19 +48,22 @@ class OptimizerInfoBuilder { class MomentumOptimInfoBuilder : public OptimizerInfoBuilder { public: OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, - const InputsShapePtr &inputs_shape, size_t worker_num) override; + const InputsShapePtr &inputs_shape, size_t worker_num, + const std::shared_ptr &pserver_kernel) override; }; class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder { public: OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, - const InputsShapePtr &inputs_shpae, size_t worker_num) override; + const InputsShapePtr &inputs_shpae, size_t worker_num, + const std::shared_ptr &pserver_kernel) override; }; class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder { public: OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, - const InputsShapePtr &inputs_shpae, size_t worker_num) override; + const InputsShapePtr &inputs_shpae, size_t worker_num, + const std::shared_ptr &pserver_kernel) override; }; } // namespace ps } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index 70923124c1..0c4f77f962 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -571,11 +571,7 @@ void WorkerProxy::BuildSparseValue(const ::ps::SArray &lengths, const si int indice_offset = grad_offset + lengths[grad_index]; data_size = lengths[indice_index] * sizeof(T); T *indice_data = reduced_data->data() + indice_offset; - std::vector convert(lengths[indice_index]); - for (int i = 0; i < lengths[indice_index]; i++) { - convert[i] = static_cast(indices[i]); - } - ret = memcpy_s(indice_data, data_size, convert.data(), data_size); + ret = memcpy_s(indice_data, data_size, indices, data_size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index 80826394fe..1ede68bb50 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -162,6 +162,7 @@ class FTRL(Optimizer): self.sparse_opt = P.FusedSparseFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking) self._ps_pull = P.Pull() self._ps_push = P.Push("Ftrl", [0, 1, 2]) + self._ps_push.add_prim_attr("init_accum", initial_accum) self._ps_push.add_prim_attr("lr", learning_rate) self._ps_push.add_prim_attr("l1", l1) self._ps_push.add_prim_attr("l2", l2)