|
|
|
@ -16,23 +16,22 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h"
|
|
|
|
|
#include "src/common/anf_importer/anf_populater/anf_conv_populater.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
|
|
|
|
|
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
|
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
|
#include "ir/primitive.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore::lite {
|
|
|
|
|
int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
|
|
|
|
|
std::vector<schema::TensorT *> *outputs) {
|
|
|
|
|
auto p = GetCNodePrimitive(cnodePtr);
|
|
|
|
|
int group = GetValue<int>(p->GetAttr("group"));
|
|
|
|
|
|
|
|
|
|
int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
|
|
|
|
|
const std::vector<AnfNodePtr> &inputs) {
|
|
|
|
|
int group = GetValue<int>(prim->GetAttr("group"));
|
|
|
|
|
auto primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
if (group > 1) {
|
|
|
|
|
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
|
|
|
|
auto format = GetValue<std::string>(p->GetAttr("data_format"));
|
|
|
|
|
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
|
|
|
|
|
if (format == "NCHW") {
|
|
|
|
|
attr->format = schema::Format_NCHW;
|
|
|
|
|
} else if (format == "NHWC") {
|
|
|
|
@ -40,25 +39,25 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
|
|
|
|
|
} else {
|
|
|
|
|
attr->format = schema::Format_NUM_OF_FORMAT;
|
|
|
|
|
}
|
|
|
|
|
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list"));
|
|
|
|
|
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
|
|
|
|
|
attr->padUp = pad_list[0];
|
|
|
|
|
attr->padDown = pad_list[1];
|
|
|
|
|
attr->padLeft = pad_list[2];
|
|
|
|
|
attr->padRight = pad_list[3];
|
|
|
|
|
|
|
|
|
|
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation"));
|
|
|
|
|
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
|
|
|
|
|
attr->dilateH = dilation[0];
|
|
|
|
|
attr->dilateW = dilation[1];
|
|
|
|
|
|
|
|
|
|
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size"));
|
|
|
|
|
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
|
|
|
|
|
attr->kernelH = kernel_size[0];
|
|
|
|
|
attr->kernelW = kernel_size[1];
|
|
|
|
|
|
|
|
|
|
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride"));
|
|
|
|
|
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
|
|
|
|
|
attr->strideH = stride[2];
|
|
|
|
|
attr->strideW = stride[3];
|
|
|
|
|
|
|
|
|
|
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode"));
|
|
|
|
|
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
|
|
|
|
|
if (pad_mode == "valid") {
|
|
|
|
|
attr->padMode = schema::PadMode_VALID;
|
|
|
|
|
} else if (pad_mode == "same") {
|
|
|
|
@ -67,14 +66,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
|
|
|
|
|
attr->padMode = schema::PadMode_NOTSET;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
node->nodeType = schema::NodeType_CNode;
|
|
|
|
|
node->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
|
|
|
|
node->primitive->value.value = attr.release();
|
|
|
|
|
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
|
|
|
|
primitive->value.value = attr.release();
|
|
|
|
|
} else {
|
|
|
|
|
auto attr = std::make_unique<schema::Conv2DT>();
|
|
|
|
|
attr->group = group;
|
|
|
|
|
auto format = GetValue<std::string>(p->GetAttr("data_format"));
|
|
|
|
|
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
|
|
|
|
|
if (format == "NCHW") {
|
|
|
|
|
attr->format = schema::Format_NCHW;
|
|
|
|
|
} else if (format == "NHWC") {
|
|
|
|
@ -82,27 +79,27 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
|
|
|
|
|
} else {
|
|
|
|
|
attr->format = schema::Format_NUM_OF_FORMAT;
|
|
|
|
|
}
|
|
|
|
|
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list"));
|
|
|
|
|
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
|
|
|
|
|
attr->padUp = pad_list[0];
|
|
|
|
|
attr->padDown = pad_list[1];
|
|
|
|
|
attr->padLeft = pad_list[2];
|
|
|
|
|
attr->padRight = pad_list[3];
|
|
|
|
|
|
|
|
|
|
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation"));
|
|
|
|
|
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
|
|
|
|
|
attr->dilateH = dilation[0];
|
|
|
|
|
attr->dilateW = dilation[1];
|
|
|
|
|
|
|
|
|
|
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size"));
|
|
|
|
|
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
|
|
|
|
|
attr->kernelH = kernel_size[0];
|
|
|
|
|
attr->kernelW = kernel_size[1];
|
|
|
|
|
|
|
|
|
|
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride"));
|
|
|
|
|
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
|
|
|
|
|
attr->strideH = stride[2];
|
|
|
|
|
attr->strideW = stride[3];
|
|
|
|
|
|
|
|
|
|
attr->channelOut = GetValue<int>(p->GetAttr("out_channel"));
|
|
|
|
|
attr->channelOut = GetValue<int>(prim->GetAttr("out_channel"));
|
|
|
|
|
|
|
|
|
|
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode"));
|
|
|
|
|
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
|
|
|
|
|
if (pad_mode == "valid") {
|
|
|
|
|
attr->padMode = schema::PadMode_VALID;
|
|
|
|
|
} else if (pad_mode == "same") {
|
|
|
|
@ -110,12 +107,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
|
|
|
|
|
} else {
|
|
|
|
|
attr->padMode = schema::PadMode_NOTSET;
|
|
|
|
|
}
|
|
|
|
|
node->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
node->primitive->value.type = schema::PrimitiveType_Conv2D;
|
|
|
|
|
node->primitive->value.value = attr.release();
|
|
|
|
|
primitive->value.type = schema::PrimitiveType_Conv2D;
|
|
|
|
|
primitive->value.value = attr.release();
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(primitiveTValuePtr != nullptr);
|
|
|
|
|
primitiveTValuePtr->SetPrimitiveT(primitive.release());
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePopulaterRegistrar anfConvParser("Conv2D", new AnfConvPopulater());
|
|
|
|
|
AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater());
|
|
|
|
|
} // namespace mindspore::lite
|