diff --git a/build.sh b/build.sh
index 136a2a5c75..f06f4c86d3 100755
--- a/build.sh
+++ b/build.sh
@@ -671,6 +671,12 @@ build_lite_java_arm64() {
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
+
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
+
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/arm64-v8a/
else
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
@@ -697,6 +703,12 @@ build_lite_java_arm32() {
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
+
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
+
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/armeabi-v7a/
else
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
@@ -706,10 +718,15 @@ build_lite_java_arm32() {
build_lite_java_x86() {
# build mindspore-lite x86
+ local inference_or_train=inference
+ if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
+ inference_or_train=train
+ fi
+
if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then
- local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD}
+ local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}
else
- local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64
+ local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64
fi
if [[ "X$INC_BUILD" == "Xoff" ]] || [[ ! -f "${BASEPATH}/mindspore/lite/build/java/${JTARBALL}.tar.gz" ]]; then
build_lite "x86_64" "off" ""
@@ -721,8 +738,20 @@ build_lite_java_x86() {
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/linux_x86/libs/
mkdir -p ${JAVA_PATH}/java/linux_x86/libs/
mkdir -p ${JAVA_PATH}/native/libs/linux_x86/
- cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
- cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
+ if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
+
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/linux_x86/libs/
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/linux_x86/
+
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/java/linux_x86/libs/
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/native/libs/linux_x86/
+ else
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
+ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
+ fi
+ [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
}
build_jni_arm64() {
@@ -776,7 +805,7 @@ build_jni_x86_64() {
mkdir -pv java/jni
cd java/jni
cmake -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
- -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${JAVA_PATH}/native/"
+ -DENABLE_VERBOSE=${ENABLE_VERBOSE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} "${JAVA_PATH}/native/"
make -j$THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "---------------- mindspore lite: build jni x86_64 failed----------------"
@@ -825,11 +854,16 @@ build_java() {
cd ${JAVA_PATH}/java/app/build
zip -r mindspore-lite-maven-${VERSION_STR}.zip mindspore
+ local inference_or_train=inference
+ if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
+ inference_or_train=train
+ fi
+
# build linux x86 jar
if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then
- local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD}-jar
+ local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}-jar
else
- local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-jar
+ local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-jar
fi
check_java_home
build_lite_java_x86
@@ -843,15 +877,17 @@ build_java() {
gradle releaseJar
# install and package
mkdir -p ${JAVA_PATH}/java/linux_x86/build/lib
- cp ${JAVA_PATH}/java/linux_x86/libs/*.so ${JAVA_PATH}/java/linux_x86/build/lib/jar
+ cp ${JAVA_PATH}/java/linux_x86/libs/*.so* ${JAVA_PATH}/java/linux_x86/build/lib/jar
cd ${JAVA_PATH}/java/linux_x86/build/
+
cp -r ${JAVA_PATH}/java/linux_x86/build/lib ${JAVA_PATH}/java/linux_x86/build/${LINUX_X86_PACKAGE_NAME}
tar czvf ${LINUX_X86_PACKAGE_NAME}.tar.gz ${LINUX_X86_PACKAGE_NAME}
# copy output
cp ${JAVA_PATH}/java/app/build/mindspore-lite-maven-${VERSION_STR}.zip ${BASEPATH}/output
cp ${LINUX_X86_PACKAGE_NAME}.tar.gz ${BASEPATH}/output
+
cd ${BASEPATH}/output
- [ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-inference-linux-x64
+ [ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64
exit 0
}
diff --git a/mindspore/lite/examples/train_lenet_java/pom.xml b/mindspore/lite/examples/train_lenet_java/pom.xml
new file mode 100644
index 0000000000..862ae5f179
--- /dev/null
+++ b/mindspore/lite/examples/train_lenet_java/pom.xml
@@ -0,0 +1,55 @@
+
+
+ 4.0.0
+
+ com.mindspore.lite.demo
+ train_lenet_java
+ 1.0
+
+
+ 8
+ 8
+
+
+
+
+
+ com.mindspore.lite
+ mindspore-lite-java
+ 1.0
+ system
+ ${project.basedir}/lib/mindspore-lite-java.jar
+
+
+
+
+ ${project.name}
+
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+
+
+
+ com.mindspore.lite.train_lenet.Main
+
+
+
+ jar-with-dependencies
+
+
+
+
+ make-assemble
+ package
+
+ single
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/mindspore/lite/examples/train_lenet_java/resources/model/lenet_tod.ms b/mindspore/lite/examples/train_lenet_java/resources/model/lenet_tod.ms
new file mode 100644
index 0000000000..3857a843e3
Binary files /dev/null and b/mindspore/lite/examples/train_lenet_java/resources/model/lenet_tod.ms differ
diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java
new file mode 100644
index 0000000000..47cb469c2e
--- /dev/null
+++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java
@@ -0,0 +1,131 @@
+package com.mindspore.lite.train_lenet;
+
+
+import java.io.BufferedInputStream;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.Vector;
+
+public class DataSet {
+ private long numOfClasses = 0;
+ private long expectedDataSize = 0;
+ public class DataLabelTuple {
+ public float[] data;
+ public int label;
+ }
+ Vector trainData;
+ Vector testData;
+
+ public void initializeMNISTDatabase(String dpath) {
+ numOfClasses = 10;
+ trainData = new Vector();
+ testData = new Vector();
+ readMNISTFile(dpath + "/train/train-images-idx3-ubyte", dpath+"/train/train-labels-idx1-ubyte", trainData);
+ readMNISTFile(dpath + "/test/t10k-images-idx3-ubyte", dpath+"/test/t10k-labels-idx1-ubyte", testData);
+
+ System.out.println("train data cnt: " + trainData.size());
+ System.out.println("test data cnt: " + testData.size());
+ }
+
+ private String bytesToHex(byte[] bytes) {
+ StringBuffer sb = new StringBuffer();
+ for (int i = 0; i < bytes.length; i++) {
+ String hex = Integer.toHexString(bytes[i] & 0xFF);
+ if (hex.length() < 2) {
+ sb.append(0);
+ }
+ sb.append(hex);
+ }
+ return sb.toString();
+ }
+
+ private void readFile(BufferedInputStream inputStream, byte[] bytes, int len) throws IOException {
+ int result = inputStream.read(bytes, 0, len);
+ if (result != len) {
+ System.err.println("expected read " + len + " bytes, but " + result + " read");
+ System.exit(1);
+ }
+ }
+ public void readMNISTFile(String inputFileName, String labelFileName, Vector dataset) {
+ try {
+ BufferedInputStream ibin = new BufferedInputStream(new FileInputStream(inputFileName));
+ BufferedInputStream lbin = new BufferedInputStream(new FileInputStream(labelFileName));
+ byte[] bytes = new byte[4];
+
+ readFile(ibin, bytes, 4);
+ if (!"00000803".equals(bytesToHex(bytes))) { // 2051
+ System.err.println("The dataset is not valid: " + bytesToHex(bytes));
+ return;
+ }
+ readFile(ibin, bytes, 4);
+ int inumber = Integer.parseInt(bytesToHex(bytes), 16);
+
+ readFile(lbin, bytes, 4);
+ if (!"00000801".equals(bytesToHex(bytes))) { // 2049
+ System.err.println("The dataset label is not valid: " + bytesToHex(bytes));
+ return;
+ }
+ readFile(lbin, bytes, 4);
+ int lnumber = Integer.parseInt(bytesToHex(bytes), 16);
+ if (inumber != lnumber) {
+ System.err.println("input data cnt: " + inumber + " not equal label cnt: " + lnumber);
+ return;
+ }
+
+ // read all labels
+ byte[] labels = new byte[lnumber];
+ readFile(lbin, labels, lnumber);
+
+ // row, column
+ readFile(ibin, bytes, 4);
+ int n_rows = Integer.parseInt(bytesToHex(bytes), 16);
+ readFile(ibin, bytes, 4);
+ int n_cols = Integer.parseInt(bytesToHex(bytes), 16);
+ if (n_rows != 28 || n_cols != 28) {
+ System.err.println("invalid n_rows: " + n_rows + " n_cols: " + n_cols);
+ return;
+ }
+ // read images
+ int image_size = n_rows * n_cols;
+ byte[] image_data = new byte[image_size];
+ for (int i = 0; i < lnumber; i++) {
+ float [] hwc_bin_image = new float[32 * 32];
+ readFile(ibin, image_data, image_size);
+ for (int r = 0; r < 32; r++) {
+ for (int c = 0; c < 32; c++) {
+ int index = r * 32 + c;
+ if (r < 2 || r > 29 || c < 2 || c > 29) {
+ hwc_bin_image[index] = 0;
+ } else {
+ int data = image_data[(r-2)*28 + (c-2)] & 0xff;
+ hwc_bin_image[index] = (float)data / 255.0f;
+ }
+ }
+ }
+
+ DataLabelTuple data_label_tupel = new DataLabelTuple();
+ data_label_tupel.data = hwc_bin_image;
+ data_label_tupel.label = labels[i] & 0xff;
+ dataset.add(data_label_tupel);
+ }
+ } catch (IOException e) {
+ System.err.println("Read Dateset exception");
+ }
+ }
+
+ public void setExpectedDataSize(long data_size) {
+ expectedDataSize = data_size;
+ }
+
+ public long getNumOfClasses() {
+ return numOfClasses;
+ }
+
+ public Vector getTestData() {
+ return testData;
+ }
+
+ public Vector getTrainData() {
+ return trainData;
+ }
+}
diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java
new file mode 100644
index 0000000000..4f70d451fc
--- /dev/null
+++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java
@@ -0,0 +1,20 @@
+package com.mindspore.lite.train_lenet;
+
+import com.mindspore.lite.Version;
+
+public class Main {
+ public static void main(String[] args) {
+ System.out.println(Version.version());
+ if (args.length < 2) {
+ System.err.println("model path and dataset path must be provided.");
+ return;
+ }
+ String modelPath = args[0];
+ String datasetPath = args[1];
+
+ NetRunner net_runner = new NetRunner();
+ net_runner.trainModel(modelPath, datasetPath);
+ }
+
+
+}
diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java
new file mode 100644
index 0000000000..4fbcfaeed2
--- /dev/null
+++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java
@@ -0,0 +1,220 @@
+package com.mindspore.lite.train_lenet;
+
+import com.mindspore.lite.MSTensor;
+import com.mindspore.lite.TrainSession;
+import com.mindspore.lite.config.MSConfig;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.List;
+import java.util.Map;
+import java.util.Vector;
+
+public class NetRunner {
+ private int dataIndex = 0;
+ private int labelIndex = 1;
+ private TrainSession session;
+ private long batchSize;
+ private long dataSize; // one input data size, in byte
+ private DataSet ds = new DataSet();
+ private long numOfClasses;
+ private long cycles = 3500;
+ private int idx = 1;
+ private String trainedFilePath = "trained.ms";
+
+ public void initAndFigureInputs(String modelPath) {
+ MSConfig msConfig = new MSConfig();
+ // arg 0: DeviceType:DT_CPU -> 0
+ // arg 1: ThreadNum -> 2
+ // arg 2: cpuBindMode:NO_BIND -> 0
+ // arg 3: enable_fp16 -> false
+ msConfig.init(0, 2, 0, false);
+ session = new TrainSession();
+ session.init(modelPath, msConfig);
+ session.setLearningRate(0.01f);
+
+ List inputs = session.getInputs();
+ if (inputs.size() <= 1) {
+ System.err.println("model input size: " + inputs.size());
+ return;
+ }
+
+ dataIndex = 0;
+ labelIndex = 1;
+ batchSize = inputs.get(dataIndex).getShape()[0];
+ dataSize = inputs.get(dataIndex).size() / batchSize;
+ System.out.println("batch_size: " + batchSize);
+ int index = modelPath.lastIndexOf(".ms");
+ if (index == -1) {
+ System.out.println("The model " + modelPath + " should be named *.ms");
+ return;
+ }
+ trainedFilePath = modelPath.substring(0, index) + "_trained.ms";
+ }
+
+ public int initDB(String datasetPath) {
+ if (dataSize != 0) {
+ ds.setExpectedDataSize(dataSize);
+ }
+ ds.initializeMNISTDatabase(datasetPath);
+ numOfClasses = ds.getNumOfClasses();
+ if (numOfClasses != 10) {
+ System.err.println("unexpected num_of_class: " + numOfClasses);
+ System.exit(1);
+ }
+
+ if (ds.testData.size() == 0) {
+ System.err.println("test data size is 0");
+ return -1;
+ }
+
+ return 0;
+ }
+
+ public float getLoss() {
+ MSTensor tensor = searchOutputsForSize(1);
+ return tensor.getFloatData()[0];
+ }
+
+ private MSTensor searchOutputsForSize(int size) {
+ Map outputs = session.getOutputMapByTensor();
+ for (MSTensor tensor : outputs.values()) {
+ if (tensor.elementsNum() == size) {
+ return tensor;
+ }
+ }
+ System.err.println("can not find output the tensor which element num is " + size);
+ return null;
+ }
+
+ public int trainLoop() {
+ session.train();
+ float min_loss = 1000;
+ float max_acc = 0;
+ for (int i = 0; i < cycles; i++) {
+ fillInputData(ds.getTrainData(), false);
+ session.runGraph();
+ float loss = getLoss();
+ if (min_loss > loss) {
+ min_loss = loss;
+ }
+ if ((i + 1) % 500 == 0) {
+ float acc = calculateAccuracy(10); // only test 10 batch size
+ if (max_acc < acc) {
+ max_acc = acc;
+ }
+ System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
+ }
+
+ }
+ return 0;
+ }
+
+ public float calculateAccuracy(long maxTests) {
+ float accuracy = 0;
+ Vector test_set = ds.getTestData();
+ long tests = test_set.size() / batchSize;
+ if (maxTests != -1 && tests < maxTests) {
+ tests = maxTests;
+ }
+ session.eval();
+ for (long i = 0; i < tests; i++) {
+ Vector labels = fillInputData(test_set, (maxTests == -1));
+ if (labels.size() != batchSize) {
+ System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize);
+ System.exit(1);
+ }
+ session.runGraph();
+ MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses));
+ if (outputsv == null) {
+ System.err.println("can not find output tensor with size: " + batchSize * numOfClasses);
+ System.exit(1);
+ }
+ float[] scores = outputsv.getFloatData();
+ for (int b = 0; b < batchSize; b++) {
+ int max_idx = 0;
+ float max_score = scores[(int) (numOfClasses * b)];
+ for (int c = 0; c < numOfClasses; c++) {
+ if (scores[(int) (numOfClasses * b + c)] > max_score) {
+ max_score = scores[(int) (numOfClasses * b + c)];
+ max_idx = c;
+ }
+
+ }
+ if (labels.get(b) == max_idx) {
+ accuracy += 1.0;
+ }
+ }
+ }
+ session.train();
+ accuracy /= (batchSize * tests);
+ return accuracy;
+ }
+
+ // each time fill batch_size data
+ Vector fillInputData(Vector dataset, boolean serially) {
+ Vector labelsVec = new Vector();
+ int totalSize = dataset.size();
+
+ List inputs = session.getInputs();
+
+ int inputDataCnt = inputs.get(dataIndex).elementsNum();
+ float[] inputBatchData = new float[inputDataCnt];
+
+ int labelDataCnt = inputs.get(labelIndex).elementsNum();
+ int[] labelBatchData = new int[labelDataCnt];
+
+ for (int i = 0; i < batchSize; i++) {
+ if (serially) {
+ idx = (++idx) % totalSize;
+ } else {
+ idx = (int) (Math.random() * totalSize);
+ }
+
+ int label = 0;
+ DataSet.DataLabelTuple dataLabelTuple = dataset.get(idx);
+ label = dataLabelTuple.label;
+ System.arraycopy(dataLabelTuple.data, 0, inputBatchData, (int) (i * dataLabelTuple.data.length), dataLabelTuple.data.length);
+ labelBatchData[i] = label;
+ labelsVec.add(label);
+ }
+
+ ByteBuffer byteBuf = ByteBuffer.allocateDirect(inputBatchData.length * Float.BYTES);
+ byteBuf.order(ByteOrder.nativeOrder());
+ for (int i = 0; i < inputBatchData.length; i++) {
+ byteBuf.putFloat(inputBatchData[i]);
+ }
+ inputs.get(dataIndex).setData(byteBuf);
+
+ ByteBuffer labelByteBuf = ByteBuffer.allocateDirect(labelBatchData.length * 4);
+ labelByteBuf.order(ByteOrder.nativeOrder());
+ for (int i = 0; i < labelBatchData.length; i++) {
+ labelByteBuf.putInt(labelBatchData[i]);
+ }
+ inputs.get(labelIndex).setData(labelByteBuf);
+
+ return labelsVec;
+ }
+
+ public void trainModel(String modelPath, String datasetPath) {
+ System.out.println("==========Loading Model, Create Train Session=============");
+ initAndFigureInputs(modelPath);
+ System.out.println("==========Initing DataSet================");
+ initDB(datasetPath);
+ System.out.println("==========Training Model===================");
+ trainLoop();
+ System.out.println("==========Evaluating The Trained Model============");
+ float acc = calculateAccuracy(-1);
+ System.out.println("accuracy = " + acc);
+
+ if (cycles > 0) {
+ if (session.saveToFile(trainedFilePath)) {
+ System.out.println("Trained model successfully saved: " + trainedFilePath);
+ } else {
+ System.err.println("Save model error.");
+ }
+ }
+ session.free();
+ }
+
+}