add splice kernel && opcoders

pull/13275/head
z00512249 4 years ago
parent dd90f7f055
commit 54708a8714

@ -226,6 +226,9 @@ constexpr auto kCoeff = "coeff";
constexpr auto kIsDepthWise = "is_depth_wise";
constexpr auto kZoneoutCell = "zoneout_cell";
constexpr auto kZoneoutHidden = "zoneout_hidden";
constexpr auto kSpliceContext = "context";
constexpr auto kSpliceForwardIndexes = "forward_indexes";
constexpr auto kSpliceOutputDims = "output_dim";
const std::set<TypeId> common_valid_types = {
kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16,

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/splice.h"
#include <vector>
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
void Splice::Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes,
int64_t output_dims) {
this->set_context(contexts);
this->set_forward_indexes(forward_indexes);
this->set_output_dim(output_dims);
}
void Splice::set_context(const std::vector<int64_t> &contexts) { this->AddAttr(kSpliceContext, MakeValue(contexts)); }
void Splice::set_forward_indexes(const std::vector<int64_t> &forward_indexes) {
this->AddAttr(kSpliceForwardIndexes, MakeValue(forward_indexes));
}
void Splice::set_output_dim(int64_t output_dim) { this->AddAttr(kSpliceOutputDims, MakeValue(output_dim)); }
std::vector<int64_t> Splice::get_context() const {
auto value_ptr = GetAttr(kSpliceContext);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Splice::get_forward_indexes() const {
auto value_ptr = GetAttr(kSpliceForwardIndexes);
return GetValue<std::vector<int64_t>>(value_ptr);
}
int64_t Splice::get_output_dim() const {
auto value_ptr = GetAttr(kSpliceOutputDims);
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameSplice, Splice);
} // namespace ops
} // namespace mindspore

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPLICE_H_
#define MINDSPORE_CORE_OPS_SPLICE_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSplice = "Splice";
class Splice : public PrimitiveC {
public:
Splice() : PrimitiveC(kNameSplice) { InitIOName({"inputs"}, {"outputs"}); }
~Splice() = default;
MS_DECLARE_PARENT(Splice, PrimitiveC);
void Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes, int64_t output_dims);
void set_context(const std::vector<int64_t> &contexts);
void set_forward_indexes(const std::vector<int64_t> &forward_indexes);
void set_output_dim(int64_t output_dim);
std::vector<int64_t> get_context() const;
std::vector<int64_t> get_forward_indexes() const;
int64_t get_output_dim() const;
AbstractBasePtr SpliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPLICE_H_

@ -91,6 +91,7 @@ set(CODER_OPCODERS_SRC
${MICRO_DIR}/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc
${MICRO_DIR}/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc
${MICRO_DIR}/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc
${MICRO_DIR}/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc
#### nnacl int8 coder
${MICRO_DIR}/coder/opcoders/nnacl/int8/activation_int8_coder.cc
${MICRO_DIR}/coder/opcoders/nnacl/int8/add_int8_coder.cc
@ -161,6 +162,7 @@ set(LITE_SRC
${LITE_DIR}/src/ops/populate/bias_add_populate.cc
${LITE_DIR}/src/ops/populate/activation_populate.cc
${LITE_DIR}/src/ops/populate/softmax_populate.cc
${LITE_DIR}/src/ops/populate/splice_populate.cc
### tools
${LITE_DIR}/tools/common/flag_parser.cc
)

@ -32,14 +32,10 @@ int SpliceFP32Coder::DoCode(CoderContext *const context) {
MS_LOG(ERROR) << "SpliceFP32Coder src_shape size not equal to dst_shape";
return RET_ERROR;
}
int src_row = src_shape.at(kInputIndex);
int dst_row = dst_shape.at(kInputIndex);
int src_row = src_shape.at(kWeightIndex);
int dst_row = dst_shape.at(kWeightIndex);
int src_col = src_shape.at(kBiasIndex);
int dst_col = dst_shape.at(kBiasIndex);
if (src_row != dst_row) {
MS_LOG(ERROR) << "SpliceFP32Coder src_row not equal to dst_row";
return RET_ERROR;
}
if (src_col * splice_parameter->context_dim_ != dst_col) {
MS_LOG(ERROR) << "SpliceFP32Coder src_col not match to dst_col";
return RET_ERROR;

@ -116,6 +116,7 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const DeQuantArg &
void NNaclFp32Serializer::CodeStruct(const std::string &name, const SpliceParameter &splice_parameter) {
CodeArray("splice_context", splice_parameter.context_, splice_parameter.context_dim_, false);
CodeBaseStruct("SpliceParameter", name, splice_parameter.op_parameter_, splice_parameter.context_dim_,
splice_parameter.forward_indexes_dim_, "splice_context", nullptr, splice_parameter.output_dim_);
splice_parameter.forward_indexes_dim_, splice_parameter.src_to_dst_row_offset_, "splice_context",
nullptr, splice_parameter.output_dim_);
}
} // namespace mindspore::lite::micro::nnacl

