!7716 [lite] support complex64 and convertion about audio model and fix bug

Merge pull request !7716 from 徐安越/master
pull/7716/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 97cde36843

@ -47,6 +47,12 @@ Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) {
}
}
Complex::Complex(const int nbits) : Number(TypeId::kNumberTypeComplex64, nbits, false) {
if (nbits != 64) {
MS_LOG(EXCEPTION) << "Wrong number of bits.";
}
}
const TypePtr kBool = std::make_shared<Bool>();
const TypePtr kInt8 = std::make_shared<Int>(8);
const TypePtr kInt16 = std::make_shared<Int>(16);
@ -63,4 +69,5 @@ const TypePtr kInt = std::make_shared<Int>();
const TypePtr kUInt = std::make_shared<UInt>();
const TypePtr kFloat = std::make_shared<Float>();
const TypePtr kNumber = std::make_shared<Number>();
const TypePtr kComplex64 = std::make_shared<Complex>(64);
} // namespace mindspore

@ -150,6 +150,28 @@ class Float : public Number {
}
};
// Complex
class Complex : public Number {
public:
Complex() : Number(kNumberTypeComplex64, 0) {}
explicit Complex(const int nbits);
~Complex() override {}
MS_DECLARE_PARENT(Complex, Number)
TypeId generic_type_id() const override { return kNumberTypeComplex64; }
TypePtr DeepCopy() const override {
if (nbits() == 0) {
return std::make_shared<Complex>();
}
return std::make_shared<Complex>(nbits());
}
std::string ToString() const override { return GetTypeName("Complex64"); }
std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); }
std::string DumpText() const override {
return nbits() == 0 ? std::string("Complex64") : std::string("C") + std::to_string(nbits());
}
};
extern const TypePtr kBool;
extern const TypePtr kInt8;
extern const TypePtr kInt16;
@ -166,6 +188,7 @@ extern const TypePtr kInt;
extern const TypePtr kUInt;
extern const TypePtr kFloat;
extern const TypePtr kNumber;
extern const TypePtr kComplex64;
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_

