grad passer add

pull/5939/head
kai00 4 years ago
parent 0aa9f900dd
commit fe2911021f

@ -170,6 +170,7 @@ union PrimitiveType {
AddFold,
SquaredDifference,
Flatten,
FlattenGrad,
TupleGetItem,
Div,
Where,

@ -134,7 +134,8 @@ table Minimum {
table Flatten {
}
table FlattenGrad {
}
table Concat {
axis: int;
n: int;

@ -46,8 +46,6 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodeP
} else if (prim.name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
}
auto alpha = GetValue<float>(prim.GetAttr("alpha"));
attr->alpha = alpha;
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";

@ -19,7 +19,27 @@ namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_ApplyMomentum;
}
if (this->primitive_->value.type != schema::PrimitiveType_ApplyMomentum) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
auto attr = std::make_unique<schema::ApplyMomentumT>();
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
return RET_OK;
}
#else
int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);

@ -20,6 +20,7 @@
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
@ -31,6 +32,7 @@ class ApplyMomentum : public PrimitiveC {
MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC);
ApplyMomentum() = default;
explicit ApplyMomentum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
ApplyMomentum() = default;

@ -41,7 +41,6 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";

@ -24,7 +24,35 @@ float BNGrad::GetMomentum() const { return this->primitive_->value.AsBNGrad()->m
void BNGrad::SetEps(float eps) { this->primitive_->value.AsBNGrad()->eps = eps; }
void BNGrad::SetMomentum(float momentum) { this->primitive_->value.AsBNGrad()->momentum = momentum; }
int BNGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BNGrad;
}
if (this->primitive_->value.type != schema::PrimitiveType_BNGrad) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::BNGradInputT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->eps = GetValue<float>(prim.GetAttr("eps"));
attr->momentum = GetValue<float>(prim.GetAttr("momentum"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int BNGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);

@ -33,6 +33,7 @@ class BNGrad : public PrimitiveC {
explicit BNGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetEps(float eps);
void SetMomentum(float momentum);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
BNGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;

@ -116,8 +116,6 @@ void Conv2DGradFilter::PopulaterConv2DMultiGroup(const Primitive &prim, schema::
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
}
attr->channelMultiplier = channel_mutiplier;
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
}
@ -168,8 +166,6 @@ void Conv2DGradFilter::PopulaterConv2DSingleGroup(const Primitive &prim,
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
}
int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {

@ -114,8 +114,6 @@ void Conv2DGradInput::PopulaterConv2DMultiGroup(const Primitive &prim, schema::P
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
}
attr->channelMultiplier = channel_mutiplier;
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
}
@ -166,8 +164,6 @@ void Conv2DGradInput::PopulaterConv2DSingleGroup(const Primitive &prim,
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
}
int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/depend.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Depend::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Depend;
}
if (this->primitive_->value.type != schema::PrimitiveType_Depend) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow)(schema::DependT);
if (attr == nullptr) {
MS_LOG(ERROR) << "attr is nullptr";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_SRC_OPS_DEPEND_H_
#define LITE_MINDSPORE_LITE_SRC_OPS_DEPEND_H_
#include <vector>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Depend : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Depend, PrimitiveC);
Depend() = default;
explicit Depend(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Depend() = default;
#endif
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_SRC_OPS_Depend_H_

@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/flatten_grad.h"
#include <memory>
namespace mindspore {
namespace lite {
int FlattenGrad::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
auto output = outputs_.front();
if (input == nullptr || output == nullptr) {
MS_LOG(ERROR) << "FlattenGrad input or output is null!";
return RET_ERROR;
}
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return RET_INPUT_TENSOR_ERROR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
std::vector<int> output_shape(2);
output_shape[0] = input_shape[0];
output_shape[1] = 1;
for (size_t i = 1; i < input_shape.size(); i++) {
output_shape[1] *= input_shape[i];
}
output->set_shape(output_shape);
return RET_OK;
}
#ifdef PRIMITIVE_WRITEABLE
int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_FlattenGrad;
}
if (this->primitive_->value.type != schema::PrimitiveType_FlattenGrad) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::FlattenGradT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int FlattenGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateFlattenGrad(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FlattenGrad, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore

@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_FlattenGrad_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_FlattenGrad_GRAD_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class FlattenGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FlattenGrad, PrimitiveC);
FlattenGrad() = default;
explicit FlattenGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
FlattenGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_FlattenGrad_H_

@ -136,7 +136,10 @@
#include "src/ops/power_grad.h"
#include "src/ops/softmax_cross_entropy.h"
#include "src/ops/bn_grad.h"
#include "src/ops/bn_grad_input.h"
#include "src/ops/arithmetic_grad.h"
#include "src/ops/depend.h"
#include "src/ops/flatten_grad.h"
#endif
@ -397,6 +400,12 @@ std::shared_ptr<PrimitiveC> PrimitiveC::UnPackFromPrimitive(const Primitive &pri
return NewPrimitiveC<BNGradInput>(prim, inputs, quantType);
} else if (op_type == "PowerGrad") {
return NewPrimitiveC<PowerGrad>(prim, inputs, quantType);
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType);
} else if (op_type == "Depend") {
return NewPrimitiveC<Depend>(prim, inputs, quantType);
} else if (op_type == "FlattenGrad") {
return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType);
#endif
} else {
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type;
@ -638,6 +647,12 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
return new PowerGrad(primitive);
case schema::PrimitiveType_BNGradInput:
return new BNGradInput(primitive);
case schema::PrimitiveType_SoftmaxCrossEntroy:
return new SoftmaxCrossEntroy(primitive);
case schema::PrimitiveType_Depend:
return new Depend(primitive);
case schema::PrimitiveType_FlattenGrad:
return new FlattenGrad(primitive);
#endif
default:

@ -24,7 +24,33 @@ std::vector<int> SoftmaxCrossEntropy::GetAxis() const { return this->primitive_-
void SoftmaxCrossEntropy::SetAxis(const std::vector<int> &axis) {
this->primitive_->value.AsSoftmaxCrossEntropy()->axis = axis;
}
int SoftmaxCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_SoftmaxCrossEntropy;
}
if (this->primitive_->value.type != schema::PrimitiveType_SoftmaxCrossEntropy) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SoftmaxCrossEntropyT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
std::vector<int> SoftmaxCrossEntropy::GetAxis() const {

@ -33,7 +33,7 @@ class SoftmaxCrossEntropy : public PrimitiveC {
SoftmaxCrossEntropy() = default;
explicit SoftmaxCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(const std::vector<int> &axis);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
SoftmaxCrossEntropy() = default;

@ -323,6 +323,18 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
} else if (value->isa<mindspore::BoolImm>()) {
auto valueAbstract = valueNode->abstract();
auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract);
auto typePtr = abstractScalar->GetTypeTrack();
paramTensor->dataType = typePtr->type_id();
paramTensor->dims = {1};
paramTensor->nodeType = schema::NodeType_ValueNode;
auto data = value->cast<mindspore::BoolImmPtr>();
paramTensor->data.emplace_back(data->value());
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
} else if (value->isa<mindspore::ValueSequeue>()) {
MS_LOG(DEBUG) << "Value type is ValueSequence.";
return RET_OK;

Loading…
Cancel
Save