@ -17,13 +17,14 @@
#include "nnacl/fp32/splice_fp32.h"
void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter,
float *dst_data, int dst_row, int dst_col) {
int row_offset = splice_parameter->src_to_dst_row_offset_;
for (int r = 0; r < dst_row; ++r) {
for (int off = 0; off < splice_parameter->context_dim_; ++off) {
int r_off = r + splice_parameter->context_[off];
int r_off = r + row_offset + splice_parameter->context_[off];
r_off = MSMAX(r_off, 0);
r_off = MSMIN(r_off, src_row - 1);
const float *tmp_src_data = src_data + r_off * src_col * sizeof(float);
float *tmp_dst_data = dst_data + r * dst_col * sizeof(float);
const float *tmp_src_data = src_data + r_off * src_col;
float *tmp_dst_data = dst_data + r * dst_col;
memcpy(tmp_dst_data + off * src_col, tmp_src_data, src_col * sizeof(float));
}
}

@ -212,8 +212,9 @@ enum PrimType {
PrimType_SqrtGrad = 185,
PrimType_LayerNormGrad = 186,
PrimType_ResizeGrad = 187,
PrimType_Splice = 188,
PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_ResizeGrad
PrimType_MAX = PrimType_Splice
};
void RegInfer(int prim_type, InferShape func);

@ -54,3 +54,5 @@ int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
output->shape_[2] = out_dim;
return NNACL_OK;
}
REG_INFER(Splice, PrimType_Splice, SpliceInferShape)

@ -21,6 +21,7 @@ typedef struct SpliceParameter {
OpParameter op_parameter_;
int context_dim_;
int forward_indexes_dim_;
int src_to_dst_row_offset_;
int *context_;
int *forward_indexes_;
int output_dim_;

@ -205,6 +205,7 @@ union PrimitiveType {
SqrtGrad,
LayerNormGrad,
ResizeGrad,
Splice,
}
table Abs {
@ -1087,3 +1088,10 @@ table ResizeGrad {
method: ResizeMethod;
align_corners: bool;
}
table Splice {
context: [long];
forward_indexes: [long];
output_dim: long;
}

@ -204,6 +204,7 @@ OP_TYPE(RsqrtGrad)
OP_TYPE(SqrtGrad)
OP_TYPE(LayerNormGrad)
OP_TYPE(ResizeGrad)
OP_TYPE(Splice)
OP_TYPE_DEF_END(PrimitiveType)
OP_SCHEMA_DEF(Abs)
@ -1086,3 +1087,9 @@ OP_SCHEMA_DEF(ResizeGrad)
OP_ATTR_ENUM(method, ResizeMethod)
OP_ATTR(align_corners, bool)
OP_SCHEMA_DEF_END(ResizeGrad)
OP_SCHEMA_DEF(Splice)
OP_ATTR(context, [long])
OP_ATTR(forward_indexes, [long])
OP_ATTR(output_dim, long)
OP_SCHEMA_DEF_END(Splice)

@ -242,6 +242,7 @@
#include "ops/lin_space.h"
#include "ops/uniform_real.h"
#include "ops/grad/abs_grad.h"
#include "ops/splice.h"
#define FUNC_MSOP2SCHEMAOP_DECLARE(OP) \
namespace mindspore::lite::ops { \
@ -453,5 +454,6 @@ FUNC_MSOP2SCHEMAOP_DECLARE(RsqrtGrad);
FUNC_MSOP2SCHEMAOP_DECLARE(SqrtGrad);
FUNC_MSOP2SCHEMAOP_DECLARE(LayerNormGrad);
FUNC_MSOP2SCHEMAOP_DECLARE(ResizeGrad);
FUNC_MSOP2SCHEMAOP_DECLARE(Splice);
#endif
#endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_

@ -745,6 +745,11 @@ schema::PrimitiveT *ErfPrimitiveCreator(const AnfNodePtr &node) {
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *SplicePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Splice>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
@ -954,6 +959,7 @@ RegistryMSOps g_unsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiv
RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator);
RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator);
RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator);
RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator);
} // namespace lite
} // namespace mindspore

