fix concat fp16 when tensor is int32

pull/6327/head
sunsuodong 5 years ago
parent dccd231ff0
commit 2bdd184212

@ -43,4 +43,24 @@ float16_t *MallocOutputFp16(lite::Tensor *output, const lite::Context *ctx) {
} }
return fp16_data; return fp16_data;
} }
bool IsExistFp16Tensor(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs) {
bool result = false;
for (auto &input : inputs) {
if (input->data_type() == kNumberTypeFloat16) {
result = true;
break;
}
}
if (result) {
return true;
}
for (auto &output : outputs) {
if (output->data_type() == kNumberTypeFloat16) {
result = true;
break;
}
}
return result;
}
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -16,6 +16,7 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_
#include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
namespace mindspore::kernel { namespace mindspore::kernel {
@ -23,6 +24,7 @@ float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::Context *ctx)
float16_t *MallocOutputFp16(lite::Tensor *output, const lite::Context *ctx); float16_t *MallocOutputFp16(lite::Tensor *output, const lite::Context *ctx);
bool IsExistFp16Tensor(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs);
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_

@ -13,12 +13,11 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <vector>
#include "nnacl/fp16/concat_fp16.h"
#include "src/runtime/kernel/arm/fp16/concat_fp16.h" #include "src/runtime/kernel/arm/fp16/concat_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "src/runtime/kernel/arm/fp32/concat.h"
#include "nnacl/fp16/concat_fp16.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "schema/model_generated.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "nnacl/fp16/cast_fp16.h" #include "nnacl/fp16/cast_fp16.h"
@ -142,24 +141,28 @@ int ConcatFp16CPUKernel::Run() {
} }
kernel::LiteKernel *CpuConcatFp16KernelCreator(const std::vector<lite::Tensor *> &inputs, kernel::LiteKernel *CpuConcatFp16KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const Context *ctx, const kernel::KernelKey &desc, const Context *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) { if (parameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!"; MS_LOG(ERROR) << "Input parameter is nullptr!";
return nullptr; return nullptr;
} }
MS_ASSERT(desc.type == schema::PrimitiveType_Concat); kernel::LiteKernel *kernel = nullptr;
auto *kernel = new (std::nothrow) ConcatFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (IsExistFp16Tensor(inputs, outputs)) {
kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) ConcatCPUKernel(parameter, inputs, outputs, ctx, primitive);
}
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr; return nullptr;
} }
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr; return nullptr;
} }
return kernel; return kernel;

Loading…
Cancel
Save