|
|
|
@ -55,19 +55,19 @@ public class TrainSession {
|
|
|
|
|
public List<MSTensor> getInputs() {
|
|
|
|
|
List<Long> ret = this.getInputs(this.sessionPtr);
|
|
|
|
|
ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
|
|
|
|
|
for (Long ms_tensor_addr : ret) {
|
|
|
|
|
MSTensor msTensor = new MSTensor(ms_tensor_addr);
|
|
|
|
|
for (Long msTensorAddr : ret) {
|
|
|
|
|
MSTensor msTensor = new MSTensor(msTensorAddr);
|
|
|
|
|
tensors.add(msTensor);
|
|
|
|
|
}
|
|
|
|
|
return tensors;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public MSTensor getInputsByTensorName(String tensorName) {
|
|
|
|
|
Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName);
|
|
|
|
|
if(tensor_addr == null){
|
|
|
|
|
Long tensorAddr = this.getInputsByTensorName(this.sessionPtr, tensorName);
|
|
|
|
|
if(tensorAddr == null) {
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
MSTensor msTensor = new MSTensor(tensor_addr);
|
|
|
|
|
MSTensor msTensor = new MSTensor(tensorAddr);
|
|
|
|
|
return msTensor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -98,11 +98,11 @@ public class TrainSession {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public MSTensor getOutputByTensorName(String tensorName) {
|
|
|
|
|
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
|
|
|
|
|
if(tensor_addr == null){
|
|
|
|
|
Long tensorAddr = getOutputByTensorName(this.sessionPtr, tensorName);
|
|
|
|
|
if(tensorAddr == null) {
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
return new MSTensor(tensor_addr);
|
|
|
|
|
return new MSTensor(tensorAddr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public void free() {
|
|
|
|
@ -111,11 +111,11 @@ public class TrainSession {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public boolean resize(List<MSTensor> inputs, int[][] dims) {
|
|
|
|
|
long[] inputs_array = new long[inputs.size()];
|
|
|
|
|
long[] inputsArray = new long[inputs.size()];
|
|
|
|
|
for (int i = 0; i < inputs.size(); i++) {
|
|
|
|
|
inputs_array[i] = inputs.get(i).getMSTensorPtr();
|
|
|
|
|
inputsArray[i] = inputs.get(i).getMSTensorPtr();
|
|
|
|
|
}
|
|
|
|
|
return this.resize(this.sessionPtr, inputs_array, dims);
|
|
|
|
|
return this.resize(this.sessionPtr, inputsArray, dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public boolean saveToFile(String modelFilename) {
|
|
|
|
|