|
|
|
@ -14,21 +14,23 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include "mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h"
|
|
|
|
|
#include "utils/log_adapter.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
void CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) {
|
|
|
|
|
if (attr == nullptr || attr->group == 1) {
|
|
|
|
|
return;
|
|
|
|
|
STATUS CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op,
|
|
|
|
|
schema::Conv2DT *attr) {
|
|
|
|
|
if (attr->group == 1) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>();
|
|
|
|
|
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam
|
|
|
|
|
= std::make_unique<schema::DepthwiseConv2DT>();
|
|
|
|
|
if (depthwiseConv2DParam == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new DepthwiseConv2DT failed";
|
|
|
|
|
return;
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
depthwiseConv2DParam->format = attr->format;
|
|
|
|
|
depthwiseConv2DParam->channelIn = attr->channelIn;
|
|
|
|
|
depthwiseConv2DParam->channelMultiplier = attr->channelOut / attr->channelIn;
|
|
|
|
@ -48,19 +50,30 @@ void CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema::C
|
|
|
|
|
delete attr;
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
|
|
|
|
op->primitive->value.value = depthwiseConv2DParam.release();
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
|
|
|
|
|
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
|
|
|
|
|
op->name = proto.name();
|
|
|
|
|
std::unique_ptr<schema::Conv2DT> attr(new (std::nothrow) schema::Conv2DT());
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new Conv2DT failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto,
|
|
|
|
|
const caffe::LayerParameter &weight,
|
|
|
|
|
schema::CNodeT *op,
|
|
|
|
|
std::vector<schema::TensorT *> *weightVec) {
|
|
|
|
|
MS_LOG(DEBUG) << "parse CaffeConvolutionParser";
|
|
|
|
|
|
|
|
|
|
if (op == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "op is null";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
if (op->primitive == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "op->primitive is null";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<schema::Conv2DT> attr(new (std::nothrow) schema::Conv2DT());
|
|
|
|
|
|
|
|
|
|
attr->format = schema::Format_NCHW;
|
|
|
|
|
const caffe::ConvolutionParameter convParam = proto.convolution_param();
|
|
|
|
|
|
|
|
|
|
const caffe::ConvolutionParameter convParam = proto.convolution_param();
|
|
|
|
|
CaffeConvBaseParser convParser;
|
|
|
|
|
// parse pad
|
|
|
|
|
std::vector<int64_t> pad(4, 0);
|
|
|
|
@ -119,14 +132,21 @@ STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const c
|
|
|
|
|
attr->channelIn = weightBlob.channels() * attr->group;
|
|
|
|
|
}
|
|
|
|
|
attr->padMode = schema::PadMode_CAFFE;
|
|
|
|
|
op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
|
|
|
|
|
op->name = proto.name();
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
|
|
|
|
op->primitive->value.value = attr.get();
|
|
|
|
|
|
|
|
|
|
ParseGroupConvolution(op, attr.release());
|
|
|
|
|
status = ParseGroupConvolution(op, attr.release());
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Parse group convolution failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
status = convParser.ParseWeight(weight, weightVec);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "ParseWeight for " << proto.name().c_str() << " failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return status;
|
|
|
|
|