From 0d3263bf2407c9ca5e181e71fef27234adf07b2e Mon Sep 17 00:00:00 2001 From: yefeng Date: Tue, 15 Dec 2020 16:23:34 +0800 Subject: [PATCH] fix_controlflow_ops-3 --- mindspore/lite/nnacl/fp32/arithmetic_fp32.c | 8 ++++++ mindspore/lite/nnacl/fp32/arithmetic_fp32.h | 1 + mindspore/lite/src/ops/merge.cc | 4 +++ mindspore/lite/src/ops/switch.cc | 16 +++++++++++- .../src/runtime/kernel/arm/base/switch.cc | 2 +- .../arm/fp32/arithmetic_compare_fp32.cc | 1 + .../kernel/arm/fp32/arithmetic_fp32.cc | 2 ++ .../runtime/kernel/arm/fp32/squeeze_fp32.cc | 1 + mindspore/lite/src/scheduler.cc | 25 +++++++++++++++++-- mindspore/lite/src/scheduler.h | 2 ++ 10 files changed, 58 insertions(+), 4 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c index 23720bf1df..d888be181e 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c @@ -890,6 +890,14 @@ int ElementLogicalAnd(const float *input0, const float *input1, float *output, c return NNACL_OK; } +int ElementLogicalAndInt(const int *input0, const int *input1, int *output, const int element_size) { + int index = 0; + for (; index < element_size; index++) { + output[index] = (int)((int)(input0[index]) & (int)(input1[index])); + } + return NNACL_OK; +} + int ElementSquaredDifference(const float *input0, const float *input1, float *output, const int element_size) { ElementSub(input0, input1, output, element_size); return ElementMul(output, output, output, element_size); diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h index c30f992487..f3261f0ddb 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h @@ -92,6 +92,7 @@ int BroadcastDiv(const float *input0, const float *input1, float *tile_input0, f int element_size, ArithmeticParameter *param); int ElementLogicalAnd(const float *input0, const float *input1, float *output, const int element_size); +int ElementLogicalAndInt(const int *input0, const int *input1, int *output, const int element_size); int BroadcastLogicalAnd(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param); diff --git a/mindspore/lite/src/ops/merge.cc b/mindspore/lite/src/ops/merge.cc index 1c7e7bc91d..7dd0397234 100644 --- a/mindspore/lite/src/ops/merge.cc +++ b/mindspore/lite/src/ops/merge.cc @@ -68,8 +68,12 @@ Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator); int Merge::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(inputs_.size() == 2 * outputs_.size()); + if (!infer_flag()) { + return RET_INFER_INVALID; + } for (size_t i = 0; i < inputs_.size() / 2; i++) { outputs_[i]->set_data_type(inputs_[i]->data_type()); + outputs_[i]->set_shape(inputs_[i]->shape()); } return RET_OK; } diff --git a/mindspore/lite/src/ops/switch.cc b/mindspore/lite/src/ops/switch.cc index 70277c2d27..2135f5fbb7 100644 --- a/mindspore/lite/src/ops/switch.cc +++ b/mindspore/lite/src/ops/switch.cc @@ -70,6 +70,20 @@ PrimitiveC *SwitchCreator(const schema::Primitive *primitive) { return Primitive Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator); #endif -int Switch::InferShape(std::vector inputs_, std::vector outputs_) { return RET_OK; } +int Switch::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size()); + if (!infer_flag()) { + return RET_INFER_INVALID; + } + for (size_t i = 0; i < outputs_.size() / 2; i++) { + outputs_[i]->set_data_type(inputs_[i + 1]->data_type()); + outputs_[i + outputs_.size() / 2]->set_data_type(inputs_[i + 1]->data_type()); + outputs_[i]->set_shape(inputs_[i + 1]->shape()); + outputs_[i + outputs_.size() / 2]->set_shape(inputs_[i + 1]->shape()); + outputs_[i]->set_format(inputs_[i + 1]->format()); + outputs_[i + outputs_.size() / 2]->set_format(inputs_[i + 1]->format()); + } + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/base/switch.cc b/mindspore/lite/src/runtime/kernel/arm/base/switch.cc index 1a16bdaaaa..eb2c7933aa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/switch.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/switch.cc @@ -30,7 +30,7 @@ int SwitchCPUKernel::PostProcess() { MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool); MS_ASSERT(bool_tensor->shape().size() == 1); MS_ASSERT(bool_tensor->shape().front() == 1); - auto *active = static_cast(bool_tensor->data_c()); + auto active = static_cast(bool_tensor->data_c()); if (active == nullptr) { MS_LOG(ERROR) << "data of bool tensor is nullptr"; return lite::RET_NULL_PTR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc index ffa69ecc12..98a42b0be6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc @@ -132,6 +132,7 @@ kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector *kernels) { } // when merge is removed, this if is removed automatically if (kernel->Type() == schema::PrimitiveType_Merge) { - MS_ASSERT(kernel->in_kernels().size() == 2); - return (is_kernel_finish[kernel->in_kernels().at(0)] || is_kernel_finish[kernel->in_kernels().at(1)]); + return MergeOpIsReady(kernel, is_kernel_finish); } else { return std::all_of(kernel_inputs.begin(), kernel_inputs.end(), [&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; }); @@ -370,6 +369,28 @@ int Scheduler::ConstructSubGraphs(std::vector *kernels) { } return RET_OK; } +bool Scheduler::MergeOpIsReady(const kernel::LiteKernel *kernel, + std::map is_kernel_finish) { + std::map merge_in_tensors_map; + for (auto merge_in_tensor : kernel->in_tensors()) { + merge_in_tensors_map[merge_in_tensor] = false; + if (merge_in_tensor->category() == Tensor::CONST_TENSOR || merge_in_tensor->category() == Tensor::CONST_SCALAR) { + merge_in_tensors_map[merge_in_tensor] = true; + } + for (auto merge_in_kernel : kernel->in_kernels()) { + for (auto tensor : merge_in_kernel->out_tensors()) { + if (tensor == merge_in_tensor && is_kernel_finish[merge_in_kernel]) { + merge_in_tensors_map[merge_in_tensor] = true; + } + } + } + } + auto kernel_in_tensors_num = kernel->in_tensors().size(); + return std::all_of(kernel->in_tensors().begin(), kernel->in_tensors().begin() + kernel_in_tensors_num / 2, + [&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; }) || + std::all_of(kernel->in_tensors().begin() + kernel_in_tensors_num / 2, kernel->in_tensors().end(), + [&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; }); +} kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector &kernels, kernel::SubGraphType type) { diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index 755ce8fcb4..0ef4f87783 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -65,6 +65,8 @@ class Scheduler { kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector &kernels, kernel::SubGraphType type); + bool MergeOpIsReady(const kernel::LiteKernel *kernel, std::map is_kernel_finish); + std::vector FindAllSubGraphKernels( kernel::LiteKernel *head_kernel, std::map *sinked_kernel_map);