!7432 Adam + Sparse softmax and bug fix
Merge pull request !7432 from yonibaehr/exportpull/7432/MERGE
commit
4843f6aba0
@ -0,0 +1,91 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/adam.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
bool Adam::GetUseNesterov() const { return this->primitive_->value.AsAdam()->useNesterov; }
|
||||
int Adam::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Adam;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Adam) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = std::make_unique<schema::AdamT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov"));
|
||||
|
||||
this->primitive_->value.value = attr.release();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
bool Adam::GetUseNesterov() const { return this->primitive_->value_as_Adam()->useNesterov(); }
|
||||
int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Adam();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Adam return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateAdam(*fbb, attr->useNesterov());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adam, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
|
||||
int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
if (10 != inputs.size()) {
|
||||
MS_LOG(ERROR) << "Adam should have at least 8 input tensors";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[2]->ElementsNum() ||
|
||||
inputs[0]->ElementsNum() != inputs[9]->ElementsNum() || inputs[3]->ElementsNum() != 1 ||
|
||||
inputs[4]->ElementsNum() != 1 || inputs[5]->ElementsNum() != 1 || inputs[6]->ElementsNum() != 1 ||
|
||||
inputs[7]->ElementsNum() != 1 || inputs[8]->ElementsNum() != 1) {
|
||||
MS_LOG(ERROR) << "error input data size!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!outputs.empty()) {
|
||||
auto *out = outputs.front();
|
||||
MS_ASSERT(out != nullptr);
|
||||
out->set_data_type(inputs[0]->data_type());
|
||||
out->SetFormat(inputs[0]->GetFormat());
|
||||
out->set_shape({1});
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_ADAM_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_ADAM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Adam : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Adam, PrimitiveC);
|
||||
Adam() = default;
|
||||
explicit Adam(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
Adam() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
bool GetUseNesterov() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_ADAM_H_
|
@ -0,0 +1,82 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/assign.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Assign::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Assign;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Assign) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
this->primitive_->value.value = new (std::nothrow) schema::AssignT();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_Assign();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_Assign return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateAssign(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Assign, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
|
||||
int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
if (2 != inputs.size()) {
|
||||
MS_LOG(ERROR) << "Assign should have at least 5 input tensors";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum()) {
|
||||
MS_LOG(ERROR) << "error input data size!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!outputs.empty()) {
|
||||
auto *out = outputs.front();
|
||||
MS_ASSERT(out != nullptr);
|
||||
out->set_data_type(inputs[0]->data_type());
|
||||
out->SetFormat(inputs[0]->GetFormat());
|
||||
out->set_shape({1});
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_ASSIGN_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_ASSIGN_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Assign : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(Assign, PrimitiveC);
|
||||
Assign() = default;
|
||||
explicit Assign(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
Assign() = default;
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_ASSIGN_H_
|
@ -0,0 +1,172 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/group_conv2d_grad_input.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value.AsGroupConv2DGradInput()->format; }
|
||||
int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value.AsGroupConv2DGradInput()->group; }
|
||||
int GroupConv2DGradInput::GetChannelIn() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelIn; }
|
||||
int GroupConv2DGradInput::GetChannelOut() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelOut; }
|
||||
int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelW; }
|
||||
int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelH; }
|
||||
int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideW; }
|
||||
int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideH; }
|
||||
int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value.AsGroupConv2DGradInput()->padMode; }
|
||||
int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value.AsGroupConv2DGradInput()->padUp; }
|
||||
int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value.AsGroupConv2DGradInput()->padDown; }
|
||||
int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value.AsGroupConv2DGradInput()->padLeft; }
|
||||
int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value.AsGroupConv2DGradInput()->padRight; }
|
||||
int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateW; }
|
||||
int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateH; }
|
||||
bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value.AsGroupConv2DGradInput()->hasBias; }
|
||||
int GroupConv2DGradInput::GetActivationType() const {
|
||||
return this->primitive_->value.AsGroupConv2DGradInput()->activationType;
|
||||
}
|
||||
|
||||
void GroupConv2DGradInput::SetFormat(int format) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->format = (schema::Format)format;
|
||||
}
|
||||
void GroupConv2DGradInput::SetGroup(int group) { this->primitive_->value.AsGroupConv2DGradInput()->group = group; }
|
||||
void GroupConv2DGradInput::SetChannelIn(int channel_in) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->channelIn = channel_in;
|
||||
}
|
||||
void GroupConv2DGradInput::SetChannelOut(int channel_out) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->channelOut = channel_out;
|
||||
}
|
||||
void GroupConv2DGradInput::SetKernelW(int kernel_w) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->kernelW = kernel_w;
|
||||
}
|
||||
void GroupConv2DGradInput::SetKernelH(int kernel_h) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->kernelH = kernel_h;
|
||||
}
|
||||
void GroupConv2DGradInput::SetStrideW(int stride_w) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->strideW = stride_w;
|
||||
}
|
||||
void GroupConv2DGradInput::SetStrideH(int stride_h) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->strideH = stride_h;
|
||||
}
|
||||
void GroupConv2DGradInput::SetPadMode(int pad_mode) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->padMode = (schema::PadMode)pad_mode;
|
||||
}
|
||||
void GroupConv2DGradInput::SetPadUp(int pad_up) { this->primitive_->value.AsGroupConv2DGradInput()->padUp = pad_up; }
|
||||
void GroupConv2DGradInput::SetPadDown(int pad_down) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->padDown = pad_down;
|
||||
}
|
||||
void GroupConv2DGradInput::SetPadLeft(int pad_left) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->padLeft = pad_left;
|
||||
}
|
||||
void GroupConv2DGradInput::SetPadRight(int pad_right) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->padRight = pad_right;
|
||||
}
|
||||
void GroupConv2DGradInput::SetDilateW(int dilate_w) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->dilateW = dilate_w;
|
||||
}
|
||||
void GroupConv2DGradInput::SetDilateH(int dilate_h) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->dilateH = dilate_h;
|
||||
}
|
||||
void GroupConv2DGradInput::SetHasBias(bool has_bias) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->hasBias = has_bias;
|
||||
}
|
||||
void GroupConv2DGradInput::SetActivationType(int activation_type) {
|
||||
this->primitive_->value.AsGroupConv2DGradInput()->activationType = (schema::ActivationType)activation_type;
|
||||
}
|
||||
#else
|
||||
int GroupConv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_GroupConv2DGradInput();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_GroupConv2DGradInput return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateGroupConv2DGradInput(
|
||||
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
|
||||
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
|
||||
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GroupConv2DGradInput, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value_as_GroupConv2DGradInput()->format(); }
|
||||
int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value_as_GroupConv2DGradInput()->group(); }
|
||||
int GroupConv2DGradInput::GetChannelIn() const {
|
||||
return this->primitive_->value_as_GroupConv2DGradInput()->channelIn();
|
||||
}
|
||||
int GroupConv2DGradInput::GetChannelOut() const {
|
||||
return this->primitive_->value_as_GroupConv2DGradInput()->channelOut();
|
||||
}
|
||||
int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelW(); }
|
||||
int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelH(); }
|
||||
int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideW(); }
|
||||
int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideH(); }
|
||||
int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value_as_GroupConv2DGradInput()->padMode(); }
|
||||
int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value_as_GroupConv2DGradInput()->padUp(); }
|
||||
int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value_as_GroupConv2DGradInput()->padDown(); }
|
||||
int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value_as_GroupConv2DGradInput()->padLeft(); }
|
||||
int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value_as_GroupConv2DGradInput()->padRight(); }
|
||||
int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateW(); }
|
||||
int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateH(); }
|
||||
bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value_as_GroupConv2DGradInput()->hasBias(); }
|
||||
int GroupConv2DGradInput::GetActivationType() const {
|
||||
return this->primitive_->value_as_GroupConv2DGradInput()->activationType();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
if (3 != inputs.size()) {
|
||||
MS_LOG(ERROR) << "Conv2d Grad Input should have 3 inputs";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (1 != outputs.size()) {
|
||||
MS_LOG(ERROR) << "Conv2d Grad input should have one output";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto *in0 = inputs.at(0);
|
||||
auto *in = inputs.at(2);
|
||||
MS_ASSERT(out != nullptr);
|
||||
|
||||
std::vector<int> output_shape;
|
||||
int *out_shape = reinterpret_cast<int *>(in->MutableData());
|
||||
int new_size = in->ElementsNum();
|
||||
if (in0->GetFormat() == in->GetFormat()) {
|
||||
for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]);
|
||||
} else {
|
||||
if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) {
|
||||
output_shape.push_back(out_shape[0]);
|
||||
output_shape.push_back(out_shape[2]);
|
||||
output_shape.push_back(out_shape[3]);
|
||||
output_shape.push_back(out_shape[1]);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Shape covnert is not supported";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
auto *out = outputs.at(0);
|
||||
MS_ASSERT(out != nullptr);
|
||||
out->set_shape(output_shape);
|
||||
out->set_data_type(in0->data_type());
|
||||
out->SetFormat(in0->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,79 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class GroupConv2DGradInput : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(GroupConv2DGradInput, PrimitiveC);
|
||||
GroupConv2DGradInput() = default;
|
||||
explicit GroupConv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetFormat(int format);
|
||||
void SetGroup(int group);
|
||||
void SetChannelIn(int channel_in);
|
||||
void SetChannelOut(int channel_out);
|
||||
void SetKernelW(int kernel_w);
|
||||
void SetKernelH(int kernel_h);
|
||||
void SetStrideW(int stride_w);
|
||||
void SetStrideH(int stride_h);
|
||||
void SetPadMode(int pad_mode);
|
||||
void SetPadUp(int pad_up);
|
||||
void SetPadDown(int pad_down);
|
||||
void SetPadLeft(int pad_left);
|
||||
void SetPadRight(int pad_right);
|
||||
void SetDilateW(int dilate_w);
|
||||
void SetDilateH(int dilate_h);
|
||||
void SetHasBias(bool has_bias);
|
||||
void SetActivationType(int activation_type);
|
||||
#else
|
||||
GroupConv2DGradInput() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
int GetFormat() const;
|
||||
int GetGroup() const;
|
||||
int GetChannelIn() const;
|
||||
int GetChannelOut() const;
|
||||
int GetKernelW() const;
|
||||
int GetKernelH() const;
|
||||
int GetStrideW() const;
|
||||
int GetStrideH() const;
|
||||
int GetPadMode() const;
|
||||
int GetPadUp() const;
|
||||
int GetPadDown() const;
|
||||
int GetPadLeft() const;
|
||||
int GetPadRight() const;
|
||||
int GetDilateW() const;
|
||||
int GetDilateH() const;
|
||||
bool GetHasBias() const;
|
||||
int GetActivationType() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue