support distributed GatherV2 operator

pull/221/head
c00425699 5 years ago
parent d949c17a7e
commit c8cdb6b331

@ -623,5 +623,34 @@ double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp
Shape input0_slice_shape = input0.slice_shape();
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * DROPOUT_COST_RATE;
}
// return the per device communication cost in the forward phase.
double GatherV2Cost::GetForwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
const int32_t&) const {
// GatherV2Cost does not need communication in the forward phase
return 0.0;
}
// return the per device communication cost in the backward phase.
double GatherV2Cost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
const int32_t&) const {
// GatherV2Cost does not need communication in the backward phase
return 0.0;
}
double GatherV2Cost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
const int32_t&) const {
// In forward phase, the computation cost = slice(A) + slice(B)
Shape input0_slice_shape = inputs[0].slice_shape();
Shape input1_slice_shape = inputs[1].slice_shape();
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
return result;
}
double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
const int32_t&) const {
return 0.0;
}
} // namespace parallel
} // namespace mindspore

@ -81,6 +81,8 @@ class OperatorCost {
std::vector<size_t> outputs_type_lengths_;
};
using OperatorCostPtr = std::shared_ptr<OperatorCost>;
class MatMulCost : public OperatorCost {
public:
MatMulCost() = default;
@ -525,6 +527,31 @@ class DropOutCost : public OperatorCost {
};
using DropOutCostPtr = std::shared_ptr<DropOutCost>;
class GatherV2Cost : public OperatorCost {
public:
GatherV2Cost() = default;
~GatherV2Cost() override = default;
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override {
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
}
double GetForwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override;
double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override;
double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override {
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
}
double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t& stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
const int32_t&) const override;
};
using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>;
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_

