Change GatherV2 to Gather r1.1 to master

pull/11761/head
mindspore-ci-bot 4 years ago committed by l00591931
parent 8a61767f32
commit 9fa0499fa0

File diff suppressed because one or more lines are too long

@ -43,7 +43,7 @@ class GatherV2CPUKernel : public CPUKernel {
};
MS_REG_CPU_KERNEL(
GatherV2,
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherV2CPUKernel);
} // namespace kernel

@ -19,26 +19,26 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
GatherV2,
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherV2GpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2,
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
GatherV2GpuFwdKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherV2,
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherV2GpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
GatherV2,
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
GatherV2GpuFwdKernel, half, int64_t)
MS_REG_GPU_KERNEL_TWO(GatherV2,
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
@ -46,7 +46,7 @@ MS_REG_GPU_KERNEL_TWO(GatherV2,
.AddOutputAttr(kNumberTypeFloat32),
GatherV2GpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(GatherV2,
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
@ -54,7 +54,7 @@ MS_REG_GPU_KERNEL_TWO(GatherV2,
.AddOutputAttr(kNumberTypeFloat32),
GatherV2GpuFwdKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(GatherV2,
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
@ -62,7 +62,7 @@ MS_REG_GPU_KERNEL_TWO(GatherV2,
.AddOutputAttr(kNumberTypeFloat16),
GatherV2GpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(GatherV2,
MS_REG_GPU_KERNEL_TWO(Gather,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)

@ -85,8 +85,8 @@ CNodePtr CreateGatherV2Ds(const FuncGraphPtr &graph, const CNodePtr &origin_node
if (origin_node->size() != 4) {
MS_LOG(EXCEPTION) << "In dynamic shape scene, gatherv2 should have 3 inputs";
}
std::vector<AnfNodePtr> gatherv2_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimGatherV2->name())),
pad, origin_node->input(2), origin_node->input(3)};
std::vector<AnfNodePtr> gatherv2_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimGather->name())), pad,
origin_node->input(2), origin_node->input(3)};
auto gather_v2 = graph->NewCNode(gatherv2_inputs);
MS_EXCEPTION_IF_NULL(gather_v2);
gather_v2->set_scope(origin_node->scope());
@ -146,7 +146,7 @@ bool CheckInputs(const CNodePtr &origin_node) {
const BaseRef GatherV2DsFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimGatherV2, Xs});
VectorRef pattern({prim::kPrimGather, Xs});
return pattern;
}

@ -38,7 +38,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimReduceMin->name(), {1});
Register(prim::kPrimReduceSum->name(), {1});
Register(prim::kPrimReduceMean->name(), {1});
Register(prim::kPrimGatherV2->name(), {2});
Register(prim::kPrimGather->name(), {2});
Register(prim::kPrimGatherD->name(), {1});
Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5});
Register(prim::kPrimEmbeddingLookupCommGrad->name(), {1});

@ -62,7 +62,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
{prim::kPrimCast, {2}},
{prim::kPrimTranspose, {2}},
{prim::kPrimOneHot, {2}},
{prim::kPrimGatherV2, {3}},
{prim::kPrimGather, {3}},
{prim::kPrimReshape, {2}},
{prim::kPrimAssign, {1}},
{prim::kPrimAssignAdd, {1}},
@ -508,7 +508,7 @@ bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const Abstrac
abstract::AbstractTuplePtr false_branch_tuple = false_branch_abs->cast<abstract::AbstractTuplePtr>();
if (true_branch_tuple->elements().size() != false_branch_tuple->elements().size()) {
MS_LOG(ERROR) << "true branch size:" << true_branch_tuple->elements().size()
<< ", not equal to false banch size:" << false_branch_tuple->elements().size() << " ";
<< ", not equal to false branch size:" << false_branch_tuple->elements().size() << " ";
return false;
}
bool all_compatible = true;

