add parallel ops

pull/8230/head
yangzhenzhang 4 years ago
parent d79bcc923e
commit 0a79ab82ae

@ -133,6 +133,8 @@ REGISTER(LogicalAndInfo);
REGISTER(LogicalOrInfo);
REGISTER(EluInfo);
REGISTER(ReLUInfo);
REGISTER(RepeatElementsInfo);
REGISTER(TensorDotInfo);
REGISTER(ReLU6Info);
REGISTER(ReLUV2Info);
REGISTER(SoftplusInfo);

@ -307,27 +307,18 @@ Status ActivationBase::InferTensorMap() {
}
Status ActivationBase::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
Strategys outputs_strategy = {inputs_strategy.at(0)};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape.at(0);
TensorLayout input_tensor_layout;
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
TensorLayout input_tensor_layout, output_tensor_layout;
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]))) {
MS_LOG(ERROR) << name_ << ": init tensor layout failed";
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
TensorInfo input_tensor_info(input_tensor_layout);
TensorInfo output_tensor_info(output_tensor_layout);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(input_tensor_info); // the same as input
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}

@ -146,6 +146,14 @@ class ReLUInfo : public ActivationOther {
~ReLUInfo() override = default;
};
class RepeatElementsInfo : public ActivationOther {
public:
RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~RepeatElementsInfo() override = default;
};
class ReLU6Info : public ActivationOther {
public:
ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,

@ -42,6 +42,7 @@
#include "frontend/parallel/ops_info/strided_slice_info.h"
#include "frontend/parallel/ops_info/concat_info.h"
#include "frontend/parallel/ops_info/split_info.h"
#include "frontend/parallel/ops_info/tensordot_info.h"
#include "frontend/parallel/ops_info/pack_info.h"
#include "frontend/parallel/ops_info/broadcast_to_info.h"
#include "frontend/parallel/ops_info/unique_info.h"

@ -103,6 +103,7 @@ constexpr char STRIDES[] = "strides";
constexpr char GROUP[] = "group";
constexpr char FUSION[] = "fusion";
constexpr char AXIS[] = "axis";
constexpr char AXES[] = "axes";
constexpr char OUTPUT_NUM[] = "output_num";
constexpr char SPLIT_COUNT[] = "split_count";
constexpr char SPLIT_DIM[] = "split_dim";
@ -190,6 +191,8 @@ constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset";
constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo";
constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits";
constexpr char RELU[] = "ReLU";
constexpr char REPEAT_ELEMENTS[] = "RepeatElements";
constexpr char TENSOR_DOT[] = "TensorDot";
constexpr char ONEHOT[] = "OneHot";
constexpr char DROPOUT_DO_MASK[] = "DropoutDoMask";
constexpr char DROPOUT_GEN_MASK[] = "DropoutGenMask";

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TENSORDOT_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TENSORDOT_INFO_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "utils/ms_utils.h"
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
namespace mindspore {
namespace parallel {
enum AxesType {
INT_TYPE = 0,
TUPLE_TYPE,
TUPLE_TUPLE_TYPE,
};
class TensorDotInfo : public OperatorInfo {
public:
TensorDotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {}
~TensorDotInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size,
size_t input1_shape_size, StrategyPtr *sp);
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
void InferTensorMapAxesInt(const TensorMap &tensor_map_index);
void InferTensorMapAxesTuple(size_t size, const TensorMap &input_a_tensor_map, const TensorMap &tensor_map_index);
void ShowAxes();
Shape origin_dev_matrix_shape_;
AxesType axes_type_ = INT_TYPE;
int32_t axes_int_ = 1;
std::vector<int32_t> axes_tuple_;
std::vector<std::vector<int32_t>> axes_tuple_tuple_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TENSORDOT_INFO_H_

@ -276,6 +276,7 @@ std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
}
bool IsElementWiseOperator(const std::string &op_name) {
// clang-format off
static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH,
SOFTMAX, LOG_SOFTMAX, RELU,
SQRT, CAST, POW,
@ -294,7 +295,9 @@ bool IsElementWiseOperator(const std::string &op_name) {
DIVNONAN, LOGICALAND, ELU,
LOGICALOR, RELU6, SOFTPLUS,
SOFTSIGN, LESS, LESSEQUAL,
BESSELI1E, GREATEREQUAL, APPROXIMATEEQUAL};
BESSELI1E, GREATEREQUAL, APPROXIMATEEQUAL,
REPEAT_ELEMENTS};
// clang-format on
auto iter = elementwise_op.find(op_name);
return (iter != elementwise_op.end());
}
@ -313,7 +316,7 @@ bool IsSplittableOperator(const std::string &op_name) {
EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN};
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT};
// clang-format on
auto iter = splittable_op.find(op_name);

@ -3174,6 +3174,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
(void)gettimeofday(&end_time, nullptr);
uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(INFO) << "Now leaving step parallel, used time: " << time << " us";
return changes;
}

@ -814,6 +814,7 @@ class TensorDot(PrimitiveWithInfer):
raise ValueError("Axes have to be the same size/length")
if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])):
raise ValueError("Axes cannot have duplicating values")
self.add_prim_attr("axes", self.axes)
def int_to_tuple_conv(self):
"""

@ -0,0 +1,86 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.repeat = P.RepeatElements(rep=2, axis=1).shard(strategy2)
self.mul_weight = Parameter(mul_weight, "w1")
def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out = self.repeat(out)
return out
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_repeat_elements_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((16, 1, 1),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)
def test_repeat_elements_model_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 16), (1, 1, 16))
strategy2 = ((1, 1, 16),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)
def test_repeat_elements_hybrid_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4), (2, 2, 4))
strategy2 = ((2, 2, 4),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)
def test_repeat_elements_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
net = Net(_w1)
compile_net(net)
def test_repeat_elements_repeat_calc():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4), (2, 2, 4))
strategy2 = ((1, 2, 2),)
net = Net(_w1, strategy1, strategy2)
compile_net(net)
Loading…
Cancel
Save