diff --git a/mindspore/lite/include/train_session.h b/mindspore/lite/include/train_session.h index fbe00bc8fb..0a7faf440a 100644 --- a/mindspore/lite/include/train_session.h +++ b/mindspore/lite/include/train_session.h @@ -20,7 +20,6 @@ #include #include #include "include/lite_session.h" -#include "include/train_model.h" namespace mindspore { namespace session { @@ -33,19 +32,23 @@ class TrainSession : public session::LiteSession { /// \brief Static method to create a TrainSession object /// + /// \param[in] model_buf A buffer that was read from a MS model file + /// \param[in] size Length of the buffer /// \param[in] context Defines the context of the session to be created + /// \param[in] train_mode training mode to initialize Session with /// /// \return Pointer of MindSpore Lite TrainSession - static TrainSession *CreateSession(lite::Context *context); + static TrainSession *CreateSession(const char *model_buf, size_t size, lite::Context *context, + bool train_mode = false); - /// \brief Compile MindSpore Lite train model - /// - /// \note CompileTrainGraph should be called before RunGraph + /// \brief Static method to create a TrainSession object /// - /// \param[in] model Define the model to be compiled + /// \param[in] filename Filename to read flatbuffer from + /// \param[in] context Defines the context of the session to be created + /// \param[in] train_mode training mode to initialize Session with /// - /// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h - virtual int CompileTrainGraph(lite::TrainModel *model) = 0; + /// \return Pointer of MindSpore Lite TrainSession + static TrainSession *CreateSession(const std::string &filename, lite::Context *context, bool train_mode = false); /// \brief Export the trained model into a buffer /// diff --git a/mindspore/lite/nnacl/fp32_grad/softmax_grad.h b/mindspore/lite/nnacl/fp32_grad/softmax_grad.h index 5acef0c2e3..a758c86e90 100644 --- a/mindspore/lite/nnacl/fp32_grad/softmax_grad.h +++ b/mindspore/lite/nnacl/fp32_grad/softmax_grad.h @@ -30,6 +30,7 @@ typedef struct SoftmaxCrossEntropyParameter { unsigned int number_of_classes_; int n_dim_; int input_shape_[5]; + int is_grad; } SoftmaxCrossEntropyParameter; void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul, diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 467b43a67a..7eab3e35ae 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -253,6 +253,7 @@ union PrimitiveType { All, Assert, Adder, + SparseSoftmaxCrossEntropy } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 1c2fcdb3da..84f5443b10 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -301,12 +301,14 @@ table BatchNorm { } table BiasGrad { - axis: [int]; } table SoftmaxCrossEntropy { - axis: [int]; +} + +table SparseSoftmaxCrossEntropy { + isGrad: int; } table make_tuple { diff --git a/mindspore/lite/src/ops/bias_grad.cc b/mindspore/lite/src/ops/bias_grad.cc index 95385a0534..162c807c05 100644 --- a/mindspore/lite/src/ops/bias_grad.cc +++ b/mindspore/lite/src/ops/bias_grad.cc @@ -23,9 +23,6 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector BiasGrad::GetAxis() const { return this->primitive_->value.AsBiasGrad()->axis; } - -void BiasGrad::SetAxis(const std::vector &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; } int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector &inputs) { if (this->primitive_ == nullptr) { this->primitive_ = new (std::nothrow) schema::PrimitiveT; @@ -45,11 +42,11 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector &i MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } - if (prim.GetAttr("axis") == nullptr) { - MS_LOG(WARNING) << "get axis failed"; - attr->axis = {0}; - } else { - attr->axis = CastToInt(prim.GetAttr("axis")); + + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; } this->primitive_->value.value = attr; } @@ -64,21 +61,12 @@ int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer MS_LOG(ERROR) << "value_as_BiasGrad return nullptr"; return RET_ERROR; } - std::vector axis; - if (attr->axis() != nullptr) { - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis.push_back(attr->axis()->data()[i]); - } - } - auto val_offset = schema::CreateBiasGradDirect(*fbb, &axis); + + auto val_offset = schema::CreateBiasGrad(*fbb); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasGrad, val_offset.o); fbb->Finish(prim_offset); return RET_OK; } -std::vector BiasGrad::GetAxis() const { - auto fb_vector = this->primitive_->value_as_BiasGrad()->axis(); - return std::vector(fb_vector->begin(), fb_vector->end()); -} PrimitiveC *BiasGradCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index f40699231d..08be441b0c 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -82,7 +82,8 @@ void ConvertConvWeight(const ParameterPtr ¶m_node) { auto weight = std::dynamic_pointer_cast(param); MS_ASSERT(weight != nullptr); - std::unique_ptr buf(new (std::nothrow) T[weight->tensor_shape_size()]); + std::unique_ptr buf(new (std::nothrow) T[weight->tensor_shape_size()]); + if (buf == nullptr) { MS_LOG(ERROR) << "new buf failed"; return; @@ -150,9 +151,13 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT attr->padRight = pad_list[3]; auto dilation = CastToInt(prim.GetAttr("dilation")); +#ifdef SUPPORT_TRAIN + attr->dilateH = dilation[2]; + attr->dilateW = dilation[3]; +#else attr->dilateH = dilation[0]; attr->dilateW = dilation[1]; - +#endif auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; attr->kernelW = kernel_size[1]; diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index e7bd75da16..61574dc337 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -110,8 +110,8 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vectorpadRight = pad_list[3]; auto dilation = CastToInt(prim.GetAttr("dilation")); - attr->dilateH = dilation[0]; - attr->dilateW = dilation[1]; + attr->dilateH = dilation[2]; + attr->dilateW = dilation[3]; auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index 1173a659f9..6872c5090e 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -111,8 +111,8 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vectorpadRight = pad_list[3]; auto dilation = CastToInt(prim.GetAttr("dilation")); - attr->dilateH = dilation[0]; - attr->dilateW = dilation[1]; + attr->dilateH = dilation[2]; + attr->dilateW = dilation[3]; auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size[0]; diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 722bc8ee57..4d638e0e0b 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -76,7 +76,7 @@ void ConvertConvWeight(const ParameterPtr ¶m_node) { auto weight = std::dynamic_pointer_cast(param); MS_ASSERT(weight != nullptr); - std::unique_ptr buf(new (std::nothrow) T[weight->tensor_shape_size()]); + std::unique_ptr buf(new (std::nothrow) T[weight->tensor_shape_size()]); if (buf == nullptr) { MS_LOG(ERROR) << "new buf failed"; return; diff --git a/mindspore/lite/src/ops/dropout.cc b/mindspore/lite/src/ops/dropout.cc index 32c6baa7a8..fae4e37ca3 100644 --- a/mindspore/lite/src/ops/dropout.cc +++ b/mindspore/lite/src/ops/dropout.cc @@ -89,7 +89,6 @@ int Dropout::InferShape(std::vector inputs_, std::vector out output0->set_shape(input->shape()); output0->set_data_type(input->data_type()); output0->set_format(input->format()); - if (outputs_.size() > 1) { auto output1 = outputs_[1]; MS_ASSERT(output1 != nullptr); @@ -97,7 +96,6 @@ int Dropout::InferShape(std::vector inputs_, std::vector out output1->set_data_type(input->data_type()); output1->set_format(input->format()); } - return RET_OK; } diff --git a/mindspore/lite/src/ops/dropout_grad.cc b/mindspore/lite/src/ops/dropout_grad.cc index e459dc65b3..f7dcb3f0f1 100644 --- a/mindspore/lite/src/ops/dropout_grad.cc +++ b/mindspore/lite/src/ops/dropout_grad.cc @@ -92,9 +92,7 @@ int DropoutGrad::InferShape(std::vector inputs_, std::vector output->set_shape(input->shape()); output->set_data_type(input->data_type()); output->set_format(input->format()); - return RET_OK; } - } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/tile_populate.cc b/mindspore/lite/src/ops/populate/tile_populate.cc index 7608a307f1..b7e9ff3c62 100644 --- a/mindspore/lite/src/ops/populate/tile_populate.cc +++ b/mindspore/lite/src/ops/populate/tile_populate.cc @@ -31,6 +31,13 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive) memset(tile_param, 0, sizeof(TileParameter)); tile_param->op_parameter_.type_ = primitive->Type(); auto param = reinterpret_cast(const_cast(primitive)); +#ifdef SUPPORT_TRAIN + auto multiples = param->GetMultiples(); + tile_param->in_dim_ = multiples.size(); + for (int i = 0; i < tile_param->in_dim_; ++i) { + tile_param->multiples_[i] = multiples[i]; + } +#else auto dims = param->GetDims(); auto multiples = param->GetMultiples(); for (size_t i = 0; i < kDimension_4d; ++i) { @@ -39,6 +46,7 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive) for (size_t i = 0; i < dims.size(); ++i) { tile_param->multiples_[dims[i]] = multiples[i]; } +#endif return reinterpret_cast(tile_param); } diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 17a31645fb..1c75d2ab52 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -161,6 +161,7 @@ #include "src/ops/group_conv2d_grad_input.h" #include "src/ops/power_grad.h" #include "src/ops/softmax_cross_entropy.h" +#include "src/ops/sparse_softmax_cross_entropy.h" #include "src/ops/bn_grad.h" #include "src/ops/arithmetic_grad.h" #include "src/ops/depend.h" @@ -578,6 +579,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: #ifdef SUPPORT_TRAIN } else if (op_type == "SoftmaxCrossEntropyWithLogits") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "SparseSoftmaxCrossEntropyWithLogits") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "BiasAddGrad") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "ApplyMomentum") { @@ -916,6 +919,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) ArithmeticGrad(primitive); case schema::PrimitiveType_SoftmaxCrossEntropy: return new (std::nothrow) SoftmaxCrossEntropy(primitive); + case schema::PrimitiveType_SparseSoftmaxCrossEntropy: + return new (std::nothrow) SparseSoftmaxCrossEntropy(primitive); case schema::PrimitiveType_PowerGrad: return new (std::nothrow) PowerGrad(primitive); case schema::PrimitiveType_Depend: diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.cc b/mindspore/lite/src/ops/softmax_cross_entropy.cc index 66af7ba552..483bd7363b 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.cc +++ b/mindspore/lite/src/ops/softmax_cross_entropy.cc @@ -23,11 +23,6 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector SoftmaxCrossEntropy::GetAxis() const { return this->primitive_->value.AsSoftmaxCrossEntropy()->axis; } - -void SoftmaxCrossEntropy::SetAxis(const std::vector &axis) { - this->primitive_->value.AsSoftmaxCrossEntropy()->axis = axis; -} int SoftmaxCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector &inputs) { if (this->primitive_ == nullptr) { this->primitive_ = new (std::nothrow) schema::PrimitiveT; @@ -48,7 +43,6 @@ int SoftmaxCrossEntropy::UnPackAttr(const Primitive &prim, const std::vectoraxis = {0}; this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { MS_LOG(ERROR) << "primitive value is nullptr"; @@ -59,10 +53,6 @@ int SoftmaxCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector SoftmaxCrossEntropy::GetAxis() const { - auto fb_vector = this->primitive_->value_as_SoftmaxCrossEntropy()->axis(); - return std::vector(fb_vector->begin(), fb_vector->end()); -} int SoftmaxCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); @@ -71,13 +61,8 @@ int SoftmaxCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, MS_LOG(ERROR) << "value_as_SoftmaxCrossEntropy return nullptr"; return RET_ERROR; } - std::vector axis; - if (attr->axis() != nullptr) { - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis.push_back(attr->axis()->data()[i]); - } - } - auto val_offset = schema::CreateSoftmaxCrossEntropyDirect(*fbb, &axis); + + auto val_offset = schema::CreateSoftmaxCrossEntropy(*fbb); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SoftmaxCrossEntropy, val_offset.o); fbb->Finish(prim_offset); return RET_OK; @@ -100,6 +85,7 @@ int SoftmaxCrossEntropy::InferShape(std::vector inputs, std::vector outshape; + outshape.push_back(in0->shape()[0]); outshape.push_back(1); out->set_shape(outshape); out->set_data_type(in0->data_type()); diff --git a/mindspore/lite/src/ops/sparse_softmax_cross_entropy.cc b/mindspore/lite/src/ops/sparse_softmax_cross_entropy.cc new file mode 100644 index 0000000000..751afb084d --- /dev/null +++ b/mindspore/lite/src/ops/sparse_softmax_cross_entropy.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/sparse_softmax_cross_entropy.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int SparseSoftmaxCrossEntropy::GetIsGrad() const { + return this->primitive_->value.AsSparseSoftmaxCrossEntropy()->isGrad; +} + +void SparseSoftmaxCrossEntropy::SetIsGrad(int isGrad) { + this->primitive_->value.AsSparseSoftmaxCrossEntropy()->isGrad = isGrad; +} + +int SparseSoftmaxCrossEntropy::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_SparseSoftmaxCrossEntropy; + } + if (this->primitive_->value.type != schema::PrimitiveType_SparseSoftmaxCrossEntropy) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SparseSoftmaxCrossEntropyT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + + attr->isGrad = GetValue(prim.GetAttr("is_grad")); + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} +#else + +int SparseSoftmaxCrossEntropy::GetIsGrad() const { + return this->primitive_->value_as_SparseSoftmaxCrossEntropy()->isGrad(); +} +int SparseSoftmaxCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, + flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_SparseSoftmaxCrossEntropy(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_SparseSoftmaxCrossEntropy return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateSparseSoftmaxCrossEntropy(*fbb, attr->isGrad()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SparseSoftmaxCrossEntropy, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *SparseSoftmaxCrossEntropyCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry SparseSoftmaxCrossEntropyRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropy, + SparseSoftmaxCrossEntropyCreator); +#endif + +int SparseSoftmaxCrossEntropy::InferShape(std::vector inputs, std::vector outputs) { + if (2 != inputs.size()) { + MS_LOG(ERROR) << "SparseSoftmaxCrossEntropy should have at two inputs"; + return RET_ERROR; + } + + if (1 != outputs.size()) { + MS_LOG(ERROR) << "SparseSoftmaxCrossEntropy should have one output"; + return RET_ERROR; + } + auto *in0 = inputs.front(); + MS_ASSERT(in0 != nullptr); + auto *out = outputs.front(); + MS_ASSERT(out != nullptr); + + if (GetIsGrad() != 0) { + out->set_shape(in0->shape()); + out->set_data_type(in0->data_type()); + out->set_format(in0->format()); + } else { + std::vector outshape; + outshape.push_back(1); + out->set_shape(outshape); + out->set_data_type(in0->data_type()); + out->set_format(in0->format()); + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/sparse_softmax_cross_entropy.h b/mindspore/lite/src/ops/sparse_softmax_cross_entropy.h new file mode 100644 index 0000000000..21cfbad3ef --- /dev/null +++ b/mindspore/lite/src/ops/sparse_softmax_cross_entropy.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_H_ +#define MINDSPORE_LITE_SRC_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_H_ + +#include +#include +#include +#include + +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class SparseSoftmaxCrossEntropy : public PrimitiveC { + public: + SparseSoftmaxCrossEntropy() = default; + ~SparseSoftmaxCrossEntropy() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(SparseSoftmaxCrossEntropy, PrimitiveC); + explicit SparseSoftmaxCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + void SetIsGrad(int isGrad); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + + int GetIsGrad() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_OPS_SPARSE_SOFTMAX_CROSS_ENTROPY_H_ diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 5dd481cbc2..754e4fee19 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -140,6 +140,21 @@ int Tile::InferShape(std::vector inputs_, std::vector output std::vector out_shape; std::vector multiples = GetMultiples(); + +#ifdef SUPPORT_TRAIN + const size_t in_dims = input->shape().size(); + const size_t delta_dims = in_dims - multiples.size(); + + size_t i = 0; + for (; i < delta_dims; ++i) { + int tmp = input->shape().at(i); + out_shape.push_back(tmp); + } + for (; i < in_dims; ++i) { + int tmp = input->shape().at(i) * (multiples[i - delta_dims]); + out_shape.push_back(tmp); + } +#else std::vector dims = GetDims(); const size_t in_dims = input->shape().size(); @@ -150,7 +165,7 @@ int Tile::InferShape(std::vector inputs_, std::vector output for (size_t i = 0; i < dims.size(); ++i) { out_shape[dims[i]] = input->shape()[dims[i]] * (multiples[i]); } - +#endif output->set_shape(out_shape); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc index 21ca174cca..e040e0f337 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc @@ -45,22 +45,28 @@ int AdamCPUKernel::Execute(int task_id) { auto eps = reinterpret_cast(in_tensors_[8]->MutableData())[0]; auto gradient = reinterpret_cast(in_tensors_[9]->MutableData()); size_t elem_num = in_tensors_[0]->ElementsNum(); - if (fabs(1 - beta1_power) <= 0.0f) { - MS_LOG(ERROR) << "divisor cannot be 0"; + + if ((1.f - beta1_power) <= 0.0f) { + MS_LOG(ERROR) << "divisor cannot be 0 or below"; + return RET_ERROR; + } + if ((1.f - beta2_power) < 0.0f) { + MS_LOG(ERROR) << "sqrt cannot be negative"; return RET_ERROR; } - auto update_lr = learning_rate * std::sqrt(1 - beta2_power) / (1 - beta1_power); + + auto update_lr = learning_rate * std::sqrt(1.f - beta2_power) / (1.f - beta1_power); if (adam_param_->use_nesterov_) { // Nadam for (size_t i = 0; i < elem_num; ++i) { - m[i] += (gradient[i] - m[i]) * (1 - beta1); - v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2); - weight[i] -= update_lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + eps); + m[i] += (gradient[i] - m[i]) * (1.f - beta1); + v[i] += (gradient[i] * gradient[i] - v[i]) * (1.f - beta2); + weight[i] -= update_lr * (m[i] * beta1 + (1.f - beta1) * gradient[i]) / (std::sqrt(v[i]) + eps); } } else { for (size_t i = 0; i < elem_num; ++i) { - m[i] += (gradient[i] - m[i]) * (1 - beta1); - v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2); + m[i] += (gradient[i] - m[i]) * (1.f - beta1); + v[i] += (gradient[i] * gradient[i] - v[i]) * (1.f - beta2); weight[i] -= update_lr * m[i] / (std::sqrt(v[i]) + eps); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc index 367112b877..7fa2eafa8b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc @@ -115,7 +115,6 @@ kernel::LiteKernel *CpuDropoutFp32KernelCreator(const std::vectorInit(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc index b4af30a444..bb62ba40f8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc @@ -102,7 +102,6 @@ kernel::LiteKernel *CpuDropoutGradFp32KernelCreator(const std::vectorInit(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc index a7ae2b03ba..309a583c91 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc @@ -39,16 +39,37 @@ int SgdCPUKernel::Execute(int task_id) { auto gradient = reinterpret_cast(in_tensors_[1]->MutableData()); float moment = reinterpret_cast(in_tensors_[4]->MutableData())[0]; size_t elem_num = in_tensors_[0]->ElementsNum(); - - if (sgd_param_->use_nesterov_) { - for (size_t i = 0; i < elem_num; ++i) { - accumulate[i] = accumulate[i] * moment + gradient[i]; - weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; + auto stat = reinterpret_cast(in_tensors_[5]->MutableData()); + + if (stat[0] > 0) { + stat[0] = 0; + memcpy(accumulate, gradient, elem_num * sizeof(float)); + if (sgd_param_->use_nesterov_) { + for (size_t i = 0; i < elem_num; ++i) { + weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; + } + } else { + for (size_t i = 0; i < elem_num; ++i) { + weight[i] -= accumulate[i] * learning_rate; + } } } else { - for (size_t i = 0; i < elem_num; ++i) { - accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); - weight[i] -= accumulate[i] * learning_rate; + if (moment > 0.f) { + if (sgd_param_->use_nesterov_) { + for (size_t i = 0; i < elem_num; ++i) { + accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); + weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; + } + } else { + for (size_t i = 0; i < elem_num; ++i) { + accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); + weight[i] -= accumulate[i] * learning_rate; + } + } + } else { + for (size_t i = 0; i < elem_num; ++i) { + weight[i] -= gradient[i] * learning_rate; + } } } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc index a094af8213..4f2f4ed0b4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc @@ -34,27 +34,29 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; } void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2) const { float eps = 1e-6; - float total_loss = 0.0; if (grads != nullptr) { for (int i = 0; i < param_->batch_size_; ++i) { + float loss = 0.f; for (size_t j = 0; j < param_->number_of_classes_; ++j) { float logit = -logf(logits[i * param_->number_of_classes_ + j] <= 0.0 ? eps : logits[i * param_->number_of_classes_ + j]); grads[i * param_->number_of_classes_ + j] = (logits[i * param_->number_of_classes_ + j] - labels[i * param_->number_of_classes_ + j]); - total_loss += labels[i * param_->number_of_classes_ + j] * logit; + loss += labels[i * param_->number_of_classes_ + j] * logit; } + output2[i] = loss; } } else { for (int i = 0; i < param_->batch_size_; ++i) { + float loss = 0.f; for (size_t j = 0; j < param_->number_of_classes_; ++j) { float logit = -logf(logits[i * param_->number_of_classes_ + j] <= 0.0 ? eps : logits[i * param_->number_of_classes_ + j]); - total_loss += labels[i * param_->number_of_classes_ + j] * logit; + loss += labels[i * param_->number_of_classes_ + j] * logit; } + output2[i] = loss; } } - output2[0] = total_loss / param_->batch_size_; } int SoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index 4e2b339081..c5768e839c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -25,7 +25,7 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_SoftmaxCrossEntropy; +using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropy; namespace mindspore::kernel { @@ -51,10 +51,9 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int * return RET_OK; } -int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, float *grads, - float *output) const { +int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, + float *grads) const { size_t row_start = 0; - float total_loss = 0; for (int i = 0; i < param->batch_size_; ++i) { if (labels[i] < 0) { MS_LOG(ERROR) << "label value must >= 0"; @@ -65,7 +64,6 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *lab MS_LOG(ERROR) << "error label input!"; return RET_ERROR; } else { - total_loss -= logf(losses[i * param->number_of_classes_ + label]); for (size_t j = 0; j < param->number_of_classes_; ++j) { size_t index = row_start + j; if (j == label) { @@ -77,18 +75,14 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *lab } row_start += param->number_of_classes_; } - output[0] = total_loss / param->batch_size_; return RET_OK; } int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { + auto sce_param = reinterpret_cast(op_parameter_); auto ins = reinterpret_cast(in_tensors_.at(0)->data_c()); auto labels = reinterpret_cast(in_tensors_.at(1)->data_c()); float *out = reinterpret_cast(out_tensors_.at(0)->data_c()); - float *grads = nullptr; - if (IsTrain() && out_tensors_.size() > 1) { - grads = reinterpret_cast(out_tensors_.at(1)->MutableData()); - } size_t data_size = in_tensors_.at(0)->ElementsNum(); MS_ASSERT(out != nullptr); MS_ASSERT(labels != nullptr); @@ -99,8 +93,8 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { std::fill(losses_, losses_ + data_size, 0.f); std::fill(sum_data_, sum_data_ + sm_params_.input_shape_[0], 0.f); Softmax(ins, losses_, sum_data_, &sm_params_); - if (IsTrain()) { - GradPostExecute(labels, losses_, grads, out); + if (sce_param->is_grad) { + GradPostExecute(labels, losses_, out); } else { ForwardPostExecute(labels, losses_, out); } @@ -133,12 +127,12 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { param->batch_size_ = dims[0]; for (unsigned int i = 0; i < dims.size(); i++) param->input_shape_[i] = dims[i]; if (2 != this->in_tensors_.size()) { - MS_LOG(ERROR) << "softmax entropy loss should have two inputs"; + MS_LOG(ERROR) << "sparse softmax entropy loss should have two inputs"; return RET_ERROR; } auto *in0 = in_tensors_.front(); if (in0 == nullptr) { - MS_LOG(ERROR) << "softmax etropy loss in0 have no data"; + MS_LOG(ERROR) << "sparse softmax etropy loss in0 have no data"; return RET_ERROR; } size_t data_size = in_tensors_.at(0)->ElementsNum(); @@ -155,7 +149,7 @@ kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyFp32KernelCreator( const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_SoftmaxCrossEntropy); + MS_ASSERT(desc.type == schema::PrimitiveType_SparseSoftmaxCrossEntropy); auto *kernel = new (std::nothrow) SparseSoftmaxCrossEntropyWithLogitsCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { @@ -172,4 +166,6 @@ kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyFp32KernelCreator( } return kernel; } +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseSoftmaxCrossEntropy, + CpuSparseSoftmaxCrossEntropyFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h index d26b53b6d6..57e39cf2d8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.h @@ -38,7 +38,7 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel { ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override {} int ForwardPostExecute(const int *labels, const float *losses, float *output) const; - int GradPostExecute(const int *labels, const float *losses, float *grads, float *output) const; + int GradPostExecute(const int *labels, const float *losses, float *grads) const; int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/train/train_model.cc b/mindspore/lite/src/train/train_model.cc index ee93a62c47..3594ad3dbe 100644 --- a/mindspore/lite/src/train/train_model.cc +++ b/mindspore/lite/src/train/train_model.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/ops/primitive_c.h" -#include "include/train_model.h" +#include "src/train/train_model.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" #include "src/common/graph_util.h" diff --git a/mindspore/lite/include/train_model.h b/mindspore/lite/src/train/train_model.h similarity index 91% rename from mindspore/lite/include/train_model.h rename to mindspore/lite/src/train/train_model.h index e1a3366761..486e3e03d0 100644 --- a/mindspore/lite/include/train_model.h +++ b/mindspore/lite/src/train/train_model.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_MODEL_H_ -#define MINDSPORE_LITE_INCLUDE_TRAIN_MODEL_H_ +#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ +#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ #include #include "include/model.h" @@ -50,4 +50,4 @@ struct TrainModel : public lite::Model { } // namespace lite } // namespace mindspore -#endif // MINDSPORE_LITE_INCLUDE_TRAIN_MODEL_H_ +#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index be9c602263..9988f3c2df 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -19,6 +19,7 @@ #include "src/ops/pooling_grad.h" #include "nnacl/pooling_parameter.h" #include "src/ops/softmax_cross_entropy.h" +#include "src/ops/sparse_softmax_cross_entropy.h" #include "nnacl/fp32_grad/softmax_grad.h" #include "src/ops/activation_grad.h" #include "nnacl/fp32/activation_fp32.h" @@ -146,6 +147,26 @@ OpParameter *PopulateSgdParameter(const mindspore::lite::PrimitiveC *primitive) return reinterpret_cast(p); } +OpParameter *PopulateSparseSoftmaxCrossEntropyParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + SoftmaxCrossEntropyParameter *sce_param = + reinterpret_cast(malloc(sizeof(SoftmaxCrossEntropyParameter))); + if (sce_param == nullptr) { + MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed."; + return nullptr; + } + auto sce_primitive = reinterpret_cast( + const_cast(primitive)); + + sce_param->is_grad = sce_primitive->GetIsGrad(); + + sce_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(sce_param); +} + OpParameter *PopulateSoftmaxCrossEntropyParameter(const mindspore::lite::PrimitiveC *primitive) { if (primitive == nullptr) { MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; @@ -157,6 +178,7 @@ OpParameter *PopulateSoftmaxCrossEntropyParameter(const mindspore::lite::Primiti MS_LOG(ERROR) << "malloc SoftmaxCrossEntropyParameter failed."; return nullptr; } + sce_param->is_grad = 0; sce_param->op_parameter_.type_ = primitive->Type(); return reinterpret_cast(sce_param); } @@ -468,6 +490,8 @@ void PopulateTrainParameters() { lite::Registry BiasGradParameterRegistry(schema::PrimitiveType_BiasGrad, PopulateBiasGradParameter); lite::Registry SoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SoftmaxCrossEntropy, PopulateSoftmaxCrossEntropyParameter); + lite::Registry SparseSoftmaxCrossEntropyParameterRegistry(schema::PrimitiveType_SparseSoftmaxCrossEntropy, + PopulateSparseSoftmaxCrossEntropyParameter); lite::Registry ActivationParameterRegistry(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter); lite::Registry TupleGetItemParameterRegistry(schema::PrimitiveType_TupleGetItem, DefaultPopulateParameter); lite::Registry DependParameterRegistry(schema::PrimitiveType_Depend, DefaultPopulateParameter); diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 2497a004c5..5ce7a67f2b 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -21,8 +21,8 @@ #include #include #include +#include #include "include/errorcode.h" -#include "include/train_model.h" #include "src/common/utils.h" #include "src/tensor.h" #include "src/train/loss_kernel.h" @@ -72,18 +72,9 @@ void TrainSession::RestoreOps(const std::vector &restore) { void TrainSession::AllocWorkSpace() { size_t workspace_size = 0; - for (auto ori_kernel : kernels_) { - if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { - if (workspace_size < ori_kernel->workspace_size()) { - workspace_size = ori_kernel->workspace_size(); - } - } else { - auto sub_graph = reinterpret_cast(ori_kernel); - for (auto kernel : sub_graph->nodes()) { - if (workspace_size < kernel->workspace_size()) { - workspace_size = kernel->workspace_size(); - } - } + for (auto kernel : this->train_kernels_) { + if (workspace_size < kernel->workspace_size()) { + workspace_size = kernel->workspace_size(); } } mindspore::kernel::LiteKernel::AllocWorkspace(workspace_size); @@ -92,40 +83,27 @@ void TrainSession::AllocWorkSpace() { int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; } int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) { - if (model == nullptr) { - MS_LOG(ERROR) << "model is null"; - return RET_ERROR; - } model_ = model; + auto restore = ReplaceOps(); auto ret = lite::LiteSession::CompileGraph(model); if (ret != RET_OK) { - MS_LOG(ERROR) << "Compile train graph failed"; + MS_LOG(ERROR) << "failed to compile train model"; return RET_ERROR; } - orig_output_map_ = output_node_map_; + orig_output_node_map_ = output_node_map_; orig_output_tensor_map_ = output_tensor_map_; - for (auto inTensor : inputs_) { - inTensor->MutableData(); - } + + for (auto inTensor : inputs_) inTensor->MutableData(); RestoreOps(restore); + CompileTrainKernels(); // Prepare a list of train kernels + CompileInferenceKernels(); // Prepare a list of eval kernels + CompileOptimizedKernels(); // Prepare a list of kenels which are optimized (weight update step) + CompileTrainOutputs(); // prepare outputs in train mode + CompileEvalOutputs(); // prepare outputs in eval mode AllocWorkSpace(); - MarkOptimizedKernels(); - CompileTrainKernels(); - if (train_mode_) { - auto ret1 = Train(); - if (ret1 != RET_OK) { - MS_LOG(ERROR) << "faild to initialize network in train mode"; - return RET_ERROR; - } - } else { - auto ret1 = Eval(); - if (ret1 != RET_OK) { - MS_LOG(ERROR) << "faild to initialize network in eval mode"; - return RET_ERROR; - } - } - return ret; + + return RET_OK; } TrainSession::~TrainSession() { @@ -180,219 +158,144 @@ int TrainSession::SaveToFile(const std::string &filename) const { } int TrainSession::Train() { - for (auto ori_kernel : kernels_) { - MS_ASSERT(nullptr != ori_kernel); - if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { - auto ret = ori_kernel->Train(); - if (ret != RET_OK) { - MS_LOG(ERROR) << ori_kernel->name() << " failed to set train mode"; - return RET_ERROR; - } - } else { - auto sub_graph = reinterpret_cast(ori_kernel); - MS_ASSERT(nullptr != sub_graph); - for (auto kernel : sub_graph->nodes()) { - MS_ASSERT(nullptr != kernel); - auto ret = kernel->Train(); - if (ret != RET_OK) { - MS_LOG(ERROR) << kernel->name() << " failed to set train mode"; - return RET_ERROR; - } - } - } - } - output_node_map_.clear(); - output_tensor_map_.clear(); + // shift kernels to train mode train_mode_ = true; - for (auto ori_kernel : kernels_) { - MS_ASSERT(nullptr != ori_kernel); - if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { - UpdateOutputMapByLossKernel(ori_kernel); - } else { - auto sub_graph = reinterpret_cast(ori_kernel); - MS_ASSERT(nullptr != sub_graph); - for (auto kernel : sub_graph->nodes()) { - MS_ASSERT(nullptr != kernel); - UpdateOutputMapByLossKernel(kernel); - } + for (auto kernel : this->train_kernels_) { + MS_ASSERT(nullptr != kernel); + auto ret = kernel->Train(); + if (ret != RET_OK) { + MS_LOG(ERROR) << kernel->name() << " failed to set train mode"; + return RET_ERROR; } } + // set train outputs + output_node_map_ = train_output_node_map_; + output_tensor_map_ = train_output_tensor_map_; + return RET_OK; } -void TrainSession::UpdateOutputMapByLossKernel(const kernel::LiteKernel *kernel) { - if (kernel != nullptr && IsLossKernel(kernel)) { - auto *ms_tensor = kernel->out_tensors().at(0); - if (ms_tensor != nullptr) { - (void)ms_tensor->MutableData(); - output_node_map_[kernel->name()].emplace_back(ms_tensor); - auto index = TSFindTensor(tensors_, ms_tensor); - if (index != tensors_.size()) { - output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); - } +int TrainSession::Eval() { + // shift kernels to eval mode + train_mode_ = false; + for (auto kernel : this->train_kernels_) { + MS_ASSERT(kernel != nullptr); + auto ret = kernel->Eval(); + if (ret != RET_OK) { + MS_LOG(ERROR) << kernel->name() << " failed to set eval mode"; + return RET_ERROR; } } + // set eval outputs + output_node_map_ = eval_output_node_map_; + output_tensor_map_ = eval_output_tensor_map_; + return RET_OK; } -void TrainSession::UpdateOutputMapByInKernel(const kernel::LiteKernel *kernel) { - if (kernel != nullptr && IsLossKernel(kernel)) { - for (auto in_kernel : kernel->in_kernels()) { - if (output_node_map_.find(in_kernel->name()) == output_node_map_.end()) { - auto *ms_tensor = in_kernel->out_tensors().at(0); - if (ms_tensor != nullptr) { - output_node_map_[in_kernel->name()].emplace_back(ms_tensor); - auto index = TSFindTensor(tensors_, ms_tensor); - if (index != tensors_.size()) { - output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); +void TrainSession::CompileEvalOutputs() { + eval_output_node_map_.clear(); + eval_output_tensor_map_.clear(); + for (auto kernel : this->train_kernels_) { + if (IsLossKernel(kernel)) { + for (auto in_kernel : kernel->in_kernels()) { + // insert if not already in + if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) { + auto *ms_tensor = in_kernel->out_tensors().at(0); + if (ms_tensor != nullptr) { + eval_output_node_map_[in_kernel->name()].emplace_back(ms_tensor); + auto index = TSFindTensor(tensors_, ms_tensor); + if (index != tensors_.size()) { + eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); + } } } } } } + if (eval_output_node_map_.size() == 0) eval_output_node_map_ = orig_output_node_map_; + if (eval_output_tensor_map_.size() == 0) eval_output_tensor_map_ = orig_output_tensor_map_; } -int TrainSession::Eval() { - for (auto ori_kernel : kernels_) { - MS_ASSERT(nullptr != ori_kernel); - if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { - auto ret = ori_kernel->Eval(); - if (ret != RET_OK) { - MS_LOG(ERROR) << ori_kernel->name() << " failed to set eval mode"; - return RET_ERROR; - } - } else { - auto sub_graph = reinterpret_cast(ori_kernel); - MS_ASSERT(nullptr != sub_graph); - for (auto kernel : sub_graph->nodes()) { - MS_ASSERT(nullptr != kernel); - auto ret = kernel->Eval(); - if (ret != RET_OK) { - MS_LOG(ERROR) << kernel->name() << " failed to set eval mode"; - return RET_ERROR; +void TrainSession::CompileTrainOutputs() { + train_output_node_map_.clear(); + train_output_tensor_map_.clear(); + for (auto kernel : this->train_kernels_) { + if (orig_output_node_map_.find(kernel->name()) == orig_output_node_map_.end()) continue; + // Mask out optimizer out tensors + if (IsMaskOutput(kernel)) continue; + // insert if not already in + if (train_output_node_map_.find(kernel->name()) == train_output_node_map_.end()) { + auto *ms_tensor = kernel->out_tensors().at(0); + if (ms_tensor != nullptr) { + train_output_node_map_[kernel->name()].emplace_back(ms_tensor); + auto index = TSFindTensor(tensors_, ms_tensor); + if (index != tensors_.size()) { + train_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); } } } } - output_node_map_ = orig_output_map_; - output_tensor_map_ = orig_output_tensor_map_; - - train_mode_ = false; - for (auto ori_kernel : kernels_) { - if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { - UpdateOutputMapByInKernel(ori_kernel); - } else { - auto sub_graph = reinterpret_cast(ori_kernel); - for (auto kernel : sub_graph->nodes()) { - UpdateOutputMapByInKernel(kernel); - } - } - } - if (inference_kernels_.size() == 0) { - BuildInferenceKernelsMap(); - } - return RET_OK; + if (train_output_node_map_.size() == 0) train_output_node_map_ = orig_output_node_map_; + if (train_output_tensor_map_.size() == 0) train_output_tensor_map_ = orig_output_tensor_map_; } void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, std::vector *v) { - if (std::find(v->begin(), v->end(), kernel) == v->end()) { // kernel is not in vector - v->push_back(kernel); + if (std::find(v->begin(), v->end(), kernel) == v->end()) { // kernel is not already in vector + if (!IsLossKernel(kernel)) v->push_back(kernel); for (auto in_node : kernel->in_kernels()) { BuildInferenceKernelsRecursive(in_node, v); } } } -void TrainSession::BuildInferenceKernelsMap() { - std::vector req_kernels; - for (auto kernel : this->kernels_) { - if (kernel->subgraph_type() == kernel::kNotSubGraph) { - if (IsLossKernel(kernel)) { // For each loss in the system add backward tree - for (auto in_node : kernel->in_kernels()) { - BuildInferenceKernelsRecursive(in_node, &req_kernels); - } - } - } else { - auto sub_graph = reinterpret_cast(kernel); - for (auto sub_kernel : sub_graph->nodes()) { - if (IsLossKernel(sub_kernel)) { // For each loss in the system add backward tree - for (auto in_node : sub_kernel->in_kernels()) { - BuildInferenceKernelsRecursive(in_node, &req_kernels); - } - } - } - } - } - - inference_kernels_.clear(); +void TrainSession::CompileTrainKernels() { + train_kernels_.clear(); for (auto ori_kernel : kernels_) { if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { - if (std::find(req_kernels.begin(), req_kernels.end(), ori_kernel) != req_kernels.end()) { - inference_kernels_.push_back(ori_kernel); - } + train_kernels_.push_back(ori_kernel); } else { auto sub_graph = reinterpret_cast(ori_kernel); for (auto kernel : sub_graph->nodes()) { - if (std::find(req_kernels.begin(), req_kernels.end(), kernel) != req_kernels.end()) { - inference_kernels_.push_back(kernel); - } + train_kernels_.push_back(kernel); } } } - if (inference_kernels_.size() == 0) { - inference_kernels_ = this->kernels_; - } } -void TrainSession::CompileTrainKernels() { - train_kernels_.clear(); - for (auto ori_kernel : kernels_) { - if (ori_kernel->subgraph_type() == kernel::kNotSubGraph) { - train_kernels_.push_back(ori_kernel); - } else { - auto sub_graph = reinterpret_cast(ori_kernel); - for (auto kernel : sub_graph->nodes()) { - train_kernels_.push_back(kernel); +void TrainSession::CompileInferenceKernels() { + std::vector req_kernels; + for (auto kernel : this->train_kernels_) { + if (IsLossKernel(kernel)) { // For each loss in the system add backward tree + for (auto in_node : kernel->in_kernels()) { + BuildInferenceKernelsRecursive(in_node, &req_kernels); } } } + inference_kernels_.clear(); + for (auto ori_kernel : this->train_kernels_) { + if (std::find(req_kernels.begin(), req_kernels.end(), ori_kernel) != req_kernels.end()) { + inference_kernels_.push_back(ori_kernel); + } + } + if (inference_kernels_.size() == 0) { + inference_kernels_ = this->train_kernels_; + } } -void TrainSession::MarkOptimizedKernels() { +void TrainSession::CompileOptimizedKernels() { std::vector ot; - for (auto kernel : this->kernels_) { - if (kernel->subgraph_type() == kernel::kNotSubGraph) { - if (IsOptimizer(kernel)) { - std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(ot)); - } - } else { - auto sub_graph = reinterpret_cast(kernel); - for (auto sub_kernel : sub_graph->nodes()) { - if (IsOptimizer(sub_kernel)) { - std::copy(sub_kernel->in_tensors().begin(), sub_kernel->in_tensors().end(), std::back_inserter(ot)); - } - } + for (auto kernel : this->train_kernels_) { + if (IsOptimizer(kernel)) { + std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(ot)); } } - for (auto kernel : this->kernels_) { - if (kernel->subgraph_type() == kernel::kNotSubGraph) { - if (!IsOptimizer(kernel)) { - for (auto it : kernel->in_tensors()) { - if (std::find(ot.begin(), ot.end(), it) != ot.end()) { - kernel->set_trainable(true); - break; - } - } - } - } else { - auto sub_graph = reinterpret_cast(kernel); - for (auto sub_kernel : sub_graph->nodes()) { - if (!IsOptimizer(sub_kernel)) { - for (auto it : sub_kernel->in_tensors()) { - if (std::find(ot.begin(), ot.end(), it) != ot.end()) { - sub_kernel->set_trainable(true); - break; - } - } + + for (auto kernel : this->train_kernels_) { + if (!IsOptimizer(kernel)) { + for (auto it : kernel->in_tensors()) { + if (std::find(ot.begin(), ot.end(), it) != ot.end()) { + kernel->set_trainable(true); + break; } } } @@ -400,19 +303,31 @@ void TrainSession::MarkOptimizedKernels() { } bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const { - return (kernel->Type() == schema::PrimitiveType_SoftmaxCrossEntropy); + return (kernel->Type() == schema::PrimitiveType_SoftmaxCrossEntropy || + kernel->Type() == schema::PrimitiveType_SparseSoftmaxCrossEntropy); } bool TrainSession::IsOptimizer(kernel::LiteKernel *kernel) const { return ((kernel->Type() == schema::PrimitiveType_Adam) || (kernel->Type() == schema::PrimitiveType_Sgd) || (kernel->Type() == schema::PrimitiveType_ApplyMomentum)); } +bool TrainSession::IsMaskOutput(kernel::LiteKernel *kernel) const { + return (IsOptimizer(kernel) || (kernel->Type() == schema::PrimitiveType_Assign)); +} } // namespace lite -session::TrainSession *session::TrainSession::CreateSession(lite::Context *context) { +session::TrainSession *session::TrainSession::CreateSession(const char *model_buf, size_t size, lite::Context *context, + bool train_mode) { + auto model = mindspore::lite::TrainModel::Import(model_buf, size); + if (model == nullptr) { + MS_LOG(ERROR) << "create model for train session failed"; + return nullptr; + } + auto session = new (std::nothrow) lite::TrainSession(); if (session == nullptr) { + delete model; MS_LOG(ERROR) << "create session failed"; return nullptr; } @@ -422,7 +337,59 @@ session::TrainSession *session::TrainSession::CreateSession(lite::Context *conte delete session; return nullptr; } + + ret = session->CompileTrainGraph(model); + if (ret != mindspore::lite::RET_OK) { + MS_LOG(ERROR) << "Compiling Train Graph sesssion failed"; + delete session; + return nullptr; + } + + if (train_mode) { + ret = session->Train(); + } else { + ret = session->Eval(); + } + if (ret != mindspore::lite::RET_OK) { + MS_LOG(ERROR) << "Could not switch to Train Modei " << train_mode; + delete session; + return nullptr; + } + return session; } +session::TrainSession *session::TrainSession::CreateSession(const std::string &filename, lite::Context *context, + bool train_mode) { + std::ifstream ifs(filename); + if (!ifs.good()) { + MS_LOG(ERROR) << "File: " << filename << " does not exist"; + return nullptr; + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "File: " << filename << " open failed"; + return nullptr; + } + + ifs.seekg(0, std::ios::end); + auto size = ifs.tellg(); + if (size == 0) { + MS_LOG(ERROR) << "Could not read file " << filename; + return nullptr; + } + std::unique_ptr buf(new (std::nothrow) char[size]); + if (buf == nullptr) { + MS_LOG(ERROR) << "malloc buf failed, file: " << filename; + ifs.close(); + return nullptr; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(buf.get(), size); + ifs.close(); + + return session::TrainSession::CreateSession(buf.get(), size, context, train_mode); +} + } // namespace mindspore diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index e11a8d92d3..45c1b42462 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -21,7 +21,7 @@ #include #include "src/ops/primitive_c.h" #include "include/train_session.h" -#include "include/train_model.h" +#include "src/train/train_model.h" #include "src/lite_session.h" /* @@ -52,7 +52,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; int CompileGraph(lite::Model *model) override; - int CompileTrainGraph(lite::TrainModel *model) override; + virtual int CompileTrainGraph(lite::TrainModel *model); void *ExportToBuf(char *buf, size_t *len) const override; int SaveToFile(const std::string &filename) const override; @@ -80,24 +80,34 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: return lite::LiteSession::Resize(inputs, dims); } - void UpdateOutputMapByInKernel(const kernel::LiteKernel *kernel); - void UpdateOutputMapByLossKernel(const kernel::LiteKernel *kernel); - protected: void AllocWorkSpace(); bool IsLossKernel(const kernel::LiteKernel *kernel) const; bool IsOptimizer(kernel::LiteKernel *kernel) const; - virtual void MarkOptimizedKernels(); + bool IsMaskOutput(kernel::LiteKernel *kernel) const; virtual std::vector ReplaceOps(); virtual void RestoreOps(const std::vector &restore); - virtual void BuildInferenceKernelsMap(); - virtual void BuildInferenceKernelsRecursive(kernel::LiteKernel *ker, std::vector *req_kernels); virtual void CompileTrainKernels(); + virtual void CompileInferenceKernels(); + virtual void CompileOptimizedKernels(); + virtual void CompileTrainOutputs(); + virtual void CompileEvalOutputs(); + TrainModel *model_ = nullptr; - std::unordered_map> orig_output_map_; + std::unordered_map> orig_output_node_map_; std::unordered_map orig_output_tensor_map_; + + std::unordered_map> eval_output_node_map_; + std::unordered_map eval_output_tensor_map_; + + std::unordered_map> train_output_node_map_; + std::unordered_map train_output_tensor_map_; + std::vector inference_kernels_; std::vector train_kernels_; + + private: + void BuildInferenceKernelsRecursive(kernel::LiteKernel *ker, std::vector *req_kernels); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/test/models_ms_train.cfg b/mindspore/lite/test/models_ms_train.cfg index 58c5e278f2..de471ebe4d 100644 --- a/mindspore/lite/test/models_ms_train.cfg +++ b/mindspore/lite/test/models_ms_train.cfg @@ -1,7 +1,7 @@ mini_alexnet -#mobilenetv1 -#mobilenetv2 -#mobilenetv3 # this model got error when RunX86 +# mobilenetv1 +mobilenetv2 +mobilenetv3 lenet #effnet effnet_tune diff --git a/mindspore/lite/test/run_net_train.sh b/mindspore/lite/test/run_net_train.sh index 756f8c4f10..088e5d5080 100755 --- a/mindspore/lite/test/run_net_train.sh +++ b/mindspore/lite/test/run_net_train.sh @@ -16,7 +16,7 @@ function Run_Export(){ echo ${model_name}'_train_export.py' >> "${export_log_file}" echo 'exporting' ${model_name} echo 'docker run --user $(id -u):$(id -g) --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true mindspore_dev:5 python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" - docker run --user $(id -u):$(id -g) --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true mindspore_dev:5 python ${models_path}'/'${model_name}_train_export.py + docker run --user $(id -u):$(id -g) --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true mindspore_dev:5 python ${models_path}'/'${model_name}_train_export.py ${epoch_num} if [ $? = 0 ]; then export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} else @@ -74,18 +74,19 @@ function Run_x86() { echo ${model_name}'_train' >> "${run_x86_log_file}" echo 'cd '${x86_path}'/mindspore-lite-'${version}'-runtime-x86-'${process_unit_x86}-train >> "${run_x86_log_file}" cd ${x86_path}/mindspore-lite-${version}-runtime-x86-${process_unit_x86}-train || return 1 - echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./net_train/net_train --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${input_path}'/'${model_name}'_input1.bin,'${train_io_path}'/'${model_name}'_input2.bin --expectedDataFile='${train_io_path}'/'${model_name}'_outputs.bin --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}" + echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./net_train/net_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${input_path}'/'${model_name}'_input1.bin,'${train_io_path}'/'${model_name}'_input2.bin --expectedDataFile='${train_io_path}'/'${model_name}'_outputs.bin --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}" echo '-------------------------------------------------------------------------------' >> "${run_x86_log_file}" LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib \ ${run_valgrind}./net_train/net_train \ --modelFile=${ms_models_path}/${model_name}_train.ms \ --inDataFile=${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin \ --expectedDataFile=${train_io_path}/${model_name}_outputs.bin \ - --exportFile=${ms_models_path}/${model_name}_train_exported.ms >> "${run_x86_log_file}" + --exportFile=${ms_models_path}/${model_name}_train_exported.ms >> "${run_x86_log_file}" \ + --epochs=${epoch_num} if [ $? = 0 ]; then run_result='x86: '${model_name}'_train pass'; echo ${run_result} >> ${run_net_train_result_file} else - run_result='x86: '${model_name}'_train failed'; echo ${run_result} >> ${run_net_train_result_file}; return 1 + run_result='x86: '${model_name}'_train failed'; echo ${run_result} >> ${run_net_train_result_file} fi done < ${models_mindspore_train_config} } @@ -160,12 +161,12 @@ function Run_arm() { echo 'chmod 777 net_train' >> ${adb_cmd_run_file} if [ "$1" == arm64 ]; then echo 'cp /data/local/tmp/libc++_shared.so ./' >> ${adb_cmd_run_file} - echo 'export LD_LIBRARY_PATH=/data/local/tmp/net_train_test;./net_train --modelFile='${model_name}'_train.ms --inDataFile=/data/local/tmp/net_train_test/'${model_name}'_input1.bin,/data/local/tmp/net_train_test/'${model_name}'_input2.bin --expectedDataFile=/data/local/tmp/net_train_test/'${model_name}'_outputs.bin --exportFile='${model_name}'_train_exported.ms' >> "${run_arm_log_file}" - echo 'export LD_LIBRARY_PATH=/data/local/tmp/net_train_test;./net_train --modelFile='${model_name}'_train.ms --inDataFile=/data/local/tmp/net_train_test/'${model_name}'_input1.bin,/data/local/tmp/net_train_test/'${model_name}'_input2.bin --expectedDataFile=/data/local/tmp/net_train_test/'${model_name}'_outputs.bin --exportFile='${model_name}'_train_exported.ms' >> "${adb_cmd_run_file}" + echo 'export LD_LIBRARY_PATH=./:/data/local/tmp/net_train_test;./net_train --epochs='${epoch_num}' --modelFile='${model_name}'_train.ms --inDataFile=/data/local/tmp/net_train_test/'${model_name}'_input1.bin,/data/local/tmp/net_train_test/'${model_name}'_input2.bin --expectedDataFile=/data/local/tmp/net_train_test/'${model_name}'_outputs.bin' >> "${run_arm_log_file}" + echo 'export LD_LIBRARY_PATH=./:/data/local/tmp/net_train_test;./net_train --epochs='${epoch_num}' --modelFile='${model_name}'_train.ms --inDataFile=/data/local/tmp/net_train_test/'${model_name}'_input1.bin,/data/local/tmp/net_train_test/'${model_name}'_input2.bin --expectedDataFile=/data/local/tmp/net_train_test/'${model_name}'_outputs.bin' >> "${adb_cmd_run_file}" elif [ "$1" == arm32 ]; then echo 'cp /data/local/tmp/arm32/libc++_shared.so ./' >> ${adb_cmd_run_file} - echo 'export LD_LIBRARY_PATH=/data/local/tmp/net_train_test;./net_train --modelFile='${model_name}'_train.ms --inDataFile=/data/local/tmp/net_train_test/'${model_name}'_input1.bin,/data/local/tmp/net_train_test/'${model_name}'_input2.bin --expectedDataFile=/data/local/tmp/net_train_test/'${model_name}'_outputs.bin --exportFile='${model_name}'_train_exported.ms' >> "${run_arm_log_file}" - echo 'export LD_LIBRARY_PATH=/data/local/tmp/net_train_test;./net_train --modelFile='${model_name}'_train.ms --inDataFile=/data/local/tmp/net_train_test/'${model_name}'_input1.bin,/data/local/tmp/net_train_test/'${model_name}'_input2.bin --expectedDataFile=/data/local/tmp/net_train_test/'${model_name}'_outputs.bin --exportFile='${model_name}'_train_exported.ms' >> "${adb_cmd_run_file}" + echo 'export LD_LIBRARY_PATH=./:/data/local/tmp/:/data/local/tmp/net_train_test;./net_train --epochs='${epoch_num}' --modelFile='${model_name}'_train.ms --inDataFile=/data/local/tmp/net_train_test/'${model_name}'_input1.bin,/data/local/tmp/net_train_test/'${model_name}'_input2.bin --expectedDataFile=/data/local/tmp/net_train_test/'${model_name}'_outputs.bin' >> "${run_arm_log_file}" + echo 'export LD_LIBRARY_PATH=./:/data/local/tmp/:/data/local/tmp/net_train_test;./net_train --epochs='${epoch_num}' --modelFile='${model_name}'_train.ms --inDataFile=/data/local/tmp/net_train_test/'${model_name}'_input1.bin,/data/local/tmp/net_train_test/'${model_name}'_input2.bin --expectedDataFile=/data/local/tmp/net_train_test/'${model_name}'_outputs.bin' >> "${adb_cmd_run_file}" fi adb -s ${device_id} shell < ${adb_cmd_run_file} >> ${run_arm_log_file} @@ -203,16 +204,18 @@ function Print_Result() { basepath=$(pwd) echo ${basepath} -train_io_path="" # Example:run_net_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408" # For running on arm64, use -t to set platform tools path (for using adb commands) -while getopts "r:m:d:i:e:v" opt; do +epoch_num=1 +train_io_path="" +while getopts "r:m:d:i:e:vt:" opt; do case ${opt} in r) release_path=${OPTARG} echo "release_path is ${OPTARG}" ;; m) + models_path=${OPTARG}"/models_train" echo "models_path is ${OPTARG}" ;; @@ -229,9 +232,13 @@ while getopts "r:m:d:i:e:v" opt; do echo "enable_export = ${OPTARG}" ;; v) - run_valgrind="valgrind " + run_valgrind="valgrind --log-file=valgrind.log " echo "Run x86 with valgrind" - ;; + ;; + t) + epoch_num=${OPTARG} + echo "train epoch num is ${OPTARG}" + ;; ?) echo "unknown para" exit 1;; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc index ba626f30c8..29fb911a23 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc @@ -181,7 +181,6 @@ TEST_F(TestBNGradFp32, BNTtrainFp32) { auto kernel_obj = creator(inputs, outputs, reinterpret_cast(bn_param), &context, desc, nullptr); ASSERT_NE(kernel_obj, nullptr); mindspore::kernel::LiteKernel::AllocWorkspace(kernel_obj->workspace_size()); - float *save_mean = reinterpret_cast(save_mean_tensor.MutableData()); float *save_var = reinterpret_cast(save_var_tensor.MutableData()); for (int i = 0; i < channels; i++) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc index 9c938f7e2a..7ed5805484 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc @@ -273,10 +273,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) { auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), &context, desc, nullptr); ASSERT_NE(kernel, nullptr); mindspore::kernel::LiteKernel::AllocWorkspace(kernel->workspace_size()); - // warm up loop - for (int i = 0; i < 3; i++) { - kernel->Run(); - } + kernel->Run(); int loop_count = 100; auto time_start = mindspore::lite::GetTimeUs(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc index df7a5ae31c..66b16b5567 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc @@ -205,10 +205,7 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2FilterGrad) { auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), &context, desc, nullptr); ASSERT_NE(kernel, nullptr); mindspore::kernel::LiteKernel::AllocWorkspace(kernel->workspace_size()); - - // warm up loop for (int i = 0; i < 3; i++) { - kernel->Run(); } // runtime part @@ -631,6 +628,7 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group12Stride2FilterGrad) { ASSERT_NE(creator, nullptr); auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), &context, desc, nullptr); ASSERT_NE(kernel, nullptr); + mindspore::kernel::LiteKernel::AllocWorkspace(kernel->workspace_size()); // warm up loop diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc index cc648fe5b8..4da792359a 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc @@ -23,7 +23,6 @@ #include #include "schema/inner/model_generated.h" -#include "mindspore/lite/include/train_model.h" #include "common/common_test.h" #include "include/train_session.h" #include "include/context.h" @@ -131,7 +130,6 @@ TEST_F(NetworkTest, tuning_layer) { node->primitive->value.type = schema::PrimitiveType_SoftmaxCrossEntropy; auto primitive = new schema::SoftmaxCrossEntropyT; ASSERT_NE(primitive, nullptr); - primitive->axis.push_back(0); node->primitive->value.value = primitive; node->name = "SoftmaxCrossEntropy"; meta_graph->nodes.emplace_back(std::move(node)); @@ -144,7 +142,6 @@ TEST_F(NetworkTest, tuning_layer) { node->primitive->value.type = schema::PrimitiveType_BiasGrad; auto primitive = new schema::BiasGradT; ASSERT_NE(primitive, nullptr); - primitive->axis.push_back(0); node->primitive->value.value = primitive; node->name = "BiasGrad"; meta_graph->nodes.emplace_back(std::move(node)); @@ -360,17 +357,13 @@ TEST_F(NetworkTest, tuning_layer) { const char *content = reinterpret_cast(builder.GetBufferPointer()); std::cout << "build fb size= " << size << std::endl; - auto model = lite::TrainModel::Import(content, size); - ASSERT_NE(nullptr, model); meta_graph.reset(); content = nullptr; lite::Context context; context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context.thread_num_ = 1; - auto session = session::TrainSession::CreateSession(&context); + auto session = session::TrainSession::CreateSession(content, size, &context); ASSERT_NE(nullptr, session); - auto ret = session->CompileTrainGraph(model); - ASSERT_EQ(lite::RET_OK, ret); session->Train(); session->Train(); // Just double check that calling Train twice does not cause a problem @@ -398,7 +391,7 @@ TEST_F(NetworkTest, tuning_layer) { std::fill(labels, labels + labelTensor->ElementsNum(), 0.f); for (int i = 0; i < BATCH_SIZE; i++) labels[i * NUM_CLASSES + (i * 97) % NUM_CLASSES] = 1.0; - ret = session->RunGraph(); + auto ret = session->RunGraph(); ASSERT_EQ(lite::RET_OK, ret); auto outputs = session->GetOutputsByNodeName("SoftmaxCrossEntropy"); ASSERT_EQ(outputs.size(), 1); @@ -514,23 +507,14 @@ int32_t runNet(mindspore::session::LiteSession *session, const std::string &in, } TEST_F(NetworkTest, efficient_net) { - char *buf = nullptr; - size_t net_size = 0; - - std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms"; - ReadFile(net.c_str(), &net_size, &buf); - auto model = lite::TrainModel::Import(buf, net_size); - delete[] buf; auto context = new lite::Context; ASSERT_NE(context, nullptr); context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context->thread_num_ = 1; - auto session = session::TrainSession::CreateSession(context); + std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms"; + auto session = session::TrainSession::CreateSession(net, context, false); ASSERT_NE(session, nullptr); - auto ret = session->CompileTrainGraph(model); - ASSERT_EQ(lite::RET_OK, ret); - session->Eval(); std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin"; std::string out = "./test_data/nets/effNet_output_y_1_1000.bin"; @@ -540,58 +524,6 @@ TEST_F(NetworkTest, efficient_net) { ASSERT_EQ(res, 0); } -TEST_F(NetworkTest, retina_net) { - char *buf = nullptr; - size_t net_size = 0; - - std::string net = "./test_data/nets/retinaface1.ms"; - ReadFile(net.c_str(), &net_size, &buf); - // auto model = lite::TrainModel::Import(buf, net_size); - auto model = lite::Model::Import(buf, net_size); - delete[] buf; - auto context = new lite::Context; - ASSERT_NE(context, nullptr); - context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; - context->thread_num_ = 1; - - // auto session = session::TrainSession::CreateSession(context); - auto session = session::LiteSession::CreateSession(context); - ASSERT_NE(session, nullptr); - auto ret = session->CompileGraph(model); - EXPECT_EQ(lite::RET_OK, ret); - // session->Eval(); - - std::string in = "./test_data/nets/test1.hwc_normalized_f32"; - std::cout << "----- Output 0 -----" << std::endl; - std::string out = "./test_data/nets/test1_loc.f32"; - int final_res = 0; - auto res = runNet(session, in, out, "448", true); - // ASSERT_EQ(res, 0); - if (res != 0) { - final_res = res; - } - - std::cout << "----- Output 1 -----" << std::endl; - out = "./test_data/nets/test1_conf.f32"; - res = runNet(session, in, out, "435", true); - // ASSERT_EQ(res, 0); - if (res != 0) { - final_res |= res; - } - std::cout << "----- Output 2 -----" << std::endl; - out = "./test_data/nets/test1_landms.f32"; - res = runNet(session, in, out, "421", true); - if (res != 0) { - final_res |= res; - } - - EXPECT_EQ(final_res, 0); - - delete model; - delete session; - delete context; -} - TEST_F(NetworkTest, mobileface_net) { char *buf = nullptr; size_t net_size = 0; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc index 4d6b756090..a58994bf39 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc @@ -55,9 +55,9 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { std::vector inputs = {&y_tensor, &l_tensor}; - auto loss = new float[1]; + auto loss = new float[6]; ASSERT_NE(loss, nullptr); - std::vector dim_dw({1}); + std::vector dim_dw({6, 1}); lite::Tensor loss_tensor(TypeId::kNumberTypeFloat32, dim_dw); loss_tensor.set_data(loss); auto grad = new float[24]; diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 48de5d5a44..98ad9000b4 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -73,10 +73,14 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { hasDepend = true; + bool maskOut = (dependNode->inputs().size() == 3) ? true : false; for (size_t j = 1; j < dependNode->inputs().size(); ++j) { AnfNodePtr dependInputNode = dependNode->input(j); if (dependInputNode->isa()) { inputs.emplace_back(dependInputNode); + if (maskOut) { + break; + } } } } else { @@ -220,6 +224,11 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee ret = RET_MEMORY_FAILED; break; } +#ifdef SUPPORT_TRAIN + RemoveIfMakeTuple(cnode); + RemoveIfDepend(cnode); +#endif + if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || #ifdef SUPPORT_TRAIN (primitive_c->Type() == schema::PrimitiveType_Depend) || @@ -228,9 +237,8 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { continue; } +#ifndef SUPPORT_TRAIN RemoveIfMakeTuple(cnode); -#ifdef SUPPORT_TRAIN - RemoveIfDepend(cnode); #endif auto primT = primitive_c->primitiveT(); auto node = std::make_unique(); @@ -489,6 +497,11 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano paramTensor->format = schema::Format(valueLite->format()); paramTensor->dataType = valueLite->tensor_type(); paramTensor->dims = valueLite->tensor_shape(); +#ifdef SUPPORT_TRAIN + if (paramTensor->dims.size() == 0) { + paramTensor->dims = {1}; + } +#endif auto ret = memcpy_s(paramTensor->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(), valueLite->tensor_size()); if (ret != EOK) { diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index 8462983573..c0f3ba4b9b 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -733,13 +733,6 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output outputFuncGraph->set_return(return_node); MS_LOG(INFO) << "Construct funcgraph finined, all success."; } else { -#ifdef SUPPORT_TRAIN - auto ret_node = outputFuncGraph->get_return(); - if (ret_node) { - ret_node->add_input(cnode_ptr); - return true; - } -#endif const onnx::ValueInfoProto &output_node = importProto.output(0); const onnx::TypeProto &output_typeproto = output_node.type(); int output_type = output_typeproto.tensor_type().elem_type(); @@ -805,29 +798,13 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG MS_LOG(ERROR) << "primitive_c is nullptr"; return RET_ERROR; } - -#ifdef SUPPORT_TRAIN - if (primitive_c->Type() == schema::PrimitiveType_MakeTuple) { - last_cnode_ptr = cnode_ptr; - if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { - MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; - status = RET_ERROR; - } - } -#endif } if (status != RET_OK) { return status; } -#ifdef SUPPORT_TRAIN - if (last_cnode_ptr != cnode_ptr) { -#else - { -#endif - if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { - MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; - status = RET_ERROR; - } + if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { + MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; + status = RET_ERROR; } return status; } diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 8c182cda35..bd96c53a3e 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -137,10 +137,19 @@ static const std::vector int8OpList = {schema::PrimitiveT schema::PrimitiveType_L2Norm}; static const std::vector needInsertOpList = { +#ifdef SUPPORT_TRAIN + schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, + schema::PrimitiveType_Concat, schema::PrimitiveType_Power, + schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, + schema::PrimitiveType_Split, schema::PrimitiveType_Slice, + schema::PrimitiveType_Crop +#else schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop, - schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum}; + schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum +#endif +}; static const std::unordered_map nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index a862719436..876389d20b 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -109,9 +109,11 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver fusion_pm->AddPass(remove_unused_transpose_pass); } auto const_fold_pm = std::make_shared("const fold fusion pass manager", false); - auto inne_context_ptr = std::make_shared(); - inne_context_ptr->Init(); - const_fold_pm->AddPass(std::make_shared(inne_context_ptr)); + if (!config->trainModel) { + auto inne_context_ptr = std::make_shared(); + inne_context_ptr->Init(); + const_fold_pm->AddPass(std::make_shared(inne_context_ptr)); + } const_fold_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); convert_pm->AddPass(std::make_shared()); diff --git a/mindspore/lite/tools/net_train/net_train.cc b/mindspore/lite/tools/net_train/net_train.cc index eb1356e97f..ae34f589e2 100644 --- a/mindspore/lite/tools/net_train/net_train.cc +++ b/mindspore/lite/tools/net_train/net_train.cc @@ -32,16 +32,6 @@ static const char *DELIM_COLON = ":"; static const char *DELIM_COMMA = ","; static const char *DELIM_SLASH = "/"; -void SaveFile(std::string path, void *buf, size_t size) { - std::ofstream ofs(path); - MS_ASSERT(ofs.good() == true); - MS_ASSERT(ofs.is_open() == true); - - ofs.seekp(0, std::ios::beg); - ofs.write((const char *)buf, size); - ofs.close(); -} - int NetTrain::GenerateRandomData(size_t size, void *data) { MS_ASSERT(data != nullptr); char *casted_data = static_cast(data); @@ -61,7 +51,7 @@ int NetTrain::GenerateInputData() { } auto tensor_byte_size = tensor->Size(); auto status = GenerateRandomData(tensor_byte_size, input_data); - if (status != 0) { + if (status != RET_OK) { std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl; MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status; return status; @@ -73,14 +63,14 @@ int NetTrain::GenerateInputData() { int NetTrain::LoadInput() { if (flags_->in_data_file_.empty()) { auto status = GenerateInputData(); - if (status != 0) { + if (status != RET_OK) { std::cerr << "Generate input data error " << status << std::endl; MS_LOG(ERROR) << "Generate input data error " << status; return status; } } else { auto status = ReadInputFile(); - if (status != 0) { + if (status != RET_OK) { std::cerr << "ReadInputFile error, " << status << std::endl; MS_LOG(ERROR) << "ReadInputFile error, " << status; return status; @@ -331,20 +321,6 @@ int NetTrain::RunExportedNet() { MS_LOG(INFO) << "start reading exported model file"; std::cout << "start reading exported model file" << std::endl; - size_t size = 0; - char *graph_buf = ReadFile(flags_->export_file_.c_str(), &size); - if (graph_buf == nullptr) { - MS_LOG(ERROR) << "Read exported model file failed while running " << model_name.c_str(); - std::cerr << "Read exported model file failed while running " << model_name.c_str() << std::endl; - return RET_ERROR; - } - auto model = lite::TrainModel::Import(graph_buf, size); - delete[](graph_buf); - if (model == nullptr) { - MS_LOG(ERROR) << "Import exported model file failed while running " << model_name.c_str(); - std::cerr << "Import exported model file failed while running " << model_name.c_str() << std::endl; - return RET_ERROR; - } auto context = std::make_shared(); if (context == nullptr) { MS_LOG(ERROR) << "New context failed while running " << model_name.c_str(); @@ -362,18 +338,12 @@ int NetTrain::RunExportedNet() { context->thread_num_ = flags_->num_threads_; // context->enable_float16_ = flags_->enable_fp16_; - session_ = session::TrainSession::CreateSession(context.get()); + session_ = session::TrainSession::CreateSession(flags_->export_file_.c_str(), context.get()); if (session_ == nullptr) { MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str(); std::cout << "CreateSession failed while running ", model_name.c_str(); return RET_ERROR; } - auto ret = session_->CompileTrainGraph(model); - if (ret != RET_OK) { - MS_LOG(ERROR) << "CompileGraph failed while running ", model_name.c_str(); - std::cout << "CompileGraph failed while running ", model_name.c_str(); - return ret; - } ms_inputs_ = session_->GetInputs(); auto end_prepare_time = GetTimeUs(); @@ -383,13 +353,13 @@ int NetTrain::RunExportedNet() { // Load input MS_LOG(INFO) << "start generate input data"; auto status = LoadInput(); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "Generate input data error"; return status; } status = session_->RunGraph(); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "Inference error " << status; std::cerr << "Inference error " << status << std::endl; return status; @@ -405,7 +375,7 @@ int NetTrain::RunExportedNet() { delete data.second; } data_.clear(); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "Run MarkAccuracy on exported model error: " << status; std::cout << "Run MarkAccuracy on exported model error: " << status << std::endl; return status; @@ -421,20 +391,6 @@ int NetTrain::RunNetTrain() { MS_LOG(INFO) << "start reading model file"; std::cout << "start reading model file" << std::endl; - size_t size = 0; - char *graph_buf = ReadFile(flags_->model_file_.c_str(), &size); - if (graph_buf == nullptr) { - MS_LOG(ERROR) << "Read model file failed while running " << model_name.c_str(); - std::cerr << "Read model file failed while running " << model_name.c_str() << std::endl; - return RET_ERROR; - } - auto model = lite::TrainModel::Import(graph_buf, size); - delete[](graph_buf); - if (model == nullptr) { - MS_LOG(ERROR) << "Import model file failed while running " << model_name.c_str(); - std::cerr << "Import model file failed while running " << model_name.c_str() << std::endl; - return RET_ERROR; - } auto context = std::make_shared(); if (context == nullptr) { MS_LOG(ERROR) << "New context failed while running " << model_name.c_str(); @@ -451,18 +407,12 @@ int NetTrain::RunNetTrain() { } context->thread_num_ = flags_->num_threads_; // context->enable_float16_ = flags_->enable_fp16_; - session_ = session::TrainSession::CreateSession(context.get()); + session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get()); if (session_ == nullptr) { MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str(); std::cout << "CreateSession failed while running ", model_name.c_str(); return RET_ERROR; } - auto ret = session_->CompileTrainGraph(model); - if (ret != RET_OK) { - MS_LOG(ERROR) << "CompileGraph failed while running ", model_name.c_str(); - std::cout << "CompileGraph failed while running ", model_name.c_str(); - return ret; - } session_->Train(); @@ -474,13 +424,13 @@ int NetTrain::RunNetTrain() { // Load input MS_LOG(INFO) << "start generate input data"; auto status = LoadInput(); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "Generate input data error"; return status; } if (flags_->epochs_ > 0) { status = MarkPerformance(); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "Run MarkPerformance error: " << status; std::cout << "Run MarkPerformance error: " << status << std::endl; return status; @@ -494,24 +444,22 @@ int NetTrain::RunNetTrain() { delete data.second; } data_.clear(); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "Run MarkAccuracy error: " << status; std::cout << "Run MarkAccuracy error: " << status << std::endl; return status; } } if (!flags_->export_file_.empty()) { - size_t tsize = 0; - auto buf = session_->ExportToBuf(nullptr, &tsize); - if (buf == nullptr) { - MS_LOG(ERROR) << "Run ExportToBuf error"; - std::cout << "Run ExportToBuf error"; + auto ret = session_->SaveToFile(flags_->export_file_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SaveToFile error"; + std::cout << "Run SaveToFile error"; return RET_ERROR; } - SaveFile(flags_->export_file_, buf, size); status = RunExportedNet(); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "Run Exported model error: " << status; std::cout << "Run Exported model error: " << status << std::endl; return status; @@ -754,14 +702,14 @@ int RunNetTrain(int argc, const char **argv) { NetTrain net_trainer(&flags); auto status = net_trainer.Init(); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "NetTrain init Error : " << status; std::cerr << "NetTrain init Error : " << status << std::endl; return RET_ERROR; } status = net_trainer.RunNetTrain(); - if (status != 0) { + if (status != RET_OK) { MS_LOG(ERROR) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str() << " Failed : " << status; diff --git a/mindspore/lite/tools/net_train/net_train.h b/mindspore/lite/tools/net_train/net_train.h index df600a06ac..6ef29365a5 100644 --- a/mindspore/lite/tools/net_train/net_train.h +++ b/mindspore/lite/tools/net_train/net_train.h @@ -29,7 +29,6 @@ #include #include #include -#include "include/train_model.h" #include "tools/common/flag_parser.h" #include "src/common/file_utils.h" #include "src/common/utils.h" @@ -129,8 +128,10 @@ class MS_API NetTrain { MS_ASSERT(input != nullptr); static int i = 0; auto inData = reinterpret_cast(input->MutableData()); + size_t tensorSize = input->ElementsNum(); + size_t len = (tensorSize < 20) ? tensorSize : 20; std::cout << "InData" << i++ << ": "; - for (size_t j = 0; j < 20; j++) { + for (size_t j = 0; j < len; j++) { std::cout << inData[j] << " "; } std::cout << std::endl; @@ -190,10 +191,8 @@ class MS_API NetTrain { } } else { // just assume that atol = rtol - if (absoluteError > 1e-5) { - meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN); - errorCount++; - } + meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN); + errorCount++; } } }