@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/op_base.h"
#include "nnacl/splice_parameter.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateSpliceParameter(const void *prim) {
auto *splice_parameter = reinterpret_cast<SpliceParameter *>(malloc(sizeof(SpliceParameter)));
if (splice_parameter == nullptr) {
MS_LOG(ERROR) << "malloc Splice Parameter failed.";
return nullptr;
}
memset(splice_parameter, 0, sizeof(SpliceParameter));
auto primitive = static_cast<const schema::Primitive *>(prim);
auto splice_primitive = primitive->value_as_Splice();
splice_parameter->op_parameter_.type_ = primitive->value_type();
std::vector<int> primitive_context(splice_primitive->context()->begin(), splice_primitive->context()->end());
splice_parameter->context_dim_ = static_cast<int>(primitive_context.size());
// malloc && memset for context
splice_parameter->context_ = reinterpret_cast<int *>(malloc(splice_parameter->context_dim_ * sizeof(int)));
if (splice_parameter->context_ == nullptr) {
MS_LOG(ERROR) << "malloc splice_parameter context_ error";
free(splice_parameter);
return nullptr;
}
// src_to_dst_row_offset
int src_to_dst_row_offset = INT32_MIN;
memset(splice_parameter->context_, 0, splice_parameter->context_dim_ * sizeof(int));
for (int i = 0; i < splice_parameter->context_dim_; ++i) {
splice_parameter->context_[i] = primitive_context.at(i);
src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i)));
}
std::vector<int> primitive_forward_indexes(splice_primitive->forward_indexes()->begin(),
splice_primitive->forward_indexes()->end());
splice_parameter->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size());
// malloc && memset for forward_indexes
splice_parameter->forward_indexes_ =
reinterpret_cast<int *>(malloc(splice_parameter->forward_indexes_dim_ * sizeof(int)));
if (splice_parameter->forward_indexes_ == nullptr) {
MS_LOG(ERROR) << "malloc splice_parameter forward_indexes_ error";
free(splice_parameter->context_);
free(splice_parameter);
return nullptr;
}
memset(splice_parameter->forward_indexes_, 0, splice_parameter->forward_indexes_dim_ * sizeof(int));
for (int i = 0; i < splice_parameter->context_dim_; ++i) {
splice_parameter->context_[i] = primitive_context.at(i);
}
splice_parameter->output_dim_ = splice_primitive->output_dim();
return reinterpret_cast<OpParameter *>(splice_parameter);
}
Registry g_SpliceParameterRegistry(schema::PrimitiveType_Splice, PopulateSpliceParameter, SCHEMA_CUR);
} // namespace lite
} // namespace mindspore

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_splice_parser.h"
#include <vector>
#include <string>
#include <memory>
#include "ops/splice.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxSpliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx Splice Parser";
auto primitive = std::make_unique<ops::Splice>();
std::vector<int64_t> context;
std::vector<int64_t> forward_indexes;
int64_t output_dim = 0;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const std::string attribute_name = onnx_node_attr.name();
if (attribute_name == "context") {
const int32_t size = onnx_node_attr.ints_size();
context.resize(size);
for (int32_t i = 0; i < size; i++) {
context[i] = static_cast<int>(onnx_node_attr.ints(i));
}
} else if (attribute_name == "forward_indexes") {
const int32_t size = onnx_node_attr.ints_size();
forward_indexes.resize(size);
for (int32_t i = 0; i < size; i++) {
forward_indexes[i] = static_cast<int>(onnx_node_attr.ints(i));
}
} else if (attribute_name == "output_dim") {
output_dim = static_cast<int>(onnx_node_attr.i());
} else {
MS_LOG(ERROR) << "unsupported attribute in splice " << attribute_name;
return nullptr;
}
}
primitive->Init(context, forward_indexes, output_dim);
return primitive.release();
}
OnnxNodeRegistrar g_onnxSpliceParser("Splice", new OnnxSpliceParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxSpliceParser : public OnnxNodeParser {
public:
OnnxSpliceParser() : OnnxNodeParser("Splice") {}
~OnnxSpliceParser() override = default;
ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H
Loading…
Cancel
Save