!13275 add splice to master
From: @zoloft Reviewed-by: @wangchengyuan Signed-off-by: @wangchengyuanpull/13275/MERGE
commit
596df720af
@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/splice.h"
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void Splice::Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes,
|
||||
int64_t output_dims) {
|
||||
this->set_context(contexts);
|
||||
this->set_forward_indexes(forward_indexes);
|
||||
this->set_output_dim(output_dims);
|
||||
}
|
||||
|
||||
void Splice::set_context(const std::vector<int64_t> &contexts) { this->AddAttr(kSpliceContext, MakeValue(contexts)); }
|
||||
|
||||
void Splice::set_forward_indexes(const std::vector<int64_t> &forward_indexes) {
|
||||
this->AddAttr(kSpliceForwardIndexes, MakeValue(forward_indexes));
|
||||
}
|
||||
|
||||
void Splice::set_output_dim(int64_t output_dim) { this->AddAttr(kSpliceOutputDims, MakeValue(output_dim)); }
|
||||
|
||||
std::vector<int64_t> Splice::get_context() const {
|
||||
auto value_ptr = GetAttr(kSpliceContext);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Splice::get_forward_indexes() const {
|
||||
auto value_ptr = GetAttr(kSpliceForwardIndexes);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Splice::get_output_dim() const {
|
||||
auto value_ptr = GetAttr(kSpliceOutputDims);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameSplice, Splice);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPLICE_H_
|
||||
#define MINDSPORE_CORE_OPS_SPLICE_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSplice = "Splice";
|
||||
class Splice : public PrimitiveC {
|
||||
public:
|
||||
Splice() : PrimitiveC(kNameSplice) { InitIOName({"inputs"}, {"outputs"}); }
|
||||
~Splice() = default;
|
||||
MS_DECLARE_PARENT(Splice, PrimitiveC);
|
||||
void Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes, int64_t output_dims);
|
||||
void set_context(const std::vector<int64_t> &contexts);
|
||||
void set_forward_indexes(const std::vector<int64_t> &forward_indexes);
|
||||
void set_output_dim(int64_t output_dim);
|
||||
|
||||
std::vector<int64_t> get_context() const;
|
||||
std::vector<int64_t> get_forward_indexes() const;
|
||||
int64_t get_output_dim() const;
|
||||
AbstractBasePtr SpliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SPLICE_H_
|
@ -0,0 +1,72 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/splice_parameter.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateSpliceParameter(const void *prim) {
|
||||
auto *splice_parameter = reinterpret_cast<SpliceParameter *>(malloc(sizeof(SpliceParameter)));
|
||||
if (splice_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc Splice Parameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(splice_parameter, 0, sizeof(SpliceParameter));
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
auto splice_primitive = primitive->value_as_Splice();
|
||||
splice_parameter->op_parameter_.type_ = primitive->value_type();
|
||||
|
||||
std::vector<int> primitive_context(splice_primitive->context()->begin(), splice_primitive->context()->end());
|
||||
splice_parameter->context_dim_ = static_cast<int>(primitive_context.size());
|
||||
|
||||
// malloc && memset for context
|
||||
splice_parameter->context_ = reinterpret_cast<int *>(malloc(splice_parameter->context_dim_ * sizeof(int)));
|
||||
if (splice_parameter->context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc splice_parameter context_ error";
|
||||
free(splice_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
// src_to_dst_row_offset
|
||||
int src_to_dst_row_offset = INT32_MIN;
|
||||
memset(splice_parameter->context_, 0, splice_parameter->context_dim_ * sizeof(int));
|
||||
for (int i = 0; i < splice_parameter->context_dim_; ++i) {
|
||||
splice_parameter->context_[i] = primitive_context.at(i);
|
||||
src_to_dst_row_offset = std::max(src_to_dst_row_offset, std::abs(primitive_context.at(i)));
|
||||
}
|
||||
|
||||
std::vector<int> primitive_forward_indexes(splice_primitive->forward_indexes()->begin(),
|
||||
splice_primitive->forward_indexes()->end());
|
||||
splice_parameter->forward_indexes_dim_ = static_cast<int>(primitive_forward_indexes.size());
|
||||
|
||||
// malloc && memset for forward_indexes
|
||||
splice_parameter->forward_indexes_ =
|
||||
reinterpret_cast<int *>(malloc(splice_parameter->forward_indexes_dim_ * sizeof(int)));
|
||||
if (splice_parameter->forward_indexes_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc splice_parameter forward_indexes_ error";
|
||||
free(splice_parameter->context_);
|
||||
free(splice_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
memset(splice_parameter->forward_indexes_, 0, splice_parameter->forward_indexes_dim_ * sizeof(int));
|
||||
for (int i = 0; i < splice_parameter->context_dim_; ++i) {
|
||||
splice_parameter->context_[i] = primitive_context.at(i);
|
||||
}
|
||||
splice_parameter->output_dim_ = splice_primitive->output_dim();
|
||||
return reinterpret_cast<OpParameter *>(splice_parameter);
|
||||
}
|
||||
Registry g_SpliceParameterRegistry(schema::PrimitiveType_Splice, PopulateSpliceParameter, SCHEMA_CUR);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,58 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_splice_parser.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/splice.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *OnnxSpliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx Splice Parser";
|
||||
auto primitive = std::make_unique<ops::Splice>();
|
||||
std::vector<int64_t> context;
|
||||
std::vector<int64_t> forward_indexes;
|
||||
int64_t output_dim = 0;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const std::string attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "context") {
|
||||
const int32_t size = onnx_node_attr.ints_size();
|
||||
context.resize(size);
|
||||
for (int32_t i = 0; i < size; i++) {
|
||||
context[i] = static_cast<int>(onnx_node_attr.ints(i));
|
||||
}
|
||||
} else if (attribute_name == "forward_indexes") {
|
||||
const int32_t size = onnx_node_attr.ints_size();
|
||||
forward_indexes.resize(size);
|
||||
for (int32_t i = 0; i < size; i++) {
|
||||
forward_indexes[i] = static_cast<int>(onnx_node_attr.ints(i));
|
||||
}
|
||||
} else if (attribute_name == "output_dim") {
|
||||
output_dim = static_cast<int>(onnx_node_attr.i());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported attribute in splice " << attribute_name;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
primitive->Init(context, forward_indexes, output_dim);
|
||||
return primitive.release();
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxSpliceParser("Splice", new OnnxSpliceParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxSpliceParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSpliceParser() : OnnxNodeParser("Splice") {}
|
||||
~OnnxSpliceParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_SPLICE_PARSER_H
|
Loading…
Reference in new issue