Uniform Sampler Base Update

pull/8441/head
huangxinjing 4 years ago
parent f885f6636f
commit 2730cef047

@ -910,6 +910,21 @@ double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
return result; return result;
} }
double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs,
int64_t stage_id) const {
double result = 0.0;
Shape input0_slice_shape = inputs[0].slice_shape();
if (inputs_type_lengths_.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
<< " for UniformCandidateSampler cost";
}
result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
return result;
}
double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t stage_id) const { const std::vector<TensorInfo> &outputs, int64_t stage_id) const {
double result = 0.0; double result = 0.0;

@ -684,6 +684,38 @@ class UniqueCost : public OperatorCost {
using UniqueCostPtr = std::shared_ptr<UniqueCost>; using UniqueCostPtr = std::shared_ptr<UniqueCost>;
class UniformCandidateSamplerCost : public OperatorCost {
public:
explicit UniformCandidateSamplerCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
UniformCandidateSamplerCost() : OperatorCost(false) {}
~UniformCandidateSamplerCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
}
double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return 0;
}
double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return 0;
}
double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
}
double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t) const override {
return 0.0;
}
};
using UniformCandidateSamplerCostPtr = std::shared_ptr<UniformCandidateSamplerCost>;
class GatherV2Cost : public OperatorCost { class GatherV2Cost : public OperatorCost {
public: public:
explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}

@ -176,6 +176,7 @@ REGISTER(ExpandDimsInfo);
REGISTER(SqueezeInfo); REGISTER(SqueezeInfo);
REGISTER(SigmoidCrossEntropyWithLogitsInfo); REGISTER(SigmoidCrossEntropyWithLogitsInfo);
REGISTER(SquareInfo); REGISTER(SquareInfo);
REGISTER(UniformCandidateSamplerInfo);
REGISTER(UnsortedSegmentSumInfo); REGISTER(UnsortedSegmentSumInfo);
REGISTER(UnsortedSegmentMinInfo); REGISTER(UnsortedSegmentMinInfo);
REGISTER(GatherV2PInfo); REGISTER(GatherV2PInfo);

@ -47,6 +47,7 @@
#include "frontend/parallel/ops_info/pack_info.h" #include "frontend/parallel/ops_info/pack_info.h"
#include "frontend/parallel/ops_info/broadcast_to_info.h" #include "frontend/parallel/ops_info/broadcast_to_info.h"
#include "frontend/parallel/ops_info/unique_info.h" #include "frontend/parallel/ops_info/unique_info.h"
#include "frontend/parallel/ops_info/uniform_candidate_sampler_info.h"
#include "frontend/parallel/ops_info/reluv2_info.h" #include "frontend/parallel/ops_info/reluv2_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

@ -102,6 +102,12 @@ constexpr char END[] = "end";
constexpr char STRIDES[] = "strides"; constexpr char STRIDES[] = "strides";
constexpr char GROUP[] = "group"; constexpr char GROUP[] = "group";
constexpr char FUSION[] = "fusion"; constexpr char FUSION[] = "fusion";
constexpr char NUM_SAMPLED[] = "num_sampled";
constexpr char NUM_TRUE[] = "num_true";
constexpr char SEED[] = "seed";
constexpr char RANGE_MAX[] = "range_max";
constexpr char REMOVE_ACCIDENTAL_HITS[] = "remove_accidental_hits";
constexpr char UNIQUE_STRING[] = "unique";
constexpr char AXIS[] = "axis"; constexpr char AXIS[] = "axis";
constexpr char AXES[] = "axes"; constexpr char AXES[] = "axes";
constexpr char START[] = "start"; constexpr char START[] = "start";
@ -191,6 +197,7 @@ constexpr char DIV[] = "Div";
constexpr char REAL_DIV[] = "RealDiv"; constexpr char REAL_DIV[] = "RealDiv";
constexpr char ASSIGN_SUB[] = "AssignSub"; constexpr char ASSIGN_SUB[] = "AssignSub";
constexpr char GREATER[] = "Greater"; constexpr char GREATER[] = "Greater";
constexpr char UNIFORM_CANDIDATE_SAMPLER[] = "UniformCandidateSampler";
constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset"; constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset";
constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo"; constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo";
constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits"; constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits";

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNFORM_CANDIDATE_SAMPLER_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNFORM_CANDIDATE_SAMPLER_INFO_H_
#include <string>
#include <memory>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
constexpr size_t UNIFORM_CANDIDATE_SAMPLER_INPUTS_SIZE = 2;
class UniformCandidateSamplerInfo : public OperatorInfo {
public:
UniformCandidateSamplerInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs,
std::make_shared<UniformCandidateSamplerCost>()),
num_sampled_(0),
num_true_(0),
unique_(false),
range_max_(0),
seed_(0),
remove_accidental_hits_(false) {}
~UniformCandidateSamplerInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int64_t) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
Status InferAsLossDivisor() override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status ComputeReplaceGraph(const CNodePtr &cnode);
private:
Status GetUniformSamplerAttrBool(const std::string &argsy, bool *value);
Status GetUniformSamplerAttrInt64(const std::string &args, int64_t *value);
int64_t num_sampled_;
int64_t num_true_;
bool unique_;
int64_t range_max_;
int64_t seed_;
bool remove_accidental_hits_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNFORM_CANDIDATE_SAMPLER_INFO_H_

