!9481 support controlflow in mindspore lite runtime
From: @hangangqiang Reviewed-by: Signed-off-by:pull/9481/MERGE
commit
825a9a3b74
@ -0,0 +1,78 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/merge.h"
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
||||
int Merge::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Merge;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Merge) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
this->primitive_->value.value = new (std::nothrow) schema::MergeT();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
PopulaterQuantParam(prim, inputs);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
int Merge::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Merge();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Merge return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateMerge(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Merge, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Merge>(primitive); }
|
||||
Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator);
|
||||
#endif
|
||||
|
||||
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(outputs_.size() == 1);
|
||||
MS_ASSERT(inputs_.size() == 2);
|
||||
outputs_[0]->set_data_type(inputs_[0]->data_type());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_MERGE_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_MERGE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
class Merge : public PrimitiveC {
|
||||
public:
|
||||
Merge() = default;
|
||||
~Merge() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Merge, PrimitiveC);
|
||||
explicit Merge(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_MERGE_H_
|
@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/partial.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
||||
int Partial::GetSubGraphIndex() const { return this->primitive_->value.AsPartial()->subGraphIndex; }
|
||||
|
||||
int Partial::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Partial;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Partial) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::PartialT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int Partial::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Partial();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Partial return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreatePartial(*fbb, attr->subGraphIndex());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Partial, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Partial::GetSubGraphIndex() const { return this->primitive_->value_as_Partial()->subGraphIndex(); }
|
||||
|
||||
PrimitiveC *PartialCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Partial>(primitive); }
|
||||
Registry PartialRegistry(schema::PrimitiveType_Partial, PartialCreator);
|
||||
|
||||
#endif
|
||||
|
||||
int Partial::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; }
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,48 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Partial : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Partial, PrimitiveC);
|
||||
Partial() = default;
|
||||
explicit Partial(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
|
||||
#else
|
||||
Partial() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
int GetSubGraphIndex() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_
|
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
OpParameter *PopulateMergeParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
OpParameter *merge_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
if (merge_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc Merge parameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(merge_parameter, 0, sizeof(OpParameter));
|
||||
merge_parameter->type_ = primitive->Type();
|
||||
return reinterpret_cast<OpParameter *>(merge_parameter);
|
||||
}
|
||||
Registry MergeParameterRegistry(schema::PrimitiveType_Merge, PopulateMergeParameter);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/partial.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
typedef struct PartialParameter {
|
||||
OpParameter op_parameter_;
|
||||
int sub_graph_index_;
|
||||
} PartialParameter;
|
||||
|
||||
OpParameter *PopulatePartialParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
PartialParameter *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter)));
|
||||
if (partial_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc partial parameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(partial_parameter, 0, sizeof(PartialParameter));
|
||||
partial_parameter->op_parameter_.type_ = primitive->Type();
|
||||
|
||||
auto param = reinterpret_cast<mindspore::lite::Partial *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
partial_parameter->sub_graph_index_ = param->GetSubGraphIndex();
|
||||
|
||||
return reinterpret_cast<OpParameter *>(partial_parameter);
|
||||
}
|
||||
Registry PartialParameterRegistry(schema::PrimitiveType_Partial, PopulatePartialParameter);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/switch.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateSwitchParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
OpParameter *switch_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
if (switch_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc SwitchParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(switch_parameter, 0, sizeof(OpParameter));
|
||||
switch_parameter->type_ = primitive->Type();
|
||||
|
||||
return reinterpret_cast<OpParameter *>(switch_parameter);
|
||||
}
|
||||
Registry SwitchParameterRegistry(schema::PrimitiveType_Switch, PopulateSwitchParameter);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,75 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/switch.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Switch::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Switch;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Switch) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::SwitchT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int Switch::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Switch();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Switch return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateSwitch(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Switch, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *SwitchCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Switch>(primitive); }
|
||||
Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator);
|
||||
#endif
|
||||
|
||||
int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; }
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Switch : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Switch, PrimitiveC);
|
||||
Switch() = default;
|
||||
explicit Switch(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
|
||||
#else
|
||||
Switch() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_
|
@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/base/merge.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Merge;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
// if one of input of merge is const-tensor, merge is always ready, this will cause error.
|
||||
bool MergeCPUKernel::IsReady() {
|
||||
MS_ASSERT(in_tensors().size() == 2);
|
||||
return std::any_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) {
|
||||
return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1;
|
||||
});
|
||||
}
|
||||
|
||||
int MergeCPUKernel::Init() { return RET_OK; }
|
||||
|
||||
int MergeCPUKernel::ReSize() { return RET_ERROR; }
|
||||
|
||||
int MergeCPUKernel::Run() {
|
||||
MS_ASSERT(in_tensors_.size() == 2);
|
||||
MS_ASSERT(out_tensors_.size() == 1);
|
||||
auto out_data = out_tensors_.front()->data_c();
|
||||
MS_ASSERT(out_data != nullptr);
|
||||
for (size_t i = 0; i < in_tensors().size(); i++) {
|
||||
if (in_tensors()[i]->data_c() != nullptr) {
|
||||
auto in_data = in_tensors_[i]->data_c();
|
||||
MS_ASSERT(in_data != nullptr);
|
||||
memcpy(out_data, in_data, in_tensors_[i]->Size());
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuMergeKernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::InnerContext *ctx, const KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
if (desc.type != PrimitiveType_Merge) {
|
||||
MS_LOG(ERROR) << "type in desc is not Merge";
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "ctx is nullptr";
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *kernel = new (std::nothrow) MergeCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, CpuMergeKernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, CpuMergeKernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
typedef struct MergeParameter {
|
||||
OpParameter op_parameter_;
|
||||
} MergeParameter;
|
||||
|
||||
class MergeCPUKernel : public LiteKernel {
|
||||
public:
|
||||
MergeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
merge_param_ = reinterpret_cast<MergeParameter *>(op_parameter_);
|
||||
}
|
||||
~MergeCPUKernel() override {}
|
||||
bool IsReady() override;
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
MergeParameter *merge_param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_
|
@ -0,0 +1,115 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/base/switch.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Switch;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int SwitchCPUKernel::PostProcess() {
|
||||
auto bool_tensor = in_tensors_.front();
|
||||
MS_ASSERT(bool_tensor != nullptr);
|
||||
MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
|
||||
MS_ASSERT(bool_tensor->shape().size() == 1);
|
||||
MS_ASSERT(bool_tensor->shape().front() == 1);
|
||||
auto *active = static_cast<bool *>(bool_tensor->data_c());
|
||||
if (active == nullptr) {
|
||||
MS_LOG(ERROR) << "data of bool tensor is nullptr";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
size_t in_index = 1;
|
||||
size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2);
|
||||
while (in_index < in_tensors_.size()) {
|
||||
in_index++;
|
||||
auto out_tensor = out_tensors_.at(out_index++);
|
||||
out_tensor->ResetRefCount();
|
||||
}
|
||||
return FreeInWorkTensor();
|
||||
}
|
||||
|
||||
int SwitchCPUKernel::Init() { return RET_OK; }
|
||||
|
||||
int SwitchCPUKernel::ReSize() { return RET_ERROR; }
|
||||
|
||||
// inputs: bool*1 data*n
|
||||
// output: true-data*n, false-data*n
|
||||
int SwitchCPUKernel::Run() {
|
||||
MS_ASSERT(in_tensors_.size() >= 2);
|
||||
MS_ASSERT(out_tensors_.size() == 2 * in_tensors_.size());
|
||||
auto bool_tensor = in_tensors_.front();
|
||||
MS_ASSERT(bool_tensor != nullptr);
|
||||
MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
|
||||
MS_ASSERT(bool_tensor->shape().size() == 1);
|
||||
MS_ASSERT(bool_tensor->shape().front() == 1);
|
||||
auto active = static_cast<bool *>(bool_tensor->data_c());
|
||||
if (active == nullptr) {
|
||||
MS_LOG(ERROR) << "data of bool tensor is nullptr";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
size_t in_index = 1;
|
||||
size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2);
|
||||
while (in_index < in_tensors_.size()) {
|
||||
auto in_tensor = in_tensors_.at(in_index++);
|
||||
auto out_tensor = out_tensors_.at(out_index++);
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
MS_ASSERT(out_tensor != nullptr);
|
||||
auto input = reinterpret_cast<float *>(in_tensor->data_c());
|
||||
auto output = reinterpret_cast<float *>(out_tensor->data_c());
|
||||
MS_ASSERT(in_tensor->Size() == out_tensor->Size());
|
||||
if (input == nullptr || output == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor or output tensor have not been malloced";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
memcpy(output, input, in_tensor->Size());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuSwitchKernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::InnerContext *ctx, const KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
if (desc.type != PrimitiveType_Switch) {
|
||||
MS_LOG(ERROR) << "type in desc is not Switch";
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "ctx is nullptr";
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new (std::nothrow) SwitchCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
|
||||
free(parameter);
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Switch, CpuSwitchKernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Switch, CpuSwitchKernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
typedef struct SwitchParameter {
|
||||
OpParameter op_parameter_;
|
||||
} SwitchParameter;
|
||||
|
||||
class SwitchCPUKernel : public LiteKernel {
|
||||
public:
|
||||
SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
switch_param_ = reinterpret_cast<SwitchParameter *>(op_parameter_);
|
||||
}
|
||||
~SwitchCPUKernel() override = default;
|
||||
int PostProcess() override;
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
SwitchParameter *switch_param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue