|
|
|
@ -32,7 +32,10 @@ Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix
|
|
|
|
|
// skip redistribution for reshape operator
|
|
|
|
|
OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) {
|
|
|
|
|
OperatorAttrs attrs;
|
|
|
|
|
ValuePtr param_value = MakeValue(shape);
|
|
|
|
|
std::vector<int32_t> shape_int;
|
|
|
|
|
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int),
|
|
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
|
|
|
|
ValuePtr param_value = MakeValue(shape_int);
|
|
|
|
|
Attr param = std::make_pair(SHAPE, param_value);
|
|
|
|
|
OperatorParams params = {std::make_pair(param, 2)};
|
|
|
|
|
OperatorArgs args = std::make_pair(attrs, params);
|
|
|
|
@ -52,7 +55,10 @@ Status ConstructOperator::ReshapeOP(Shape shape) {
|
|
|
|
|
return Status::INVALID_ARGUMENT;
|
|
|
|
|
}
|
|
|
|
|
OperatorAttrs attrs;
|
|
|
|
|
ValuePtr param_value = MakeValue(shape);
|
|
|
|
|
std::vector<int32_t> shape_int;
|
|
|
|
|
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int),
|
|
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
|
|
|
|
ValuePtr param_value = MakeValue(shape_int);
|
|
|
|
|
Attr param = std::make_pair(SHAPE, param_value);
|
|
|
|
|
OperatorParams params = {std::make_pair(param, 2)};
|
|
|
|
|
OperatorArgs args = std::make_pair(attrs, params);
|
|
|
|
@ -69,12 +75,21 @@ Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &en
|
|
|
|
|
Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value);
|
|
|
|
|
OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask};
|
|
|
|
|
|
|
|
|
|
ValuePtr param_begin_value = MakeValue(begin);
|
|
|
|
|
std::vector<int32_t> begin_int;
|
|
|
|
|
(void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_int),
|
|
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
|
|
|
|
ValuePtr param_begin_value = MakeValue(begin_int);
|
|
|
|
|
Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2);
|
|
|
|
|
ValuePtr param_end_value = MakeValue(end);
|
|
|
|
|
std::vector<int32_t> end_int;
|
|
|
|
|
(void)std::transform(end.begin(), end.end(), std::back_inserter(end_int),
|
|
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
|
|
|
|
ValuePtr param_end_value = MakeValue(end_int);
|
|
|
|
|
Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3);
|
|
|
|
|
|
|
|
|
|
ValuePtr param_strides_value = MakeValue(strides);
|
|
|
|
|
std::vector<int32_t> strides_int;
|
|
|
|
|
(void)std::transform(strides.begin(), strides.end(), std::back_inserter(strides_int),
|
|
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
|
|
|
|
ValuePtr param_strides_value = MakeValue(strides_int);
|
|
|
|
|
Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4);
|
|
|
|
|
OperatorParams params = {param_begin, param_end, param_strides};
|
|
|
|
|
OperatorArgs op_args = std::make_pair(attrs, params);
|
|
|
|
|