|
|
|
@ -20,8 +20,9 @@
|
|
|
|
|
#include "src/common/log_adapter.h"
|
|
|
|
|
#include "src/tensor.h"
|
|
|
|
|
|
|
|
|
|
#ifndef PRIMITIVE_WRITEABLE
|
|
|
|
|
#include "src/ops/ops_register.h"
|
|
|
|
|
#include "nnacl/batch_to_space.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
@ -75,43 +76,6 @@ PrimitiveC *BatchToSpaceCreator(const schema::Primitive *primitive) {
|
|
|
|
|
Registry BatchToSpaceRegistry(schema::PrimitiveType_BatchToSpace, BatchToSpaceCreator);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *primitive) {
|
|
|
|
|
BatchToSpaceParameter *batch_space_param =
|
|
|
|
|
reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter)));
|
|
|
|
|
if (batch_space_param == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
memset(batch_space_param, 0, sizeof(BatchToSpaceParameter));
|
|
|
|
|
batch_space_param->op_parameter_.type_ = primitive->Type();
|
|
|
|
|
auto param = reinterpret_cast<mindspore::lite::BatchToSpace *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
|
|
|
|
auto block_shape = param->GetBlockShape();
|
|
|
|
|
if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
|
|
|
|
|
MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE;
|
|
|
|
|
free(batch_space_param);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto crops = param->GetCrops();
|
|
|
|
|
if (crops.size() != BATCH_TO_SPACE_CROPS_SIZE) {
|
|
|
|
|
MS_LOG(ERROR) << "batch_to_space crops size should be " << BATCH_TO_SPACE_CROPS_SIZE;
|
|
|
|
|
free(batch_space_param);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
|
|
|
|
|
batch_space_param->block_shape_[i] = block_shape[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
|
|
|
|
|
batch_space_param->crops_[i] = crops[i];
|
|
|
|
|
}
|
|
|
|
|
return reinterpret_cast<OpParameter *>(batch_space_param);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Registry BatchToSpaceParameterRegistry(schema::PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter);
|
|
|
|
|
Registry BatchToSpaceNDParameterRegistry(schema::PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter);
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr int kBatchToSpaceOutputNum = 1;
|
|
|
|
|
constexpr int kBatchToSpaceInputNum = 1;
|
|
|
|
|