diff --git a/build.sh b/build.sh index 136a2a5c75..f06f4c86d3 100755 --- a/build.sh +++ b/build.sh @@ -671,6 +671,12 @@ build_lite_java_arm64() { if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ + + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ + + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/arm64-v8a/ else cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ @@ -697,6 +703,12 @@ build_lite_java_arm32() { if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ + + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ + + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/armeabi-v7a/ else cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ @@ -706,10 +718,15 @@ build_lite_java_arm32() { build_lite_java_x86() { # build mindspore-lite x86 + local inference_or_train=inference + if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then + inference_or_train=train + fi + if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then - local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD} + local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD} else - local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64 + local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64 fi if [[ "X$INC_BUILD" == "Xoff" ]] || [[ ! -f "${BASEPATH}/mindspore/lite/build/java/${JTARBALL}.tar.gz" ]]; then build_lite "x86_64" "off" "" @@ -721,8 +738,20 @@ build_lite_java_x86() { [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/linux_x86/libs/ mkdir -p ${JAVA_PATH}/java/linux_x86/libs/ mkdir -p ${JAVA_PATH}/native/libs/linux_x86/ - cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ - cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ + if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ + + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/linux_x86/libs/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/linux_x86/ + + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/java/linux_x86/libs/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/native/libs/linux_x86/ + else + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ + cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ + fi + [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL} } build_jni_arm64() { @@ -776,7 +805,7 @@ build_jni_x86_64() { mkdir -pv java/jni cd java/jni cmake -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ - -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${JAVA_PATH}/native/" + -DENABLE_VERBOSE=${ENABLE_VERBOSE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} "${JAVA_PATH}/native/" make -j$THREAD_NUM if [[ $? -ne 0 ]]; then echo "---------------- mindspore lite: build jni x86_64 failed----------------" @@ -825,11 +854,16 @@ build_java() { cd ${JAVA_PATH}/java/app/build zip -r mindspore-lite-maven-${VERSION_STR}.zip mindspore + local inference_or_train=inference + if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then + inference_or_train=train + fi + # build linux x86 jar if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then - local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD}-jar + local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}-jar else - local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-jar + local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-jar fi check_java_home build_lite_java_x86 @@ -843,15 +877,17 @@ build_java() { gradle releaseJar # install and package mkdir -p ${JAVA_PATH}/java/linux_x86/build/lib - cp ${JAVA_PATH}/java/linux_x86/libs/*.so ${JAVA_PATH}/java/linux_x86/build/lib/jar + cp ${JAVA_PATH}/java/linux_x86/libs/*.so* ${JAVA_PATH}/java/linux_x86/build/lib/jar cd ${JAVA_PATH}/java/linux_x86/build/ + cp -r ${JAVA_PATH}/java/linux_x86/build/lib ${JAVA_PATH}/java/linux_x86/build/${LINUX_X86_PACKAGE_NAME} tar czvf ${LINUX_X86_PACKAGE_NAME}.tar.gz ${LINUX_X86_PACKAGE_NAME} # copy output cp ${JAVA_PATH}/java/app/build/mindspore-lite-maven-${VERSION_STR}.zip ${BASEPATH}/output cp ${LINUX_X86_PACKAGE_NAME}.tar.gz ${BASEPATH}/output + cd ${BASEPATH}/output - [ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-inference-linux-x64 + [ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64 exit 0 } diff --git a/mindspore/lite/examples/train_lenet_java/pom.xml b/mindspore/lite/examples/train_lenet_java/pom.xml new file mode 100644 index 0000000000..862ae5f179 --- /dev/null +++ b/mindspore/lite/examples/train_lenet_java/pom.xml @@ -0,0 +1,55 @@ + + + 4.0.0 + + com.mindspore.lite.demo + train_lenet_java + 1.0 + + + 8 + 8 + + + + + + com.mindspore.lite + mindspore-lite-java + 1.0 + system + ${project.basedir}/lib/mindspore-lite-java.jar + + + + + ${project.name} + + + org.apache.maven.plugins + maven-assembly-plugin + + + + com.mindspore.lite.train_lenet.Main + + + + jar-with-dependencies + + + + + make-assemble + package + + single + + + + + + + \ No newline at end of file diff --git a/mindspore/lite/examples/train_lenet_java/resources/model/lenet_tod.ms b/mindspore/lite/examples/train_lenet_java/resources/model/lenet_tod.ms new file mode 100644 index 0000000000..3857a843e3 Binary files /dev/null and b/mindspore/lite/examples/train_lenet_java/resources/model/lenet_tod.ms differ diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java new file mode 100644 index 0000000000..47cb469c2e --- /dev/null +++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/DataSet.java @@ -0,0 +1,131 @@ +package com.mindspore.lite.train_lenet; + + +import java.io.BufferedInputStream; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.Vector; + +public class DataSet { + private long numOfClasses = 0; + private long expectedDataSize = 0; + public class DataLabelTuple { + public float[] data; + public int label; + } + Vector trainData; + Vector testData; + + public void initializeMNISTDatabase(String dpath) { + numOfClasses = 10; + trainData = new Vector(); + testData = new Vector(); + readMNISTFile(dpath + "/train/train-images-idx3-ubyte", dpath+"/train/train-labels-idx1-ubyte", trainData); + readMNISTFile(dpath + "/test/t10k-images-idx3-ubyte", dpath+"/test/t10k-labels-idx1-ubyte", testData); + + System.out.println("train data cnt: " + trainData.size()); + System.out.println("test data cnt: " + testData.size()); + } + + private String bytesToHex(byte[] bytes) { + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < bytes.length; i++) { + String hex = Integer.toHexString(bytes[i] & 0xFF); + if (hex.length() < 2) { + sb.append(0); + } + sb.append(hex); + } + return sb.toString(); + } + + private void readFile(BufferedInputStream inputStream, byte[] bytes, int len) throws IOException { + int result = inputStream.read(bytes, 0, len); + if (result != len) { + System.err.println("expected read " + len + " bytes, but " + result + " read"); + System.exit(1); + } + } + public void readMNISTFile(String inputFileName, String labelFileName, Vector dataset) { + try { + BufferedInputStream ibin = new BufferedInputStream(new FileInputStream(inputFileName)); + BufferedInputStream lbin = new BufferedInputStream(new FileInputStream(labelFileName)); + byte[] bytes = new byte[4]; + + readFile(ibin, bytes, 4); + if (!"00000803".equals(bytesToHex(bytes))) { // 2051 + System.err.println("The dataset is not valid: " + bytesToHex(bytes)); + return; + } + readFile(ibin, bytes, 4); + int inumber = Integer.parseInt(bytesToHex(bytes), 16); + + readFile(lbin, bytes, 4); + if (!"00000801".equals(bytesToHex(bytes))) { // 2049 + System.err.println("The dataset label is not valid: " + bytesToHex(bytes)); + return; + } + readFile(lbin, bytes, 4); + int lnumber = Integer.parseInt(bytesToHex(bytes), 16); + if (inumber != lnumber) { + System.err.println("input data cnt: " + inumber + " not equal label cnt: " + lnumber); + return; + } + + // read all labels + byte[] labels = new byte[lnumber]; + readFile(lbin, labels, lnumber); + + // row, column + readFile(ibin, bytes, 4); + int n_rows = Integer.parseInt(bytesToHex(bytes), 16); + readFile(ibin, bytes, 4); + int n_cols = Integer.parseInt(bytesToHex(bytes), 16); + if (n_rows != 28 || n_cols != 28) { + System.err.println("invalid n_rows: " + n_rows + " n_cols: " + n_cols); + return; + } + // read images + int image_size = n_rows * n_cols; + byte[] image_data = new byte[image_size]; + for (int i = 0; i < lnumber; i++) { + float [] hwc_bin_image = new float[32 * 32]; + readFile(ibin, image_data, image_size); + for (int r = 0; r < 32; r++) { + for (int c = 0; c < 32; c++) { + int index = r * 32 + c; + if (r < 2 || r > 29 || c < 2 || c > 29) { + hwc_bin_image[index] = 0; + } else { + int data = image_data[(r-2)*28 + (c-2)] & 0xff; + hwc_bin_image[index] = (float)data / 255.0f; + } + } + } + + DataLabelTuple data_label_tupel = new DataLabelTuple(); + data_label_tupel.data = hwc_bin_image; + data_label_tupel.label = labels[i] & 0xff; + dataset.add(data_label_tupel); + } + } catch (IOException e) { + System.err.println("Read Dateset exception"); + } + } + + public void setExpectedDataSize(long data_size) { + expectedDataSize = data_size; + } + + public long getNumOfClasses() { + return numOfClasses; + } + + public Vector getTestData() { + return testData; + } + + public Vector getTrainData() { + return trainData; + } +} diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java new file mode 100644 index 0000000000..4f70d451fc --- /dev/null +++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/Main.java @@ -0,0 +1,20 @@ +package com.mindspore.lite.train_lenet; + +import com.mindspore.lite.Version; + +public class Main { + public static void main(String[] args) { + System.out.println(Version.version()); + if (args.length < 2) { + System.err.println("model path and dataset path must be provided."); + return; + } + String modelPath = args[0]; + String datasetPath = args[1]; + + NetRunner net_runner = new NetRunner(); + net_runner.trainModel(modelPath, datasetPath); + } + + +} diff --git a/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java new file mode 100644 index 0000000000..4fbcfaeed2 --- /dev/null +++ b/mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java @@ -0,0 +1,220 @@ +package com.mindspore.lite.train_lenet; + +import com.mindspore.lite.MSTensor; +import com.mindspore.lite.TrainSession; +import com.mindspore.lite.config.MSConfig; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.Map; +import java.util.Vector; + +public class NetRunner { + private int dataIndex = 0; + private int labelIndex = 1; + private TrainSession session; + private long batchSize; + private long dataSize; // one input data size, in byte + private DataSet ds = new DataSet(); + private long numOfClasses; + private long cycles = 3500; + private int idx = 1; + private String trainedFilePath = "trained.ms"; + + public void initAndFigureInputs(String modelPath) { + MSConfig msConfig = new MSConfig(); + // arg 0: DeviceType:DT_CPU -> 0 + // arg 1: ThreadNum -> 2 + // arg 2: cpuBindMode:NO_BIND -> 0 + // arg 3: enable_fp16 -> false + msConfig.init(0, 2, 0, false); + session = new TrainSession(); + session.init(modelPath, msConfig); + session.setLearningRate(0.01f); + + List inputs = session.getInputs(); + if (inputs.size() <= 1) { + System.err.println("model input size: " + inputs.size()); + return; + } + + dataIndex = 0; + labelIndex = 1; + batchSize = inputs.get(dataIndex).getShape()[0]; + dataSize = inputs.get(dataIndex).size() / batchSize; + System.out.println("batch_size: " + batchSize); + int index = modelPath.lastIndexOf(".ms"); + if (index == -1) { + System.out.println("The model " + modelPath + " should be named *.ms"); + return; + } + trainedFilePath = modelPath.substring(0, index) + "_trained.ms"; + } + + public int initDB(String datasetPath) { + if (dataSize != 0) { + ds.setExpectedDataSize(dataSize); + } + ds.initializeMNISTDatabase(datasetPath); + numOfClasses = ds.getNumOfClasses(); + if (numOfClasses != 10) { + System.err.println("unexpected num_of_class: " + numOfClasses); + System.exit(1); + } + + if (ds.testData.size() == 0) { + System.err.println("test data size is 0"); + return -1; + } + + return 0; + } + + public float getLoss() { + MSTensor tensor = searchOutputsForSize(1); + return tensor.getFloatData()[0]; + } + + private MSTensor searchOutputsForSize(int size) { + Map outputs = session.getOutputMapByTensor(); + for (MSTensor tensor : outputs.values()) { + if (tensor.elementsNum() == size) { + return tensor; + } + } + System.err.println("can not find output the tensor which element num is " + size); + return null; + } + + public int trainLoop() { + session.train(); + float min_loss = 1000; + float max_acc = 0; + for (int i = 0; i < cycles; i++) { + fillInputData(ds.getTrainData(), false); + session.runGraph(); + float loss = getLoss(); + if (min_loss > loss) { + min_loss = loss; + } + if ((i + 1) % 500 == 0) { + float acc = calculateAccuracy(10); // only test 10 batch size + if (max_acc < acc) { + max_acc = acc; + } + System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc); + } + + } + return 0; + } + + public float calculateAccuracy(long maxTests) { + float accuracy = 0; + Vector test_set = ds.getTestData(); + long tests = test_set.size() / batchSize; + if (maxTests != -1 && tests < maxTests) { + tests = maxTests; + } + session.eval(); + for (long i = 0; i < tests; i++) { + Vector labels = fillInputData(test_set, (maxTests == -1)); + if (labels.size() != batchSize) { + System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize); + System.exit(1); + } + session.runGraph(); + MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses)); + if (outputsv == null) { + System.err.println("can not find output tensor with size: " + batchSize * numOfClasses); + System.exit(1); + } + float[] scores = outputsv.getFloatData(); + for (int b = 0; b < batchSize; b++) { + int max_idx = 0; + float max_score = scores[(int) (numOfClasses * b)]; + for (int c = 0; c < numOfClasses; c++) { + if (scores[(int) (numOfClasses * b + c)] > max_score) { + max_score = scores[(int) (numOfClasses * b + c)]; + max_idx = c; + } + + } + if (labels.get(b) == max_idx) { + accuracy += 1.0; + } + } + } + session.train(); + accuracy /= (batchSize * tests); + return accuracy; + } + + // each time fill batch_size data + Vector fillInputData(Vector dataset, boolean serially) { + Vector labelsVec = new Vector(); + int totalSize = dataset.size(); + + List inputs = session.getInputs(); + + int inputDataCnt = inputs.get(dataIndex).elementsNum(); + float[] inputBatchData = new float[inputDataCnt]; + + int labelDataCnt = inputs.get(labelIndex).elementsNum(); + int[] labelBatchData = new int[labelDataCnt]; + + for (int i = 0; i < batchSize; i++) { + if (serially) { + idx = (++idx) % totalSize; + } else { + idx = (int) (Math.random() * totalSize); + } + + int label = 0; + DataSet.DataLabelTuple dataLabelTuple = dataset.get(idx); + label = dataLabelTuple.label; + System.arraycopy(dataLabelTuple.data, 0, inputBatchData, (int) (i * dataLabelTuple.data.length), dataLabelTuple.data.length); + labelBatchData[i] = label; + labelsVec.add(label); + } + + ByteBuffer byteBuf = ByteBuffer.allocateDirect(inputBatchData.length * Float.BYTES); + byteBuf.order(ByteOrder.nativeOrder()); + for (int i = 0; i < inputBatchData.length; i++) { + byteBuf.putFloat(inputBatchData[i]); + } + inputs.get(dataIndex).setData(byteBuf); + + ByteBuffer labelByteBuf = ByteBuffer.allocateDirect(labelBatchData.length * 4); + labelByteBuf.order(ByteOrder.nativeOrder()); + for (int i = 0; i < labelBatchData.length; i++) { + labelByteBuf.putInt(labelBatchData[i]); + } + inputs.get(labelIndex).setData(labelByteBuf); + + return labelsVec; + } + + public void trainModel(String modelPath, String datasetPath) { + System.out.println("==========Loading Model, Create Train Session============="); + initAndFigureInputs(modelPath); + System.out.println("==========Initing DataSet================"); + initDB(datasetPath); + System.out.println("==========Training Model==================="); + trainLoop(); + System.out.println("==========Evaluating The Trained Model============"); + float acc = calculateAccuracy(-1); + System.out.println("accuracy = " + acc); + + if (cycles > 0) { + if (session.saveToFile(trainedFilePath)) { + System.out.println("Trained model successfully saved: " + trainedFilePath); + } else { + System.err.println("Save model error."); + } + } + session.free(); + } + +}