relalign code

pull/13151/head
zhengjun10 4 years ago
parent 640119461c
commit eae045fcf7

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPERATOR_OPS_H_
#define MINDSPORE_CORE_OPERATOR_OPS_H_
#ifndef MINDSPORE_CORE_BASE_CORE_OPS_H_
#define MINDSPORE_CORE_BASE_CORE_OPS_H_
#include <iostream>
#include <string>
@ -182,6 +182,7 @@ inline const PrimitivePtr kPrimReverseV2 = std::make_shared<Primitive>("ReverseV
inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("ReverseSequence");
inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank");
inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear");
inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeGrad");
// NN
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");
@ -245,7 +246,6 @@ inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared<Primitive>("DetectionPostProcess");
inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd");
inline const PrimitivePtr kPrimBiasGrad = std::make_shared<Primitive>("BiasGrad");
inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared<Primitive>("BiasSubGrad");
inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>("BinaryCrossEntropy");
@ -390,6 +390,7 @@ inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt");
inline const PrimitivePtr kPrimRsqrtGrad = std::make_shared<Primitive>("RsqrtGrad");
inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV");
inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace");
inline const PrimitivePtr kPrimNonMaxSuppression = std::make_shared<Primitive>("NonMaxSuppression");
@ -551,4 +552,4 @@ using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
} // namespace prim
} // namespace mindspore
#endif // MINDSPORE_CORE_OPERATOR_OPS_H_
#endif // MINDSPORE_CORE_BASE_CORE_OPS_H_

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/layer_norm_grad.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
void LayerNormGrad::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis) {
this->set_begin_norm_axis(begin_norm_axis);
this->set_begin_params_axis(begin_params_axis);
}
void LayerNormGrad::set_begin_norm_axis(const int64_t begin_norm_axis) {
this->AddAttr(kBeginNormAxis, MakeValue(begin_norm_axis));
}
void LayerNormGrad::set_begin_params_axis(const int64_t begin_params_axis) {
this->AddAttr(kBeginParamsAxis, MakeValue(begin_params_axis));
}
int64_t LayerNormGrad::get_begin_norm_axis() const {
auto value_ptr = this->GetAttr(kBeginNormAxis);
return GetValue<int64_t>(value_ptr);
}
int64_t LayerNormGrad::get_begin_params_axis() const {
auto value_ptr = this->GetAttr(kBeginParamsAxis);
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameLayerNormGrad, LayerNormGrad);
} // namespace ops
} // namespace mindspore

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_GRAD_LAYER_NORM_GRAD_H_
#define MINDSPORE_CORE_OPS_GRAD_LAYER_NORM_GRAD_H_
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLayerNormGrad = "LayerNormGrad";
class LayerNormGrad : public PrimitiveC {
public:
LayerNormGrad() : PrimitiveC(kNameLayerNormGrad) {}
explicit LayerNormGrad(const std::string k_name) : PrimitiveC(k_name) {}
~LayerNormGrad() = default;
MS_DECLARE_PARENT(LayerNormGrad, PrimitiveC);
void Init(const int64_t begin_norm_axis = 1, const int64_t begin_params_axis = 1);
void set_begin_norm_axis(const int64_t begin_norm_axis);
void set_begin_params_axis(const int64_t begin_params_axis);
int64_t get_begin_norm_axis() const;
int64_t get_begin_params_axis() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_GRAD_LAYER_NORM_GRAD_H_

@ -0,0 +1,52 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/resize_grad.h"
#include <map>
#include <string>
#include <vector>
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
void ResizeGrad::Init(const ResizeMethod method, const bool align_corners) {
this->set_method(method);
this->set_align_corners(align_corners);
}
void ResizeGrad::set_method(const ResizeMethod method) {
auto swi = (int64_t)method;
this->AddAttr(kMethod, MakeValue(swi));
}
void ResizeGrad::set_align_corners(const bool align_corners) { this->AddAttr(kAlignCorners, MakeValue(align_corners)); }
ResizeMethod ResizeGrad::get_method() const {
auto value_ptr = GetAttr(kMethod);
return ResizeMethod(GetValue<int64_t>(value_ptr));
}
bool ResizeGrad::get_align_corners() const {
auto value_ptr = GetAttr(kAlignCorners);
return GetValue<bool>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameResizeGrad, ResizeGrad);
} // namespace ops
} // namespace mindspore

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_GRAD_RESIZE_GRAD_H_
#define MINDSPORE_CORE_OPS_GRAD_RESIZE_GRAD_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameResizeGrad = "ResizeGrad";
class ResizeGrad : public PrimitiveC {
public:
ResizeGrad() : PrimitiveC(kNameResizeGrad) {}
~ResizeGrad() = default;
MS_DECLARE_PARENT(ResizeGrad, PrimitiveC);
void Init(const ResizeMethod method, const bool align_corners);
void set_method(const ResizeMethod method);
void set_align_corners(const bool align_corners);
ResizeMethod get_method() const;
bool get_align_corners() const;
};
AbstractBasePtr ResizeGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimResizeGradPtr = std::shared_ptr<ResizeGrad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_GRAD_RESIZE_GRAD_H_

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/rsqrt_grad.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameRsqrtGrad, RsqrtGrad);
} // namespace ops
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_GRAD_RSQRT_GRAD_H_
#define MINDSPORE_CORE_OPS_GRAD_RSQRT_GRAD_H_
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameRsqrtGrad = "RsqrtGrad";
class RsqrtGrad : public PrimitiveC {
public:
RsqrtGrad() : PrimitiveC(kNameRsqrtGrad) { InitIOName({"out_backprop", "input"}, {"output"}); }
~RsqrtGrad() = default;
MS_DECLARE_PARENT(RsqrtGrad, PrimitiveC);
void Init() {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_GRAD_RSQRT_GRAD_H_

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/sqrt_grad.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameSqrtGrad, SqrtGrad);
} // namespace ops
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_GRAD_SQRT_GRAD_H_
#define MINDSPORE_CORE_OPS_GRAD_SQRT_GRAD_H_
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSqrtGrad = "SqrtGrad";
class SqrtGrad : public PrimitiveC {
public:
SqrtGrad() : PrimitiveC(kNameSqrtGrad) { InitIOName({"out_backprop", "input"}, {"output"}); }
~SqrtGrad() = default;
MS_DECLARE_PARENT(SqrtGrad, PrimitiveC);
void Init() {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_GRAD_SQRT_GRAD_H_

@ -15,13 +15,20 @@
"""densenet_train_export."""
import sys
import os
import numpy as np
from train_utils import SaveInOut, TrainWrap
from official.cv.densenet121.src.network.densenet import DenseNet121
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/cv/densenet121/')
#pylint: disable=wrong-import-position
from official.cv.densenet121.src.network.densenet import DenseNet121
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
n = DenseNet121(num_classes=10)

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mobilenetv2_train_export."""
"""resnet_train_export"""
import sys
import numpy as np

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""vgg_train_export."""
import sys
import numpy as np
from train_utils import SaveInOut, TrainWrap
from official.cv.vgg16.src.vgg import vgg16
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
batch = 2
n = vgg16(num_classes=10)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
optimizer = nn.Momentum(n.trainable_params(), 0.01, 0.9, use_nesterov=False)
net = TrainWrap(n, loss_fn, optimizer)
x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32)
label = Tensor(np.zeros([batch, 10]).astype(np.float32))
export(net, x, label, file_name="mindir/vgg_train", file_format='MINDIR')
if len(sys.argv) > 1:
SaveInOut(sys.argv[1] + "vgg", x, label, n, net)

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""inceptionv4_train_export"""
import sys
import numpy as np
from train_utils import SaveInOut, TrainWrap
from official.cv.xception.src.Xception import Xception
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
n = Xception(num_classes=1000)
n.dropout = nn.Dropout(keep_prob=1.0)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
optimizer = nn.SGD(n.trainable_params(), learning_rate=0.01, momentum=0.9, dampening=0.0, weight_decay=0.0,
nesterov=True, loss_scale=1.0)
net = TrainWrap(n, loss_fn, optimizer)
batch = 2
x = Tensor(np.random.randn(batch, 3, 299, 299), mstype.float32)
label = Tensor(np.zeros([batch, 1000]).astype(np.float32))
export(net, x, label, file_name="mindir/xception_train", file_format='MINDIR')
if len(sys.argv) > 1:
SaveInOut(sys.argv[1] + "xception", x, label, n, net)

@ -8,5 +8,7 @@ effnet_tune
resnet
googlenet
nin
#shufflenetv2
#densenet
densenet
shufflenetv2
vgg noarm32
xception

@ -2,10 +2,9 @@
display_usage()
{
echo "Usage: prepare.sh [-d mindspore_docker] [-r release.tar.gz] [-i]"
echo "Usage: prepare.sh [-d mindspore_docker] [-i]"
echo "Options:"
echo " -d docker where mindspore is installed. If no docker is provided script will use local python"
echo " -r release tarball"
echo " -i create input and output files"
}
@ -20,9 +19,6 @@ checkopts()
d)
DOCKER=$OPTARG
;;
r)
TARBALL=$OPTARG
;;
i)
TRAIN_IO="train_io/"
;;
@ -55,16 +51,6 @@ echo ' ' > ${export_result_file}
CLOUD_MODEL_ZOO=../../../../model_zoo/
checkopts "$@"
if [ "$TARBALL" == "" ]; then
file=$(ls ../../../../output/mindspore-lite-*-train-linux-x64.tar.gz)
if [ -f ${file} ]; then
TARBALL=${file}
else
echo "release.tar.gz was not found"
display_usage
exit 1
fi
fi
if [ -z "${DOCKER}" ]; then
echo "MindSpore docker was not provided, attempting to run locally"
@ -76,13 +62,14 @@ if [ ! -z "${TRAIN_IO}" ]; then
fi
while read line; do
model_name=${line}
LFS=" " read -r -a line_array <<< ${line}
model_name=${line_array[0]}
if [[ $model_name == \#* ]]; then
continue
continue
fi
echo 'exporting' ${model_name}
if [ ! -z "${DOCKER}" ]; then
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER} /bin/bash -c "PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO} && chmod 444 mindir/${model_name}_train.mindir"
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER} /bin/bash -c "CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO} && chmod 444 mindir/${model_name}_train.mindir"
else
PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO}
fi

@ -2,8 +2,12 @@ BASE_DIR=$(realpath ../../../../)
APP:=bin/net_runner
MSLIB:=mindspore-lite
LMDLIB:=-lminddata-lite -ljpeg
LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai
MSDIR:=$(realpath package-$(TARGET)/lib)
ifneq ("$(wildcard $(MSDIR)/libhiai.so)","")
LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai
else
LHIAILIB:=
endif
SRC:=src/net_runner.cc
OBJ:=$(SRC:.cc=.o)

@ -96,7 +96,8 @@ void NetRunner::InitAndFigureInputs() {
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2;
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(ms_file_, &context);
auto session = mindspore::session::TrainSession::CreateSession(ms_file_, &context);
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session, &context);
session_ = loop_->train_session();
MS_ASSERT(nullptr != session_);

@ -0,0 +1,5 @@
*.mindir
*.ms
msl
package-*
dataset

@ -40,11 +40,11 @@ class TrainLoop {
public:
/// \brief Static method to create a TrainLoop object
///
/// \param[in] filename Filename to read flatbuffer from
/// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API
/// \param[in] context Defines the context of the session to be created
///
/// \return Pointer of MindSpore Lite TrainLoop
static TrainLoop *CreateTrainLoop(const std::string &model_filename, lite::Context *context, int batch_size = -1);
static TrainLoop *CreateTrainLoop(session::TrainSession *train_session, lite::Context *context, int batch_size = -1);
/// \brief Class destructor
virtual ~TrainLoop() = default;

@ -44,11 +44,40 @@ constexpr int RET_EXIT = 2;
class TrainLoopCallBack {
public:
virtual ~TrainLoopCallBack() = default;
/// \brief This method is called once before the network executing
///
/// \param[in] cb_data info about current execution
virtual void Begin(const TrainLoopCallBackData &cb_data) {}
/// \brief This method is called once following the network execution
///
/// \param[in] cb_data info about current execution
virtual void End(const TrainLoopCallBackData &cb_data) {}
/// \brief This method is called at the beginning of each epoch
///
/// \param[in] cb_data info about current execution
virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {}
/// \brief This method is called after the run of each epoch
///
/// \param[in] cb_data info about current execution
///
/// \return indication if to continue in the train loop:
/// RET_CONTINUE -- continue training
/// RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy)
/// RET_EXIT -- Exit training (due to error of some sort)
virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; }
/// \brief This method is called at the beginning of each step
///
/// \param[in] cb_data info about current execution
virtual void StepBegin(const TrainLoopCallBackData &cb_data) {}
/// \brief This method is called after each step is ran
///
/// \param[in] cb_data info about current execution
virtual void StepEnd(const TrainLoopCallBackData &cb_data) {}
};

@ -142,6 +142,14 @@ public class TrainSession {
return this.setLearningRate(this.sessionPtr, learning_rate);
}
public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) {
return this.setupVirtualBatch(this.sessionPtr, virtualBatchMultiplier, learningRate, momentum);
}
public boolean setupVirtualBatch(int virtualBatchMultiplier) {
return this.setupVirtualBatch(this.sessionPtr, virtualBatchMultiplier, -1.0f, -1.0f);
}
private native long createSession(String modelFilename, long msConfigPtr);
private native void bindThread(long sessionPtr, boolean if_bind);
@ -175,4 +183,6 @@ public class TrainSession {
private native boolean isEval(long sessionPtr);
private native boolean setLearningRate(long sessionPtr, float learning_rate);
private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);
}

@ -303,3 +303,18 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setLe
auto ret = train_session_ptr->SetLearningRate(learning_rate);
return (jboolean)(ret == mindspore::lite::RET_OK);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setupVirtualBatch(JNIEnv *env, jobject thiz,
jlong session_ptr,
jint virtualBatchMultiplier,
jfloat learningRate,
jfloat momentum) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *train_session_ptr = static_cast<mindspore::session::TrainSession *>(session_pointer);
auto ret = train_session_ptr->SetupVirtualBatch(virtualBatchMultiplier, learningRate, momentum);
return (jboolean)(ret == mindspore::lite::RET_OK);
}

@ -244,6 +244,7 @@ set(LITE_KERNEL_SRC
${LITE_DIR}/nnacl/infer/hashtable_lookup_infer.c
${LITE_DIR}/nnacl/infer/invert_permutation_infer.c
${LITE_DIR}/nnacl/infer/layer_norm_infer.c
${LITE_DIR}/nnacl/infer/layer_norm_grad_infer.c
${LITE_DIR}/nnacl/infer/lin_space_infer.c
${LITE_DIR}/nnacl/infer/lsh_projection_infer.c
${LITE_DIR}/nnacl/infer/lstm_infer.c

@ -65,11 +65,10 @@ void LayerNormGammaAndBeta(float *dst, const float *src, const float *gamma_data
}
int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data,
LayerNormParameter *param, size_t task_id) {
LayerNormParameter *param, float *out_mean, float *out_deno, size_t task_id) {
if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) {
return NNACL_NULL_PTR;
}
int step = UP_DIV(param->norm_outer_size_, param->op_parameter_.thread_num_);
int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_);
for (int i = task_id * step; i < thread_end; i++) {
@ -79,7 +78,10 @@ int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_
float square_mean = 0.0f;
LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &mean, &square_mean);
const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_);
if ((out_mean != NULL) && (out_deno != NULL)) {
out_mean[i] = mean;
out_deno[i] = deno;
}
if (param->norm_outer_size_ <= param->params_outer_size_) {
for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) {
const float *src_param = src_norm + x * param->params_inner_size_;

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_
#define MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_
#ifndef MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_FP32_H_
#define MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_FP32_H_
#include "nnacl/op_base.h"
#include "nnacl/layer_norm_parameter.h"
@ -24,9 +24,9 @@ extern "C" {
#endif
int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data,
LayerNormParameter *param, size_t task_id);
LayerNormParameter *param, float *out_mean, float *out_deno, size_t task_id);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_
#endif // MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_FP32_H_

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save