|
|
@ -39,7 +39,8 @@ using OperatorList = std::vector<OperatorC>;
|
|
|
|
class RedistributionOperatorInfer {
|
|
|
|
class RedistributionOperatorInfer {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
const int NONE = -1;
|
|
|
|
const int NONE = -1;
|
|
|
|
explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {}
|
|
|
|
explicit RedistributionOperatorInfer(bool construct_op_flag = true)
|
|
|
|
|
|
|
|
: construct_op_flag_(construct_op_flag), is_cost_model_(false) {}
|
|
|
|
Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list,
|
|
|
|
Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list,
|
|
|
|
bool is_cost_model = false);
|
|
|
|
bool is_cost_model = false);
|
|
|
|
~RedistributionOperatorInfer() = default;
|
|
|
|
~RedistributionOperatorInfer() = default;
|
|
|
|