temporarily cast between int64 and int32 to wait ME support int64

pull/3928/head
Yi Huaijie 5 years ago
parent 518cb80133
commit 80bdcab982

@ -189,7 +189,10 @@ void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, cons
return;
}
Shape shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
std::vector<int> shape_int = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
Shape shape;
(void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape),
[](const int &value) { return static_cast<int64_t>(value); });
auto ret = param_shapes.try_emplace(param_node->name(), shape);
if (!ret.second) {
MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";

@ -176,7 +176,7 @@ Status Softmax::GetAttrs() {
}
std::vector<ValuePtr> value_vector = value_tuple->value();
(void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_),
[](const ValuePtr &value) { return static_cast<int32_t>(GetValue<int64_t>(value)); });
[](const ValuePtr &value) { return static_cast<int32_t>(GetValue<int>(value)); });
if (axis_.empty()) {
MS_LOG(ERROR) << name_ << " : The axis tuple is empty.";
return FAILED;

@ -259,8 +259,10 @@ void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) {
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr.";
}
ValuePtr new_shape = MakeValue(input_slice_shape);
std::vector<int32_t> input_slice_shape_int;
(void)std::transform(input_slice_shape.begin(), input_slice_shape.end(), std::back_inserter(input_slice_shape_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr new_shape = MakeValue(input_slice_shape_int);
AnfNodePtr val = NewValueNode(new_shape);
(void)manager->Replace(dropout_gen_mask_cnode->input(1), val);
}
@ -306,8 +308,10 @@ std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodeP
MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape);
return replace_ops;
}
ValuePtr new_shape = MakeValue(input_slice_shape);
std::vector<int32_t> input_slice_shape_int;
(void)std::transform(input_slice_shape.begin(), input_slice_shape.end(), std::back_inserter(input_slice_shape_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr new_shape = MakeValue(input_slice_shape_int);
Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0));
Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1));
OperatorAttrs attrs = {attr_0, attr_1};

@ -560,7 +560,8 @@ Status GatherV2PInfo::ComputeReplaceOp() {
OperatorName op_name = EMBEDDING_LOOKUP;
OperatorAttrs attrs;
Attr param_offset = std::make_pair("offset", MakeValue(bias));
int32_t bias_int = static_cast<int32_t>(bias);
Attr param_offset = std::make_pair("offset", MakeValue(bias_int));
OperatorParams params = {std::make_pair(param_offset, 3)};
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(op_name, args);

@ -215,7 +215,15 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) {
out_shapes[i][0] = out_shapes[i][0] / dev_num_;
}
}
ValuePtr new_shapes = MakeValue(out_shapes);
std::vector<std::vector<int32_t>> out_shapes_int;
(void)std::transform(out_shapes.begin(), out_shapes.end(), std::back_inserter(out_shapes_int),
[](const std::vector<int64_t> &shape) {
std::vector<int32_t> shape_int;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int),
[](const int64_t &v) { return static_cast<int32_t>(v); });
return shape_int;
});
ValuePtr new_shapes = MakeValue(out_shapes_int);
Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]);
Attr attr_shapes = std::make_pair(SHAPES, new_shapes);
Attr attr_num = std::make_pair(GETNEXT_NUM, attrs_[GETNEXT_NUM]);

@ -180,7 +180,10 @@ void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
ValuePtr new_multiples = MakeValue(slice_multiples_);
std::vector<int32_t> slice_multiples_int;
(void)std::transform(slice_multiples_.begin(), slice_multiples_.end(), std::back_inserter(slice_multiples_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr new_multiples = MakeValue(slice_multiples_int);
AnfNodePtr val = NewValueNode(new_multiples);
(void)manager->Replace(cnode->input(2), val);
}

@ -776,7 +776,10 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
if (input_shape == nullptr) {
MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
}
Shape shape = input_shape->shape();
std::vector<int> shape_int = input_shape->shape();
Shape shape;
(void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape),
[](int sub_shape) { return static_cast<int64_t>(sub_shape); });
Shapes inputs_shape = {shape};
Shapes outputs_shape = {shape};
// 2) init the attr

@ -1027,7 +1027,7 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
std::vector<ValuePtr> value_vector = value_tuple->value();
(void)std::transform(
value_vector.begin(), value_vector.end(), std::back_inserter(dim), [](const ValuePtr &value) {
return GetValue<int64_t>(value);
return value->isa<Int64Imm>() ? GetValue<int64_t>(value) : static_cast<int64_t>(GetValue<int>(value));
});
strategy.push_back(dim);
} else {
@ -1077,13 +1077,19 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
for (auto &shape : tuple_shape) {
auto each_shape = dyn_cast<abstract::Shape>(shape);
MS_EXCEPTION_IF_NULL(each_shape);
Shape new_shape = each_shape->shape();
std::vector<int> shape_int = each_shape->shape();
Shape new_shape;
(void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(new_shape),
[](const int &value) { return static_cast<int64_t>(value); });
shapes.push_back(new_shape);
}
} else {
auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
MS_EXCEPTION_IF_NULL(shape_ptr);
Shape new_shape = shape_ptr->shape();
std::vector<int> shape_int = shape_ptr->shape();
Shape new_shape;
(void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(new_shape),
[](const int &value) { return static_cast<int64_t>(value); });
shapes.push_back(new_shape);
}
return shapes;