@ -228,26 +228,6 @@ void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
}
}
void GatherV2Info::ReComputeBatchSplitFlagList() {
MS_ASSERT(inputs_shape_.size() == 2);
MS_ASSERT(input_value_.size() == 3);
MS_ASSERT(input_value_[0] == nullptr);
// the second input is the index tensor
MS_ASSERT(input_value_[1] != nullptr);
// the third input is the axis
MS_ASSERT(input_value_[2] != nullptr);
int axis = GetValue<int>(input_value_[2]);
MS_ASSERT(axis < inputs_shape_[0].size() && axis >= 0 - inputs_shape_[0].size());
if (axis < 0) {
axis += SizeToInt(inputs_shape_[0].size());
}
split_flag_list_[0] = true;
// if gather axis is 0, the index's strategy is equal to device number
if (axis == 0) {
split_flag_list_[1] = true;
}
}
Status BatchParallelInfo::InferAsLossDivisor() {
as_loss_divisor_ = 1;
return SUCCESS;

@ -62,15 +62,6 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
void ReComputeBatchSplitFlagList() override;
};
class GatherV2Info : public BatchParallelInfo {
public:
GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs) {}
~GatherV2Info() override = default;
void ReComputeBatchSplitFlagList() override;
};
} // namespace parallel
} // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "parallel/auto_parallel/operator_costmodel.h"
#include "parallel/ops_info/operator_info.h"
#include "parallel/strategy.h"
namespace mindspore {
namespace parallel {
constexpr size_t GATHER_V2_INPUTS_SIZE = 2;
constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1;
constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3;
// We now supported limited parallel strategies.
// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of
// the input.
// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1.
class GatherV2Info : public OperatorInfo {
public:
GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2Cost>()),
axis_(-1),
index_size_(0),
axis_strategy_(1) {}
~GatherV2Info() override = default;
Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override;
Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
protected:
Status CheckStrategy(const StrategyPtr& strategy) override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;
private:
Status InferTensorSubOps();
int32_t axis_;
size_t index_size_;
int32_t axis_strategy_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_

@ -112,6 +112,7 @@ void OperatorInfo::ResetQueueMember() {
dev_matrix_shape_.clear();
forward_op_.clear();
mirror_ops_.clear();
sub_ops_.clear();
replace_op_.clear();
replace_op_info_.clear();
virtual_div_op_.clear();

@ -41,6 +41,7 @@ namespace mindspore {
namespace parallel {
using ForwardOp = OperatorVector;
using MirrorOps = std::vector<OperatorVector>;
using Ops = std::vector<OperatorVector>;
using VirtualDivOp = OperatorVector;
using TensorMaps = std::vector<std::vector<int32_t>>;
using TensorLayouts = std::vector<TensorLayout>;
@ -99,6 +100,7 @@ class OperatorInfo {
OutPutInfoVector replace_op_info() const { return replace_op_info_; }
virtual ReplaceGraphPtr replace_graph(const CNodePtr&) { return replace_graph_; }
MirrorOps mirror_ops() const { return mirror_ops_; }
Ops sub_ops() const { return sub_ops_; }
VirtualDivOp virtual_div_op() const { return virtual_div_op_; }
Shape dev_matrix_shape() const { return dev_matrix_shape_; }
std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; }
@ -190,6 +192,7 @@ class OperatorInfo {
TensorMaps inputs_tensor_map_;
TensorMaps outputs_tensor_map_;
ForwardOp forward_op_;
Ops sub_ops_;
ForwardOp replace_op_;
OutPutInfoVector replace_op_info_;
ReplaceGraphPtr replace_graph_;

@ -24,6 +24,7 @@
#include "parallel/ops_info/comparison_function_info.h"
#include "parallel/ops_info/dropout_do_mask_info.h"
#include "parallel/ops_info/elementary_function_info.h"
#include "parallel/ops_info/gather_v2_info.h"
#include "parallel/ops_info/get_next_info.h"
#include "parallel/ops_info/l2_normalize_info.h"
#include "parallel/ops_info/loss_info.h"

@ -464,6 +464,14 @@ void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) {
MS_EXCEPTION_IF_NULL(func_graph);
Operator op = CreateGetTensorSliceOp(tensor_layout);
InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR);
if (!op_info->sub_ops().empty()) {
auto sub_ops = op_info->sub_ops();
for (size_t i = 0; i < sub_ops.size(); i++) {
if (!sub_ops.at(i).empty()) {
InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB);
}
}
}
}
void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) {

@ -29,6 +29,8 @@ from mindspore.nn import Dense, Cell
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
device_number = 32
batch_size_per_device = 128
class Dataset():
@ -57,15 +59,22 @@ class Dataset():
class GatherV2(_Loss):
def __init__(self, batchsize):
def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2, self).__init__()
self.pow = P.Pow()
emb_list = list(range(batchsize))
emb1_list = emb_list[0::2]
emb2_list = emb_list[1::2]
emb1_list = 21
emb2_list = 2
if index_dim == 1:
emb_list = list(range(index_size))
emb1_list = emb_list[0::2]
emb2_list = emb_list[1::2]
if index_dim == 2:
emb_list = np.arange(index_size*16)
emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), 16))
emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), 16))
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
self.gatherv2 = P.GatherV2()
self.gatherv2 = P.GatherV2().set_strategy(strategy)
def construct(self, nembeddings):
emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
@ -73,10 +82,6 @@ class GatherV2(_Loss):
return self.pow((emb1 - emb2), 2.0)
def get_loss(batchsize):
return GatherV2(batchsize)
def fc_with_initialize(input_channels, out_channels):
return Dense(input_channels, out_channels)
@ -114,26 +119,23 @@ class TrainOneStepCell(Cell):
return F.depend(loss, self.optimizer(grads))
def test_trains():
def net_trains(gather_v2_strategy, criterion, rank):
init()
lr = 0.1
momentum = 0.9
max_epoch = 20
device_number = 32
batch_size_per_device = 128
input_channels = 256
out_channels = 512
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number,
global_rank=rank)
predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32)
dataset = Dataset(predict, 4)
network = fc_with_initialize(input_channels, out_channels)
network.set_train()
criterion = get_loss(batch_size_per_device * device_number)
train_network = BuildTrainNetwork(network, criterion)
train_network.set_train()
opt = Momentum(train_network.trainable_params(), lr, momentum)
@ -143,5 +145,90 @@ def test_trains():
model.train(max_epoch, dataset, dataset_sink_mode=False)
context.reset_auto_parallel_context()
if __name__ == "__main__":
test_trains()
def test_auto_batch_parallel():
gather_v2_strategy = None
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_2d_index_auto_batch_parallel():
gather_v2_strategy = None
criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_batch_parallel():
gather_v2_strategy = ((device_number, 1),)
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_strategy1():
gather_v2_strategy = ((16, 2),)
rank = 2
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
net_trains(gather_v2_strategy, criterion, rank)
def test_strategy2():
gather_v2_strategy = ((1, device_number),)
rank = 2
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
net_trains(gather_v2_strategy, criterion, rank)
def test_strategy3():
gather_v2_strategy = ((8, 1),)
rank = 2
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
net_trains(gather_v2_strategy, criterion, rank)
class GatherV2Axis1(_Loss):
def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2Axis1, self).__init__()
self.pow = P.Pow()
emb1_list = 21
emb2_list = 2
if index_dim == 1:
emb_list = list(range(index_size))
emb1_list = emb_list[0::2]
emb2_list = emb_list[1::2]
if index_dim == 2:
emb_list = np.arange(index_size*index_size)
emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), index_size))
emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), index_size))
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
self.gatherv2 = P.GatherV2().set_strategy(strategy)
def construct(self, nembeddings):
emb1 = self.gatherv2(nembeddings, self.emb1_param, 1)
emb2 = self.gatherv2(nembeddings, self.emb2_param, 1)
return self.pow((emb1 - emb2), 2.0)
def test_axis1_auto_batch_parallel():
gather_v2_strategy = None
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_axis1_batch_parallel():
gather_v2_strategy = ((device_number, 1),)
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
rank = 2
net_trains(gather_v2_strategy, criterion, rank)
def test_axis1_strategy1():
gather_v2_strategy = ((16, 2),)
rank = 17
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
net_trains(gather_v2_strategy, criterion, rank)

Loading…
Cancel
Save