diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index 62cc4c5da3..1650ff0b21 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -125,6 +125,7 @@ REGISTER(GetNextInfo); REGISTER(NegInfo); REGISTER(BatchMatMulInfo); REGISTER(ExpandDimsInfo); +REGISTER(SqueezeInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/parallel/ops_info/activation_info.cc index 9ba3624b01..c59ca8402b 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "ir/value.h" #include "parallel/auto_parallel/costmodel.h" @@ -544,5 +545,160 @@ Status ExpandDimsInfo::InferMirrorOps() { MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); return SUCCESS; } + +Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) { + std::vector axis; + auto axis_list = value_tuple->value(); + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + Shape input_shape = inputs_shape_.at(0); + size_t input_size = input_shape.size(); + // if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1. + if (axis_list.empty()) { + for (size_t i = 0; i < input_size; ++i) { + if (input_shape[i] == 1) { + axis.push_back(i); + } + } + axis_ = MakeValue(axis)->cast(); + return SUCCESS; + } + + // convert negative axis to positive. + for (auto& dim : axis_list) { + if (!dim->isa()) { + MS_LOG(ERROR) << name_ << ": The type of axis is not int"; + return FAILED; + } + int32_t dim_value = GetValue(dim); + int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value; + axis.push_back(positive_value); + } + axis_ = MakeValue(axis)->cast(); + return SUCCESS; +} + +Status SqueezeInfo::GetAttrs() { + auto iter = attrs_.find(AXIS); + if (iter == attrs_.end()) { + MS_LOG(ERROR) << name_ << ": Can't find axis attribute."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(iter->second); + auto value_tuple = iter->second->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + InferAxis(value_tuple); + attrs_[AXIS] = axis_; + return SUCCESS; +} + +Status SqueezeInfo::InferReplaceOps(const StrategyPtr& strategy) { + Attr attr = std::make_pair(AXIS, axis_); + OperatorAttrs attrs = {attr}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + replace_op_ = {std::make_pair(SQUEEZE, args)}; + return SUCCESS; +} + +Status SqueezeInfo::InferTensorMap() { + // for example: if the shape of input is [32, 32, 1], and the axis is (2, ), + // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1] + std::vector input_tensor_map, output_tensor_map; + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + size_t size = inputs_shape_[0].size(); + std::vector axis = GetValue>(axis_); + for (size_t i = 0; i < size; ++i) { + size_t index = size - i - 1; + auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); + if (iter == axis.end()) { + output_tensor_map.push_back(SizeToInt(index)); + } + input_tensor_map.push_back(SizeToInt(index)); + } + inputs_tensor_map_.push_back(input_tensor_map); + outputs_tensor_map_.push_back(output_tensor_map); + MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) + << ", and the tensor map of output is " << ShapeToString(output_tensor_map); + + return SUCCESS; +} + +Status SqueezeInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; + return FAILED; + } + + if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; + return FAILED; + } + + Shape input_shape = inputs_shape_[0]; + Shape output_shape = outputs_shape_[0]; + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Dimensions output_strategy; + std::vector axis = GetValue>(axis_); + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); + if (iter == axis.end()) { + output_strategy.push_back(inputs_strategy[0].at(i)); + } + } + Strategys outputs_strategy = {output_strategy}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; + return FAILED; + } + + if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { + MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; + return FAILED; + } + + Shape input_slice_shape = inputs_slice_shape[0]; + Shape output_slice_shape = outputs_slice_shape[0]; + + // infer tensor layout + TensorLayout input_tensor_layout, output_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; + return FAILED; + } + + if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status SqueezeInfo::Init(const StrategyPtr& strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + } + + if (InferReplaceOps(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer replace ops failed"; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index 21774c43ee..b19e38b910 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -184,6 +184,25 @@ class ExpandDimsInfo : public ActivationOther { Strategys inputs_strategy_; Strategys outputs_strategy_; }; + +class SqueezeInfo : public ActivationOther { + public: + SqueezeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, + const PrimitiveAttrs& attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SqueezeInfo() override = default; + + protected: + Status InferAxis(const ValueTuplePtr& value_tuple); + Status GetAttrs() override; + Status InferReplaceOps(const StrategyPtr& strategy); + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status Init(const StrategyPtr& strategy) override; + + private: + ValueTuplePtr axis_; +}; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ACTIVATION_INFO_H_ +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h index daa2ad595c..8010b2890a 100644 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h @@ -116,4 +116,4 @@ class AssignSubInfo : public ArithmeticBase { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ARITHMETIC_INFO_H_ +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h index 110a9a6c38..5f51f1d0a9 100644 --- a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h +++ b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h @@ -53,4 +53,4 @@ class MaximumInfo : public ArithmeticBase { } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_COMPARISON_FUNCTION_INFO_H_ +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/parallel/ops_info/onehot_info.h index a54d8479b3..fec8d96324 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.h +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.h @@ -65,4 +65,4 @@ class OneHotInfo : public OperatorInfo { }; } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ONEHOT_INFO_H_ +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 0a6d0b0bef..1976053eff 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -47,8 +47,8 @@ using mindspore::tensor::Tensor; namespace mindspore { namespace parallel { -const std::set COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; -const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; +static const std::set COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; +static const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; // g_RefMap, for CNode B input i is a RefKey[Parameter C], // it will be one item in map with key: C, and value: (B, i) static std::map> g_RefMap; @@ -1832,7 +1832,6 @@ void ParallelCommunication(const FuncGraphPtr& root, const std::vector