|
|
|
@ -13,12 +13,11 @@
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "nnacl/fp16/concat_fp16.h"
|
|
|
|
|
#include "src/runtime/kernel/arm/fp16/concat_fp16.h"
|
|
|
|
|
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
|
|
|
|
|
#include "src/runtime/kernel/arm/fp32/concat.h"
|
|
|
|
|
#include "nnacl/fp16/concat_fp16.h"
|
|
|
|
|
#include "src/kernel_registry.h"
|
|
|
|
|
#include "schema/model_generated.h"
|
|
|
|
|
#include "include/errorcode.h"
|
|
|
|
|
#include "nnacl/fp16/cast_fp16.h"
|
|
|
|
|
|
|
|
|
@ -142,24 +141,28 @@ int ConcatFp16CPUKernel::Run() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel::LiteKernel *CpuConcatFp16KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
|
|
|
|
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
|
|
|
|
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
|
|
|
|
|
const Context *ctx, const kernel::KernelKey &desc,
|
|
|
|
|
const mindspore::lite::PrimitiveC *primitive) {
|
|
|
|
|
if (opParameter == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Input opParameter is nullptr!";
|
|
|
|
|
if (parameter == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Input parameter is nullptr!";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
|
|
|
|
|
auto *kernel = new (std::nothrow) ConcatFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
kernel::LiteKernel *kernel = nullptr;
|
|
|
|
|
if (IsExistFp16Tensor(inputs, outputs)) {
|
|
|
|
|
kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
} else {
|
|
|
|
|
kernel = new (std::nothrow) ConcatCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
}
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto ret = kernel->Init();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
|
|
|
|
|
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
|
|
|
|
delete kernel;
|
|
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
|
|
|
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return kernel;
|
|
|
|
|