diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index b9caa7be1c..e8f0360625 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -172,6 +172,8 @@ using TransposeCost = ActivationCost; using TransposeCostPtr = std::shared_ptr; using StridedSliceCost = ActivationCost; using StridedSliceCostPtr = std::shared_ptr; +using SplitCost = ActivationCost; +using SplitCostPtr = std::shared_ptr; class SoftmaxCost : public OperatorCost { public: @@ -203,8 +205,8 @@ using PackCost = TileCost; using PackCostPtr = std::shared_ptr; using ConcatCost = TileCost; using ConcatCostPtr = std::shared_ptr; -using SplitCost = TileCost; -using SplitCostPtr = std::shared_ptr; +using BroadcastToCost = SoftmaxCost; +using BroadcastToCostPtr = std::shared_ptr; class TmpIdentityCost : public OperatorCost { public: diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index ca33986334..78dfbc2f06 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -176,6 +176,7 @@ REGISTER(SquareInfo); REGISTER(GatherV2PInfo); REGISTER(EmbeddingLookupInfo); REGISTER(TileInfo); +REGISTER(BroadcastToInfo); REGISTER(StridedSliceInfo); REGISTER(DropoutInfo); REGISTER(PackInfo); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.cc new file mode 100644 index 0000000000..4e999d01cc --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.cc @@ -0,0 +1,265 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/broadcast_to_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "pipeline/jit/resource.h" +#include "frontend/parallel/graph_util/generate_graph.h" + +namespace mindspore { +namespace parallel { +Status BroadcastToInfo::GetAttrs() { + out_shape_.clear(); + auto shape_iter = attrs_.find(SHAPE); + if (shape_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(shape_iter->second); + auto var = shape_iter->second->cast(); + if (var == nullptr) { + MS_LOG(ERROR) << name_ << ": shape format is wrong! Need ValueSequeue"; + return FAILED; + } + for (auto &ele : var->value()) { + if (!ele->isa()) { + MS_LOG(ERROR) << name_ << ": The element of shape must be int"; + return FAILED; + } + out_shape_.push_back(static_cast(GetValue(ele))); + } + } else { + MS_LOG(ERROR) << name_ << ": Can not find the shape attr"; + return FAILED; + } + if (out_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": shape cannot be empty"; + return FAILED; + } + + return SUCCESS; +} + +Status BroadcastToInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + + auto stra = strategy->GetInputDim().at(0); + auto in_shape = inputs_shape_.at(0); + for (size_t i = 0; i < stra.size(); ++i) { + if ((in_shape[i] == 1) && (stra[i] != 1)) { + MS_LOG(ERROR) << name_ << ": dimension with size 1 is not splitable."; + return FAILED; + } + } + return SUCCESS; +} + +Status BroadcastToInfo::InferDevMatrixShape() { + MS_EXCEPTION_IF_NULL(strategy_); + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << "The strategy is empty"; + return FAILED; + } + + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +Status BroadcastToInfo::InferTensorMap() { + TensorMap in_tensor_map; + TensorMap out_tensor_map; + + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << "The inputs shape is empty"; + return FAILED; + } + + int32_t size = SizeToInt(inputs_shape_[0].size()); + for (int i = 0; i < size; ++i) { + in_tensor_map.push_back(size - i - 1); + } + inputs_tensor_map_.push_back(in_tensor_map); + + size_t len_diff = outputs_shape_.at(0).size() - inputs_shape_.at(0).size(); + for (size_t i = 0; i < len_diff; ++i) { + out_tensor_map.push_back(MAP_NONE); + } + std::copy(in_tensor_map.begin(), in_tensor_map.end(), std::back_inserter(out_tensor_map)); + outputs_tensor_map_.push_back(out_tensor_map); + return SUCCESS; +} + +Status BroadcastToInfo::InferMirrorOps() { + mirror_ops_.clear(); + if (inputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty"; + return FAILED; + } + + Shape input_tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group for input failed."; + return FAILED; + } + + if (group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror group is empty."; + return SUCCESS; + } + + OperatorVector input_op; + input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(input_op); + + return SUCCESS; +} + +Status BroadcastToInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": Invalid args"; + return FAILED; + } + + TensorLayout input_layout, output_layout; + // infer tensor layout + if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; + return FAILED; + } + TensorInfo input_tensor_info(input_layout); + inputs_tensor_info_.push_back(input_tensor_info); + + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; + return FAILED; + } + TensorInfo output_tensor_info(output_layout); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status BroadcastToInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } + +Status BroadcastToInfo::GenerateStrategies(int32_t stage_id) { + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer attrs failed"; + return FAILED; + } + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + Shape input_split; + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + if (inputs_shape_[0][i] == 1) { + input_split.push_back(0); + } else { + input_split.push_back(1); + } + } + + // to generate the first input's strategy + Shapes splittable_input = {input_split}; + Shapes tmp_inputs_shape = {inputs_shape_[0]}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate strategies failed"; + return FAILED; + } + + // the others strategies are equal to the first input's strategy + for (auto &sp : sp_vector) { + if ((sp == nullptr) || sp->GetInputDim().empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is null or empty"; + return FAILED; + } + Strategys tmp_strategy; + Dimensions first_input_strategy = sp->GetInputDim()[0]; + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + tmp_strategy.push_back(first_input_strategy); + } + sp->ResetInputs(tmp_strategy); + } + + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status BroadcastToInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + GenerateGraph gen_g = GenerateGraph(); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + + Shape to_shape = outputs_tensor_info_[0].slice_shape(); + Attr attr_shape = std::make_pair(SHAPE, MakeValue(to_shape)); + OperatorAttrs attrs = {attr_shape}; + auto new_broadcast_to = gen_g.PushBack({gen_g.NewOpInst(BROADCAST_TO, attrs), gen_g.virtual_input_node()}); + std::vector> input_nodes = {std::make_pair(new_broadcast_to, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, new_broadcast_to)); + + return SUCCESS; +} + +ReplaceGraphPtr BroadcastToInfo::replace_graph(const CNodePtr &cnode) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; + } + return replace_graph_; +} + +Status BroadcastToInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status BroadcastToInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.h new file mode 100644 index 0000000000..c60b0e593d --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BROADCAST_TO_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BROADCAST_TO_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +/* + * Limitation: Dimensions with size 1 can't be splited. + */ +class BroadcastToInfo : public OperatorInfo { + public: + BroadcastToInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~BroadcastToInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status ComputeReplaceGraph(const CNodePtr &cnode); + + private: + Shape out_shape_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_BROADCAST_TO_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index 53e87e478a..7f0b83bc96 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -42,5 +42,6 @@ #include "frontend/parallel/ops_info/concat_info.h" #include "frontend/parallel/ops_info/split_info.h" #include "frontend/parallel/ops_info/pack_info.h" +#include "frontend/parallel/ops_info/broadcast_to_info.h" #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 53f36278a2..853b071c47 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -222,6 +222,7 @@ constexpr char GATHERV2[] = "GatherV2"; constexpr char SPARSE_GATHERV2[] = "SparseGatherV2"; constexpr char STRIDEDSLICE[] = "StridedSlice"; constexpr char BROADCAST[] = "Broadcast"; +constexpr char BROADCAST_TO[] = "BroadcastTo"; constexpr char SQRT[] = "Sqrt"; constexpr char ASSIGN[] = "Assign"; constexpr char GET_NEXT[] = "GetNext"; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 9442cc63be..370ace402e 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -265,7 +265,7 @@ bool IsSplittableOperator(const std::string &op_name) { LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, - EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT}; + EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/tests/ut/python/parallel/test_broadcast_to.py b/tests/ut/python/parallel/test_broadcast_to.py new file mode 100644 index 0000000000..4159c9710e --- /dev/null +++ b/tests/ut/python/parallel/test_broadcast_to.py @@ -0,0 +1,112 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor, Parameter +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore.nn import TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, weight1, strategy1=None, strategy2=None, is_parameter=True): + super(Net, self).__init__() + self.shape = (8, 48, 64) + self.broadcast = P.BroadcastTo(self.shape).shard(strategy1) + self.mul = P.Mul().shard(strategy2) + if is_parameter: + self.weight1 = Parameter(weight1, "w1") + else: + self.weight1 = weight1 + + def construct(self, x): + out = self.broadcast(self.weight1) + out = self.mul(x, out) + return out + + +class MatMulNet(nn.Cell): + def __init__(self, weight1, strategy1=None, strategy2=None, strategy3=None, is_parameter=True): + super(MatMulNet, self).__init__() + self.shape = (8, 64, 64) + self.broadcast = P.BroadcastTo(self.shape).shard(strategy1) + self.matmul = P.BatchMatMul().shard(strategy2) + self.mul = P.Mul().shard(strategy3) + if is_parameter: + self.weight1 = Parameter(weight1, "w1") + else: + self.weight1 = weight1 + + def construct(self, x1, x2): + out = self.broadcast(x2) + out = self.matmul(x1, out) + out = self.mul(out, self.weight1) + return out + + +_w1 = Tensor(np.ones([1, 48, 64]), dtype=ms.float32) +_x1 = Tensor(np.ones([8, 48, 64]), dtype=ms.float32) +_x2 = Tensor(np.ones([64, 64]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x1) + context.reset_auto_parallel_context() + + +def compile_net2(net): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x1, _x2) + context.reset_auto_parallel_context() + + +def test_BroadcastTo_parameter(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 4, 2),) + strategy2 = ((1, 4, 2), (1, 4, 2)) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_BroadcastTo_parameter_no_full(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 2, 2),) + strategy2 = ((1, 4, 2), (1, 4, 2)) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_BroadcastTo_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net(_w1) + compile_net(net) + + +def test_BroadcastTo_matmul(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 4),) + strategy2 = ((1, 1, 2), (1, 2, 4)) + strategy3 = ((1, 2, 4), (1, 2, 4)) + net = MatMulNet(_w1, strategy1, strategy2, strategy3) + compile_net2(net)