@ -69,6 +69,8 @@ TypePtr TypeIdToType(TypeId id) {
return kFloat32;
case kNumberTypeFloat64:
return kFloat64;
case kNumberTypeComplex64:
return kComplex64;
case kNumberTypeInt8:
return kInt8;
case kNumberTypeInt16:

@ -0,0 +1,107 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/audio_spectrogram.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value.AsAudioSpectrogram()->windowSize; }
int AudioSpectrogram::GetStride() const { return this->primitive_->value.AsAudioSpectrogram()->stride; }
bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value.AsAudioSpectrogram()->magSquare; }
#else
int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value_as_AudioSpectrogram()->windowSize(); }
int AudioSpectrogram::GetStride() const { return this->primitive_->value_as_AudioSpectrogram()->stride(); }
bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value_as_AudioSpectrogram()->magSquare(); }
int AudioSpectrogram::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_AudioSpectrogram();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Add return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAudioSpectrogram(*fbb, attr->windowSize(), attr->stride(), attr->magSquare());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AudioSpectrogram, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *AudioSpectrogramCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<AudioSpectrogram>(primitive);
}
Registry AudioSpectrogramRegistry(schema::PrimitiveType_AudioSpectrogram, AudioSpectrogramCreator);
#endif
int AudioSpectrogram::Log2Ceil(uint32_t length) {
if (length == 0) {
return -1;
}
int floor = 0;
for (int i = 4; i >= 0; --i) {
int shift = (1 << i);
uint32_t tmp = length >> shift;
if (tmp != 0) {
length = tmp;
floor += shift;
}
}
return length == (length & ~(length - 1)) ? floor : floor + 1;
}
uint32_t AudioSpectrogram::GetFftLength(uint32_t length) {
int shift = Log2Ceil(length);
return 1 << shift;
}
int AudioSpectrogram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
if (input_shape.size() != 2) {
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
return RET_ERROR;
}
if (GetWindowSize() < 2) {
MS_LOG(ERROR) << "window size is too short, now is " << GetWindowSize();
return RET_ERROR;
}
if (GetStride() < 1) {
MS_LOG(ERROR) << "stride must be positive, now is " << GetStride();
return RET_ERROR;
}
std::vector<int> output_shape(3);
output_shape[0] = input_shape[1];
// output height
int sample_sub_window = input_shape[0] - GetWindowSize();
output_shape[1] = sample_sub_window < 0 ? 0 : 1 + sample_sub_window / GetStride();
// compute fft length
int fft_length = GetFftLength(GetWindowSize());
output_shape[2] = fft_length / 2 + 1;
outputs_.front()->set_shape(output_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,51 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_
#define LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class AudioSpectrogram : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC);
AudioSpectrogram() = default;
explicit AudioSpectrogram(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetWindowSize(int window_size) { this->primitive_->value.AsAudioSpectrogram()->windowSize = window_size; }
void SetStride(int stride) { this->primitive_->value.AsAudioSpectrogram()->stride = stride; }
void SetMagSquare(bool mag_square) { this->primitive_->value.AsAudioSpectrogram()->magSquare = mag_square; }
#else
AudioSpectrogram() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetWindowSize() const;
int GetStride() const;
bool GetMagSquare() const;
int Log2Ceil(uint32_t length);
uint32_t GetFftLength(uint32_t length);
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_

@ -0,0 +1,54 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/fft_imag.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int FftImag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateEqual(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftImag, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *FftImagCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<FftImag>(primitive); }
Registry FftImagRegistry(schema::PrimitiveType_FftImag, FftImagCreator);
#endif
int FftImag::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(TypeId::kNumberTypeFloat32);
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
input_shape.pop_back();
outputs_.front()->set_shape(input_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,43 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_
#define LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class FftImag : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FftImag, PrimitiveC);
FftImag() = default;
explicit FftImag(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
FftImag() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_

@ -0,0 +1,54 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/fft_real.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int FftReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateEqual(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftReal, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *FftRealCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<FftReal>(primitive); }
Registry FftRealRegistry(schema::PrimitiveType_FftReal, FftRealCreator);
#endif
int FftReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(TypeId::kNumberTypeFloat32);
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
input_shape.pop_back();
outputs_.front()->set_shape(input_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,43 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class FftReal : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FftReal, PrimitiveC);
FftReal() = default;
explicit FftReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
FftReal() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_

@ -0,0 +1,83 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/mfcc.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value.AsMfcc()->freqUpperLimit; }
float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value.AsMfcc()->freqLowerLimit; }
int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value.AsMfcc()->filterBankChannelNum; }
int Mfcc::GetDctCoeffNum() const { return this->primitive_->value.AsMfcc()->dctCoeffNum; }
#else
float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value_as_Mfcc()->freqUpperLimit(); }
float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value_as_Mfcc()->freqLowerLimit(); }
int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value_as_Mfcc()->filterBankChannelNum(); }
int Mfcc::GetDctCoeffNum() const { return this->primitive_->value_as_Mfcc()->dctCoeffNum(); }
int Mfcc::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Mfcc();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Add return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateMfcc(*fbb, attr->freqUpperLimit(), attr->freqLowerLimit(),
attr->filterBankChannelNum(), attr->dctCoeffNum());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mfcc, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *MfccCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mfcc>(primitive); }
Registry MfccRegistry(schema::PrimitiveType_Mfcc, MfccCreator);
#endif
int Mfcc::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
if (input_shape.size() != 3) {
MS_LOG(ERROR) << "first input shape is error, which need to be 3 dimensions, but the dimension is "
<< input_shape.size();
return RET_ERROR;
}
if (inputs_[1]->ElementsNum() != 1) {
MS_LOG(ERROR) << "second input element num is error, which need only a value, but the number is "
<< inputs_[1]->ElementsNum();
return RET_ERROR;
}
std::vector<int> output_shape(3);
output_shape[0] = input_shape[0];
output_shape[1] = input_shape[1];
output_shape[2] = GetDctCoeffNum();
outputs_.front()->set_shape(output_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,57 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_MFCC_H_
#define LITE_MINDSPORE_LITE_C_OPS_MFCC_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Mfcc : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Mfcc, PrimitiveC);
Mfcc() = default;
explicit Mfcc(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFreqUpperLimit(float freq_upper_limit) {
this->primitive_->value.AsMfcc()->freqUpperLimit = freq_upper_limit;
}
void SetFreqLowerLimit(float freq_lower_limit) {
this->primitive_->value.AsMfcc()->freqLowerLimit = freq_lower_limit;
}
void SetFilterBankChannelNum(int filter_bank_channel_num) {
this->primitive_->value.AsMfcc()->filterBankChannelNum = filter_bank_channel_num;
}
void SetDctCoeffNum(int dct_coeff_num) { this->primitive_->value.AsMfcc()->dctCoeffNum = dct_coeff_num; }
#else
Mfcc() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetFreqUpperLimit() const;
float GetFreqLowerLimit() const;
int GetFilterBankChannelNum() const;
int GetDctCoeffNum() const;
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_MFCC_H_

@ -137,6 +137,11 @@
#include "src/ops/upsample.h"
#include "src/ops/layer_norm.h"
#include "src/ops/non_max_suppression.h"
#include "src/ops/rfft.h"
#include "src/ops/fft_real.h"
#include "src/ops/fft_imag.h"
#include "src/ops/audio_spectrogram.h"
#include "src/ops/mfcc.h"
#include "src/ops/identity.h"
#ifdef SUPPORT_TRAIN
@ -775,6 +780,16 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new NonMaxSuppression(primitive);
case schema::PrimitiveType_Identity:
return new Identity(primitive);
case schema::PrimitiveType_Rfft:
return new Rfft(primitive);
case schema::PrimitiveType_FftReal:
return new FftReal(primitive);
case schema::PrimitiveType_FftImag:
return new FftImag(primitive);
case schema::PrimitiveType_AudioSpectrogram:
return new AudioSpectrogram(primitive);
case schema::PrimitiveType_Mfcc:
return new Mfcc(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:

@ -0,0 +1,66 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/rfft.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Rfft::GetFftLength() const { return this->primitive_->value.AsRfft()->fftLength; }
void Rfft::SetFftLength(int fft_length) { this->primitive_->value.AsRfft()->fftLength = fft_length; }
#else
int Rfft::GetFftLength() const { return this->primitive_->value_as_Rfft()->fftLength(); }
int Rfft::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Rfft();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Add return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateRfft(*fbb, attr->fftLength());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Rfft, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *RfftCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Rfft>(primitive); }
Registry RfftRegistry(schema::PrimitiveType_Rfft, RfftCreator);
#endif
int Rfft::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(TypeId::kNumberTypeComplex64);
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
input_shape[input_shape.size() - 1] = GetFftLength() / 2 + 1;
input_shape.push_back(2);
outputs_.front()->set_shape(input_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,45 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_RFFT_H_
#define LITE_MINDSPORE_LITE_C_OPS_RFFT_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Rfft : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Rfft, PrimitiveC);
Rfft() = default;
explicit Rfft(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFftLength(int fft_length);
#else
Rfft() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetFftLength() const;
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_RFFT_H_

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/caffe/caffe_elu_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
MS_LOG(DEBUG) << "parse CaffeEluParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
if (proto.has_elu_param()) {
const caffe::ELUParameter eluParameter = proto.elu_param();
if (eluParameter.has_alpha()) {
attr->alpha = eluParameter.alpha();
}
}
op->name = proto.name();
op->primitive->value.type = schema::PrimitiveType_Elu;
op->primitive->value.value = attr.release();
return RET_OK;
}
CaffeNodeRegistrar g_caffeEluParser("ELU", new CaffeEluParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_
#include <vector>
#include "tools/converter/parser/caffe/caffe_node_parser.h"
#include "tools/converter/parser/caffe/caffe_node_parser_registry.h"
namespace mindspore {
namespace lite {
class CaffeEluParser : public CaffeNodeParser {
public:
CaffeEluParser() : CaffeNodeParser("elu") {}
STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op,
std::vector<schema::TensorT *> *weightVec) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_

@ -243,7 +243,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff
auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec);
if (status_node != RET_OK) {
interrupt = true;
if (status_node == RET_NOT_SUPPORT) {
if (status_node == RET_NOT_FIND_OP) {
NoSupportOp::GetInstance()->InsertOp(layer.type());
} else {
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!";

@ -156,8 +156,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
if (attr->group != 1) {
if (!ParseGroupDeConvolution(attr, op)) {
MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed";
return RET_ERROR;
MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support";
return RET_NOT_SUPPORT;
}
} else {
op->primitive->value.type = schema::PrimitiveType_DeConv2D;

@ -522,6 +522,7 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr
dst_op->primitive->value.value = attr.release();
return RET_OK;
}
schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
TensorCache tensor_cache;
// dst_graph->name = onnx_graph.name(); // this is not used
@ -593,6 +594,7 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra
SetAllTensors(tensor_cache, dst_graph.get());
return dst_graph.release();
}
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
int status = ValidateFileStr(modelFile, ".onnx");

@ -72,11 +72,12 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW,
&params) != RET_OK) {
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else {
} else if (status == RET_OK) {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);

@ -73,11 +73,12 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
auto data_index = tflite_op->inputs[2];
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW,
&params) != RET_OK) {
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else {
} else if (status == RET_OK) {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);

@ -79,11 +79,12 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
// calculate pad params
std::vector<int> params;
if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW,
&params) != RET_OK) {
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get padding params failed";
return RET_ERROR;
} else {
} else if (status == RET_OK) {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);

@ -18,6 +18,7 @@
#include <utility>
#include <memory>
#include <vector>
#include <set>
#include "tools/common/graph_util.h"
#include "tools/common/storage.h"
#include "flatbuffers/flatbuffers.h"
@ -102,11 +103,6 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
for (const auto &tflite_op : tflite_subgraph->operators) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto op_type = GetMSOpType(tflite_op_type);
if (op_type == "CUSTOM") {
auto custom_type = (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code;
MS_LOG(ERROR) << "CUSTOM op is not supported, the type is " << custom_type;
return RET_ERROR;
}
auto op = std::make_unique<schema::CNodeT>();
op->name = op_type + "-" + std::to_string(idx++);
@ -122,7 +118,9 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
if (status == RET_OK) {
status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get());
if (status != RET_OK) {
if (status == RET_NOT_SUPPORT) {
if (status == RET_NOT_FIND_OP) {
op_type =
(op_type != "Custom" ? op_type : (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code);
NoSupportOp::GetInstance()->InsertOp(op_type);
} else {
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
@ -141,6 +139,16 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::MetaGraphT *sub_graph) {
std::set<int> output_index;
for (const auto &tflite_op : tflite_subgraph->operators) {
for (size_t j = 0; j < tflite_op->outputs.size(); ++j) {
int idx = tflite_op->outputs[j];
if (idx < 0) {
idx += tflite_subgraph->tensors.size();
}
output_index.insert(idx);
}
}
for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) {
auto idx = tensorsInfo.tensorsId[i];
if (idx < 0) {
@ -173,11 +181,16 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
return status;
}
}
// set tensor attr
if (isInput || isConst) {
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
} else {
tensor->nodeType = schema::NodeType_Parameter;
if (output_index.find(idx) == output_index.end() && tflite_tensor->shape[0] == 0) {
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
} else {
tensor->nodeType = schema::NodeType_Parameter;
}
}
// quant param
@ -246,7 +259,6 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph)
if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
auto attr = op->primitive->value.AsDepthwiseConv2D();
if (attr->channelMultiplier > 1) {
std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>();
// get channel attr
if (op->inputIndex.empty()) {
MS_LOG(ERROR) << "the input of DepthwiseConv2D is null";
@ -263,7 +275,11 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph)
return RET_NULL_PTR;
}
auto data_shape = data_tensor->dims;
if (data_shape.empty()) {
MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running";
return RET_NO_CHANGE;
}
std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>();
if (data_shape[3] == 1) {
conv_attr->channelIn = data_shape[3];
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
@ -372,7 +388,7 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file,
// update for depthwiseConv
status = ConvertGroupDepthwiseOp(meta_graph.get());
if (status != RET_OK) {
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "convert group depthwise conv failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;

@ -86,6 +86,10 @@ class TfliteNodeParser {
return RET_NULL_PTR;
}
auto data_ptr = buf_data->data.data();
if (data_ptr == nullptr) {
MS_LOG(DEBUG) << "data is not a constant";
return RET_NO_CHANGE;
}
switch (tflite_tensors[tensor_index]->type) {
case tflite::TensorType_UINT8: {
for (int i = 0; i < count; i++) {

@ -71,11 +71,11 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique
break;
default:
MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support";
return RET_INVALID_OP_ATTR;
return RET_NOT_SUPPORT;
}
} else {
MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported";
return RET_NOT_SUPPORT;
return RET_NOT_FIND_OP;
}
op->primitive->value.type = schema::PrimitiveType_Pad;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save