@ -616,7 +616,7 @@ Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_pt
return s;
}
auto name = ops[incoming_op_index]->name().substr(0, pos);
if (name == "GatherV2") {
if (name == "Gather") {
return s;
} else if (name == "GatherV2P") {
return PrepareGatherV2POutputStrategy(ops, incoming_op_index);
@ -849,7 +849,7 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
if (ops[iter_ops]->type() == GATHERV2) {
auto pos = ops[iter_ops]->name().find("Info");
auto name = ops[iter_ops]->name().substr(0, pos);
if (name == "GatherV2") {
if (name == "Gather") {
return PrepareGatherV2(ops, iter_ops, basic_stra);
} else if (name == "GatherV2P") {
return PrepareGatherV2P(ops, iter_ops, basic_stra);

@ -426,7 +426,7 @@ AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNo
AnfNodePtrList gatherv2_nodes;
auto user_set = graph->manager()->node_users()[node];
for (auto &ele : user_set) {
if (IsPrimitiveCNode(ele.first, prim::kPrimGatherV2)) {
if (IsPrimitiveCNode(ele.first, prim::kPrimGather)) {
gatherv2_nodes.emplace_back(ele.first);
}
}

@ -140,7 +140,7 @@ REGISTER(ReLU6Info);
REGISTER(ReLUV2Info);
REGISTER(SoftplusInfo);
REGISTER(SoftsignInfo);
REGISTER(GatherV2Info);
REGISTER(GatherInfo);
REGISTER(SparseGatherV2Info);
REGISTER(SqrtInfo);
REGISTER(SigmoidInfo);
@ -180,7 +180,7 @@ REGISTER(UniformCandidateSamplerInfo);
REGISTER(UnsortedSegmentSumInfo);
REGISTER(UnsortedSegmentMinInfo);
REGISTER(UnsortedSegmentMaxInfo);
REGISTER(GatherV2PInfo);
REGISTER(GatherPInfo);
REGISTER(EmbeddingLookupInfo);
REGISTER(TileInfo);
REGISTER(BroadcastToInfo);

@ -30,7 +30,7 @@
namespace mindspore {
namespace parallel {
Status GatherV2Info::GetAttrs() {
Status GatherInfo::GetAttrs() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size();
return FAILED;
@ -70,7 +70,7 @@ Status GatherV2Info::GetAttrs() {
return SUCCESS;
}
Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) {
Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
@ -104,7 +104,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
Status GatherV2Info::InferDevMatrixShape() {
Status GatherInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim();
dev_matrix_shape_ = stra.at(0);
return SUCCESS;
@ -114,7 +114,7 @@ Status GatherV2Info::InferDevMatrixShape() {
// If index is a n dimension tensor, output dimension is input dimension plus (n - 1).
// Tensor map dimension is equal to the corresponding input and output dimension.
// If index's dimension is more than 1, we insert -1 for the output tensor map.
Status GatherV2Info::InferTensorMap() {
Status GatherInfo::InferTensorMap() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
@ -158,7 +158,7 @@ Status GatherV2Info::InferTensorMap() {
return SUCCESS;
}
Status GatherV2Info::InferTensorInfo() {
Status GatherInfo::InferTensorInfo() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
@ -219,7 +219,7 @@ OperatorVector CreateSubOp(int64_t sub_value) {
return ops;
}
Status GatherV2Info::InferTensorSubOps() {
Status GatherInfo::InferTensorSubOps() {
sub_ops_.clear();
if ((index_size_ == 0) || (axis_strategy_ == 1)) {
return SUCCESS;
@ -252,7 +252,7 @@ Status GatherV2Info::InferTensorSubOps() {
return SUCCESS;
}
Status GatherV2Info::Init(const StrategyPtr &strategy) {
Status GatherInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
@ -266,7 +266,7 @@ Status GatherV2Info::Init(const StrategyPtr &strategy) {
return SUCCESS;
}
Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) {
Status GatherInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
@ -275,7 +275,7 @@ Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) {
return SUCCESS;
}
Status GatherV2Info::GenerateStrategies(int64_t stage_id) {
Status GatherInfo::GenerateStrategies(int64_t stage_id) {
if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) {
MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
<< outputs_shape_.size() << "is wrong.";
@ -301,9 +301,9 @@ Status GatherV2Info::GenerateStrategies(int64_t stage_id) {
return SUCCESS;
}
Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() {
std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();

@ -36,15 +36,15 @@ constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3;
// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of
// the input.
// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1.
class GatherV2Info : public OperatorInfo {
class GatherInfo : public OperatorInfo {
public:
GatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2Cost>()),
axis_(-1),
index_size_(0),
axis_strategy_(1) {}
~GatherV2Info() override = default;
~GatherInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

@ -32,7 +32,7 @@
namespace mindspore {
namespace parallel {
Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
Status GatherPInfo::GetManualSplitWithoutOffsetAttr() {
auto manual_split_without_offset_iter = attrs_.find("manual_split");
if (manual_split_without_offset_iter != attrs_.end()) {
manual_split_ = true;
@ -68,7 +68,7 @@ Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
return SUCCESS;
}
Status GatherV2PInfo::GetManualSplitAttr() {
Status GatherPInfo::GetManualSplitAttr() {
auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
if (manual_split_with_offset_iter != attrs_.end()) {
manual_split_ = true;
@ -118,7 +118,7 @@ Status GatherV2PInfo::GetManualSplitAttr() {
return SUCCESS;
}
Status GatherV2PInfo::GetAttrs() {
Status GatherPInfo::GetAttrs() {
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
if (target_ != CPU) {
if (input_value_.at(2) == nullptr) {
@ -172,7 +172,7 @@ Status GatherV2PInfo::GetAttrs() {
return SUCCESS;
}
Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
Status GatherPInfo::CheckManualSplit(const Strategys &strategy) {
if (strategy.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
return FAILED;
@ -228,7 +228,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
return SUCCESS;
}
Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED;
}
@ -306,7 +306,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
Status GatherV2PInfo::InferMirrorOps() {
Status GatherPInfo::InferMirrorOps() {
// There is no mirror operators for manual split
if (manual_split_) {
return SUCCESS;
@ -336,7 +336,7 @@ Status GatherV2PInfo::InferMirrorOps() {
return SUCCESS;
}
Status GatherV2PInfo::InferDevMatrixShape() {
Status GatherPInfo::InferDevMatrixShape() {
dev_matrix_shape_.clear();
out_dev_matrix_shape_.clear();
// infer input dev_matrix_shape
@ -386,7 +386,7 @@ Status GatherV2PInfo::InferDevMatrixShape() {
return SUCCESS;
}
void GatherV2PInfo::InferInputsTensorMap() {
void GatherPInfo::InferInputsTensorMap() {
// infer input tensor map
// param_strategy(axis) != 1
size_t param_size = inputs_shape_.at(0).size();
@ -413,7 +413,7 @@ void GatherV2PInfo::InferInputsTensorMap() {
inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
}
void GatherV2PInfo::InferOutputsTensorMap() {
void GatherPInfo::InferOutputsTensorMap() {
// infer output tensor map
size_t param_size = inputs_shape_.at(0).size();
size_t index_size = inputs_shape_.at(1).size();
@ -460,7 +460,7 @@ void GatherV2PInfo::InferOutputsTensorMap() {
outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
}
Status GatherV2PInfo::InferTensorMap() {
Status GatherPInfo::InferTensorMap() {
if (manual_split_) {
inputs_tensor_map_.push_back({1, 0});
inputs_tensor_map_.push_back({-1, 1});
@ -472,7 +472,7 @@ Status GatherV2PInfo::InferTensorMap() {
return SUCCESS;
}
Status GatherV2PInfo::InferTensorInfo() {
Status GatherPInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
Shape input_index_shape = inputs_shape_.at(1);
@ -505,7 +505,7 @@ Status GatherV2PInfo::InferTensorInfo() {
return SUCCESS;
}
Status GatherV2PInfo::InferBias() {
Status GatherPInfo::InferBias() {
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->rank_index_in_stage();
auto input_shape = inputs_shape_.at(0);
@ -559,7 +559,7 @@ Status GatherV2PInfo::InferBias() {
return FAILED;
}
Status GatherV2PInfo::InferOffset() {
Status GatherPInfo::InferOffset() {
CheckGlobalDeviceManager();
size_t rank = g_device_manager->rank_index_in_stage();
@ -580,7 +580,7 @@ Status GatherV2PInfo::InferOffset() {
return FAILED;
}
Status GatherV2PInfo::InferGroup() {
Status GatherPInfo::InferGroup() {
auto param_strategy = strategy_->GetInputDim().at(0);
size_t dim = LongToSize(axis_);
if (param_strategy.at(LongToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
@ -610,7 +610,7 @@ Status GatherV2PInfo::InferGroup() {
return SUCCESS;
}
Status GatherV2PInfo::InferForwardCommunication() {
Status GatherPInfo::InferForwardCommunication() {
if (manual_split_) {
return SUCCESS;
}
@ -647,7 +647,7 @@ Status GatherV2PInfo::InferForwardCommunication() {
return SUCCESS;
}
Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph();
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
@ -705,7 +705,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS;
}
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) {
if (manual_split_ && target_ != CPU) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
@ -724,7 +724,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
return replace_graph_;
}
Status GatherV2PInfo::ComputeReplaceOp() {
Status GatherPInfo::ComputeReplaceOp() {
int64_t bias = 0;
if (manual_split_) {
if (InferOffset() != SUCCESS) {
@ -752,7 +752,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
return SUCCESS;
}
Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
Status GatherPInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
@ -765,7 +765,7 @@ Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
return SUCCESS;
}
Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
@ -783,9 +783,9 @@ Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
return SUCCESS;
}
Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status GatherPInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status GatherV2PInfo::GenerateStrategies(int64_t stage_id) {
Status GatherPInfo::GenerateStrategies(int64_t stage_id) {
if (GetAttrs() != SUCCESS) {
return FAILED;
}
@ -814,7 +814,7 @@ Status GatherV2PInfo::GenerateStrategies(int64_t stage_id) {
return SUCCESS;
}
std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
std::shared_ptr<Strategys> GatherPInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
}

@ -29,17 +29,17 @@
namespace mindspore {
namespace parallel {
class GatherV2PInfo : public OperatorInfo {
class GatherPInfo : public OperatorInfo {
public:
GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2)
GatherPInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()),
axis_(0),
bias_(0),
index_offset_(0),
slice_size_(0),
replace_op_name_(replace_op_name) {}
~GatherV2PInfo() override = default;
~GatherPInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
@ -85,19 +85,19 @@ class GatherV2PInfo : public OperatorInfo {
std::vector<int64_t> index_offsets_;
};
class SparseGatherV2Info : public GatherV2PInfo {
class SparseGatherV2Info : public GatherPInfo {
public:
SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2)
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {}
: GatherPInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {}
~SparseGatherV2Info() override = default;
};
class EmbeddingLookupInfo : public GatherV2PInfo {
class EmbeddingLookupInfo : public GatherPInfo {
public:
EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {}
: GatherPInfo(name, inputs_shape, outputs_shape, attrs) {}
~EmbeddingLookupInfo() override = default;
};
} // namespace parallel

@ -249,7 +249,7 @@ constexpr char MINIMUM[] = "Minimum";
constexpr char EQUAL[] = "Equal";
constexpr char NOT_EQUAL[] = "NotEqual";
constexpr char LOGICALNOT[] = "LogicalNot";
constexpr char GATHERV2[] = "GatherV2";
constexpr char GATHERV2[] = "Gather";
constexpr char SPARSE_GATHERV2[] = "SparseGatherV2";
constexpr char STRIDEDSLICE[] = "StridedSlice";
constexpr char SLICE[] = "Slice";

@ -2699,7 +2699,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
}
if (operator_info->name().find(EMBEDDING_LOOKUP) != std::string::npos ||
operator_info->name().find(GATHERV2) != std::string::npos) {
auto gatherv2_info = std::dynamic_pointer_cast<GatherV2PInfo>(operator_info);
auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info);
auto param_split_shapes = gatherv2_info->param_split_shapes();
auto index_offsets = gatherv2_info->index_offsets();
if (param_split_shapes.size() != index_offsets.size()) {

@ -148,7 +148,7 @@ std::string GetRealOpType(const std::string &op_type) {
static const std::map<std::string, std::string> kOpTypeMap = {
{"SparseApplyFtrl", "SparseApplyFtrlD"},
{"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"},
{"SparseGatherV2", "GatherV2"},
{"SparseGatherV2", "Gather"},
{"Pad", "PadD"},
{"Concat", "ConcatD"},
};

@ -247,7 +247,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(
.Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo())
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
@ -970,7 +970,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
}
// MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices)
if (node->IsApply(prim::kPrimGatherV2)) {
if (node->IsApply(prim::kPrimGather)) {
return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto);
}

@ -70,7 +70,7 @@ INPUT_MAP(GatherV2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}};
INPUT_ATTR_MAP(GatherV2D) = {{3, ATTR_DESC(axis, AnyTraits<int64_t>())}};
ATTR_MAP(GatherV2D) = EMPTY_ATTR_MAP;
OUTPUT_MAP(GatherV2D) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(GatherV2D, prim::kPrimGatherV2->name(), ADPT_DESC(GatherV2D))
REG_ADPT_DESC(GatherV2D, prim::kPrimGather->name(), ADPT_DESC(GatherV2D))
// ScatterNdD
INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(x)}};

@ -208,7 +208,7 @@ constexpr auto kPushOpName = "Push";
constexpr auto kPullOpName = "Pull";
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy";
constexpr auto kGatherV2OpName = "GatherV2";
constexpr auto kGatherV2OpName = "Gather";
constexpr auto kPaddingOpName = "Padding";
constexpr auto kAvgPoolOpName = "AvgPool";
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";

@ -64,7 +64,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimPad, {InferImplPad, true}},
{prim::kPrimUnique, {InferImplUnique, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
{prim::kPrimGather, {InferImplGatherV2, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},

@ -25,6 +25,7 @@
namespace mindspore {
namespace prim {
constexpr auto kGather = "Gather";
// Here list all primitives used in backend or some special primitives used by core.
// Arithmetic
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
@ -86,8 +87,8 @@ inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD");
inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>(kGather);
inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2");
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape");
@ -351,7 +352,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
// Other primitve not used by backend but used in core;
// Other primitive not used by backend but used in core;
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");

@ -607,7 +607,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<While>(prim, inputs, quantType);
} else if (op_type == "MirrorPad") {
return NewPrimitiveC<Pad>(prim, inputs, quantType);
} else if (op_type == "GatherV2") {
} else if (op_type == "Gather") {
return NewPrimitiveC<Gather>(prim, inputs, quantType);
} else if (op_type == "OnesLike") {
return NewPrimitiveC<OnesLike>(prim, inputs, quantType);

@ -97,6 +97,7 @@ STATUS TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
status = AddOpInput(tf_op, 1, inputs);
return status;
}
TFNodeRegistrar g_tfGatherV2Parser("GatherV2", new TFGatherParser());
} // namespace lite
} // namespace mindspore

@ -69,7 +69,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
{tflite::BuiltinOperator_RANGE, "Range"},
{tflite::BuiltinOperator_RANK, "Rank"},
{tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, "LocalResponseNorm"},
{tflite::BuiltinOperator_GATHER, "GatherV2"},
{tflite::BuiltinOperator_GATHER, "Gather"},
{tflite::BuiltinOperator_EXP, "Exp"},
{tflite::BuiltinOperator_SPLIT_V, "SplitV"},
{tflite::BuiltinOperator_SPLIT, "Split"},

@ -112,7 +112,7 @@ class Embedding(Cell):
self.expand = P.ExpandDims()
self.reshape_flat = P.Reshape()
self.shp_flat = (-1,)
self.gather = P.GatherV2()
self.gather = P.Gather()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, self.dtype)
self.off_value = Tensor(0.0, self.dtype)
@ -154,7 +154,7 @@ class EmbeddingLookup(Cell):
When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
specified 'offset = 0' to lookup table.
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
When 'target' is set to 'DEVICE', this module will use P.Gather() which
specified 'axis = 0' to lookup table.
In field slice mode, the manual_shapes must be given. It is a tuple ,where
the element is vocab[i], vocab[i] is the row numbers for i-th part.
@ -221,7 +221,7 @@ class EmbeddingLookup(Cell):
if sparse:
self.gatherv2 = P.SparseGatherV2()
else:
self.gatherv2 = P.GatherV2()
self.gatherv2 = P.Gather()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
enable_ps = _get_ps_context("enable_ps")
if enable_ps:
@ -231,7 +231,7 @@ class EmbeddingLookup(Cell):
name='embedding_table')
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.gather_revert = P.GatherV2()
self.gather_revert = P.Gather()
self.reshape_first = P.Reshape()
self.reshape = P.Reshape()
self.unique = P.Unique()
@ -379,7 +379,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
specified 'offset = 0' to lookup table.
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
When 'target' is set to 'DEVICE', this module will use P.Gather() which
specified 'axis = 0' to lookup table.
The vectors with the same field_ids will be combined by the 'operator', such as 'SUM', 'MAX' and
'MEAN'. Ensure the input_values of the padded id is zero, so that they can be ignored. The final

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save