@ -148,7 +148,7 @@ std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::G
Shape expand_num_list_shape;
(void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(),
std::back_inserter(expand_num_list_shape),
[](const Arrangement &arr) { return SizeToLong(arr.GetDimSize()); });
[](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); });
Arrangement expand_num_list;
Status status = expand_num_list.Init(expand_num_list_shape);
if (status != Status::SUCCESS) {

@ -32,7 +32,10 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix
// skip redistribution for reshape operator
OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) {
OperatorAttrs attrs;
ValuePtr param_value = MakeValue(shape);
std::vector<int32_t> shape_int;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr param_value = MakeValue(shape_int);
Attr param = std::make_pair(SHAPE, param_value);
OperatorParams params = {std::make_pair(param, 2)};
OperatorArgs args = std::make_pair(attrs, params);
@ -52,7 +55,10 @@ Status ConstructOperator::ReshapeOP(Shape shape) {
return Status::INVALID_ARGUMENT;
}
OperatorAttrs attrs;
ValuePtr param_value = MakeValue(shape);
std::vector<int32_t> shape_int;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr param_value = MakeValue(shape_int);
Attr param = std::make_pair(SHAPE, param_value);
OperatorParams params = {std::make_pair(param, 2)};
OperatorArgs args = std::make_pair(attrs, params);
@ -69,12 +75,21 @@ Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &en
Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value);
OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask};
ValuePtr param_begin_value = MakeValue(begin);
std::vector<int32_t> begin_int;
(void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr param_begin_value = MakeValue(begin_int);
Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2);
ValuePtr param_end_value = MakeValue(end);
std::vector<int32_t> end_int;
(void)std::transform(end.begin(), end.end(), std::back_inserter(end_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr param_end_value = MakeValue(end_int);
Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3);
ValuePtr param_strides_value = MakeValue(strides);
std::vector<int32_t> strides_int;
(void)std::transform(strides.begin(), strides.end(), std::back_inserter(strides_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr param_strides_value = MakeValue(strides_int);
Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4);
OperatorParams params = {param_begin, param_end, param_strides};
OperatorArgs op_args = std::make_pair(attrs, params);

@ -52,17 +52,26 @@ void Init_Device_Manager() {
}
CNodePtr Make_Node(Shape x, Shape y, Shape out, int condition = 0) {
std::vector<int32_t> x_shape;
std::vector<int32_t> y_shape;
std::vector<int32_t> out_shape;
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
ParameterPtr param1 = func_graph->add_parameter();
ParameterPtr param2 = func_graph->add_parameter();
(void)std::transform(x.begin(), x.end(), std::back_inserter(x_shape),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(y.begin(), y.end(), std::back_inserter(y_shape),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(out.begin(), out.end(), std::back_inserter(out_shape),
[](const int64_t &value) { return static_cast<int>(value); });
param1->set_name("x");
param2->set_name("y");
BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x);
BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y);
BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out);
std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>(kNumberTypeInt32, x);
std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>(kNumberTypeInt32, y);
std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>(kNumberTypeInt32, out);
std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>(kNumberTypeInt32, x_shape);
std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>(kNumberTypeInt32, y_shape);
std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>(kNumberTypeInt32, out_shape);
AbstractBasePtr abstract1 = abstract::FromValue(inputs_x, true);
AbstractBasePtr abstract2 = abstract::FromValue(inputs_y, true);
AbstractBasePtr abstract3 = abstract::FromValue(inputs_out, true);

@ -98,8 +98,14 @@ TEST_F(TestConstructOperator, TestStridedSliceOP) {
OperatorParams params = op.second.second;
ValuePtr begin_ptr = params[0].first.second;
ValuePtr end_ptr = params[1].first.second;
Shape begin = GetValue<const std::vector<int64_t>>(begin_ptr);
Shape end = GetValue<const std::vector<int64_t>>(end_ptr);
std::vector<int32_t> begin_int = GetValue<const std::vector<int32_t>>(begin_ptr);
std::vector<int32_t> end_int = GetValue<const std::vector<int32_t>>(end_ptr);
Shape begin;
Shape end;
(void)std::transform(begin_int.begin(), begin_int.end(), std::back_inserter(begin),
[](const int32_t &value) { return static_cast<int64_t>(value); });
(void)std::transform(end_int.begin(), end_int.end(), std::back_inserter(end),
[](const int32_t &value) { return static_cast<int64_t>(value); });
for (size_t i = 0; i < begin.size(); i++) {
int64_t diff = end[i] - begin[i];
int64_t num = shape[i];

Loading…
Cancel
Save