|
|
|
@ -73,6 +73,37 @@ Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator);
|
|
|
|
|
|
|
|
|
|
int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
|
|
|
|
MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size());
|
|
|
|
|
for (size_t i = 0; i < outputs_.size() / 2; i++) {
|
|
|
|
|
auto *input = inputs_[i + 1];
|
|
|
|
|
auto *output_true = outputs_[i];
|
|
|
|
|
auto *output_false = outputs_[i + outputs_.size() / 2];
|
|
|
|
|
if (input == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "input tensor is nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (output_true == nullptr || output_false == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "output tensor is nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
output_true->set_data_type(input->data_type());
|
|
|
|
|
output_false->set_data_type(input->data_type());
|
|
|
|
|
output_true->set_format(input->format());
|
|
|
|
|
output_false->set_format(input->format());
|
|
|
|
|
auto data_type = input->data_type();
|
|
|
|
|
if (data_type != kObjectTypeTensorType) {
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
|
|
|
|
|
auto output_true_tensorlist = reinterpret_cast<TensorList *>(output_true);
|
|
|
|
|
auto output_false_tensorlist = reinterpret_cast<TensorList *>(output_false);
|
|
|
|
|
output_true_tensorlist->set_element_shape(input_tensorlist->element_shape());
|
|
|
|
|
output_false_tensorlist->set_element_shape(input_tensorlist->element_shape());
|
|
|
|
|
output_true_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
|
|
|
|
|
output_false_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
|
|
|
|
|
output_true_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
|
|
|
|
|
output_false_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!infer_flag()) {
|
|
|
|
|
return RET_INFER_INVALID;
|
|
|
|
|
}
|
|
|
|
@ -88,12 +119,8 @@ int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
|
|
|
|
|
MS_LOG(ERROR) << "output tensor is nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
output_true->set_data_type(input->data_type());
|
|
|
|
|
output_false->set_data_type(input->data_type());
|
|
|
|
|
output_true->set_shape(input->shape());
|
|
|
|
|
output_false->set_shape(input->shape());
|
|
|
|
|
output_true->set_format(input->format());
|
|
|
|
|
output_false->set_format(input->format());
|
|
|
|
|
auto data_type = input->data_type();
|
|
|
|
|
if (data_type != kObjectTypeTensorType) {
|
|
|
|
|
continue;
|
|
|
|
|