@ -317,7 +317,7 @@ bool IsSplittableOperator(const std::string &op_name) {
EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE, EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE}; UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER};
// clang-format on // clang-format on
auto iter = splittable_op.find(op_name); auto iter = splittable_op.find(op_name);

@ -0,0 +1,161 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor, Parameter
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self, embedding_weight, num_true, num_sampled, unique, range_max, seed, remove_accidential,
strategy1=None):
super(Net, self).__init__()
self.sampler = P.UniformCandidateSampler(num_true, num_sampled, unique, range_max, seed,
remove_accidential)
if strategy1:
self.sampler.shard(strategy1)
self.embedding_table = Parameter(embedding_weight, "embedding_weight")
self.gatherv2 = P.GatherV2()
self.reduce_sum = P.ReduceSum()
self.reduce_sum2 = P.ReduceSum()
self.reduce_sum3 = P.ReduceSum()
def construct(self, x):
out1, out2, out3 = self.sampler(x)
lookup = self.gatherv2(self.embedding_table, out1, 0)
loss = out1 - out3
loss = self.reduce_sum(loss, (0,))
loss2 = self.reduce_sum2(lookup, (0, 1))
loss3 = self.reduce_sum3(out2, (0, 1))
loss4 = loss + loss2 + loss3
return loss4
class Net2(nn.Cell):
def __init__(self, mul_weight, num_true, num_sampled, unique, range_max, seed, remove_accidential,
strategy1=None):
super(Net2, self).__init__()
self.sampler = P.UniformCandidateSampler(num_true, num_sampled, unique, range_max, seed,
remove_accidential)
self.cast = P.Cast()
self.weight = Parameter(mul_weight, "w1")
self.mul = P.Mul()
if strategy1:
self.sampler.shard(strategy1)
def construct(self, x):
x = self.mul(x, self.weight)
x = self.cast(x, ms.int32)
_, out2, _ = self.sampler(x)
return out2
_w = Tensor(np.ones([48, 16]), dtype=ms.float32)
_w1 = Tensor(np.ones([96, 64]), dtype=ms.float32)
_x = Tensor(np.ones([48, 16]), dtype=ms.int32)
def compile_net(net):
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x)
context.reset_auto_parallel_context()
def test_uniform_candidate_sampler_no_full_0d_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1),)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_no_full_1d_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4),)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_full_0d_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1),)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_full_1d_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8),)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_full_1d_unqiue_false():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8),)
net = Net(_w1, num_true=16, num_sampled=16, unique=False, range_max=20, seed=1,
remove_accidential=False, strategy1=strategy1)
compile_net(net)
def test_uniform_candidate_sampler_auto_parllel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, num_true=16, num_sampled=16, unique=False, range_max=20, seed=1,
remove_accidential=False, strategy1=None)
compile_net(net)
def test_uniform_candidate_sampler_auto_parllel_unqiue_true():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=False, strategy1=None)
compile_net(net)
def test_uniform_candidate_sampler_auto_parllel_remove_true():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w1, num_true=16, num_sampled=16, unique=True, range_max=20, seed=1,
remove_accidential=True, strategy1=None)
compile_net(net)
def test_uniform_candidate_sampler_full_1d_remove_true():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8),)
net = Net(_w1, num_true=16, num_sampled=16, unique=False, range_max=20, seed=1,
remove_accidential=True, strategy1=strategy1)
with pytest.raises(RuntimeError):
compile_net(net)
def test_uniform_candidate_sampler_as_final():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8),)
net = Net2(_w, num_true=16, num_sampled=16, unique=False, range_max=20, seed=1, remove_accidential=False,
strategy1=strategy1)
with pytest.raises(RuntimeError):
compile_net(net)
Loading…
Cancel
Save