增加TANH 激活函数,默认激活函数使用TANH

pull/8/head
Administrator 5 years ago
parent 49d3c71d39
commit c66f24d495

@ -6,8 +6,10 @@ import org.wlld.tools.ArithUtil;
public class Tanh implements ActiveFunction {
@Override
public double function(double x) {
double son = ArithUtil.sub(Math.exp(x), Math.exp(-x));
double mother = ArithUtil.add(Math.exp(x), Math.exp(-x));
double x1 = Math.exp(x);
double x2 = Math.exp(-x);
double son = ArithUtil.sub(x1, x2);
double mother = ArithUtil.add(x1, x2);
return ArithUtil.div(son, mother);
}

@ -281,7 +281,9 @@ public class Operation {//进行计算
Map<Integer, Double> map = new HashMap<>();
map.put(tagging, 1.0);
List<Double> feature = getFeature(myMatrix);
//System.out.println(feature);
if (templeConfig.isShowLog()) {
System.out.println(feature);
}
intoDnnNetwork(1, feature, templeConfig.getSensoryNerves(), true, map, null);
}

@ -44,7 +44,7 @@ public class TempleConfig {
private boolean boxReady = false;//边框已经学习完毕
private double iouTh = 0.5;//IOU阈值
private int lvqNub = 10;//lvq循环次数默认30
private VectorK vectorK;
private VectorK vectorK;//特征向量均值类
private boolean isThreeChannel = false;//是否启用三通道
private int classifier = Classifier.VAvg;//默认分类类别使用的是向量均值分类
private Normalization normalization = new Normalization();//统一归一化
@ -52,6 +52,8 @@ public class TempleConfig {
private int sensoryNerveNub;//输入神经元个数
private boolean isShowLog = false;
private ActiveFunction activeFunction = new Tanh();
private double studyPoint = 0;
public boolean isAccurate() {
return isAccurate;
}
@ -60,6 +62,10 @@ public class TempleConfig {
isAccurate = accurate;
}
public void setStudyPoint(double studyPoint) {
this.studyPoint = studyPoint;
}
public void setActiveFunction(ActiveFunction activeFunction) {
this.activeFunction = activeFunction;
}
@ -131,6 +137,10 @@ public class TempleConfig {
this.isShowLog = isShowLog;
}
public boolean isShowLog() {
return isShowLog;
}
public void startLvq() throws Exception {
switch (classifier) {
case Classifier.LVQ:
@ -231,7 +241,7 @@ public class TempleConfig {
initConvolutionVision(initPower, width, height);
break;
case StudyPattern.Cover_Pattern://覆盖学习模式
initNerveManager(initPower, 9, deep);
initNerveManager(initPower, 9, deep, studyPoint);
break;
}
}
@ -247,13 +257,13 @@ public class TempleConfig {
row = 5;
column = (int) (d * row);
}
initNerveManager(initPower, row * column, deep);
initNerveManager(initPower, row * column, deep,studyPoint);
}
private void initNerveManager(boolean initPower, int sensoryNerveNub
, int deep) throws Exception {
, int deep, double studyPoint) throws Exception {
nerveManager = new NerveManager(sensoryNerveNub, 9,
classificationNub, deep, activeFunction, false, isAccurate);
classificationNub, deep, activeFunction, false, isAccurate, studyPoint);
nerveManager.init(initPower, false, isShowLog);
}
@ -271,7 +281,7 @@ public class TempleConfig {
if (isThreeChannel) {
nub = nub * 3;
}
initNerveManager(true, nub, this.deep);
initNerveManager(true, nub, this.deep,studyPoint);
break;
case Classifier.LVQ:
lvq = new LVQ(classificationNub, lvqNub);
@ -306,7 +316,7 @@ public class TempleConfig {
private NerveManager initNerveManager(Map<Integer, Matrix> matrixMap, boolean initPower, int deep) throws Exception {
//初始化卷积神经网络
NerveManager convolutionNerveManager = new NerveManager(1, 1,
1, deep - 1, new ReLu(), true, isAccurate);
1, deep - 1, new ReLu(), true, isAccurate,studyPoint);
convolutionNerveManager.setMatrixMap(matrixMap);//给卷积网络管理器注入期望矩阵
convolutionNerveManager.init(initPower, true, isShowLog);
return convolutionNerveManager;
@ -411,10 +421,6 @@ public class TempleConfig {
return nerveManager.getSensoryNerves();
}
public void setStudy(double studyPoint) throws Exception {//设置学习率
nerveManager.setStudyPoint(studyPoint);
}
public void setStudyList(List<Double> list) {//设置每一层不同的学习率
if (studyPattern == StudyPattern.Accuracy_Pattern) {
//给卷积层设置层学习率

@ -44,19 +44,6 @@ public class NerveManager {
this.matrixMap = matrixMap;
}
public double getStudyPoint() {
return studyPoint;
}
public void setStudyPoint(double studyPoint) throws Exception {
//设置学习率
if (studyPoint <= 1 && studyPoint > 0) {
this.studyPoint = studyPoint;
} else {
throw new Exception("studyPoint Values range from 0 to 1");
}
}
private ModelParameter getDymModelParameter() throws Exception {//获取动态神经元参数
ModelParameter modelParameter = new ModelParameter();
List<DymNerveStudy> dymNerveStudies = new ArrayList<>();//动态神经元隐层
@ -210,7 +197,8 @@ public class NerveManager {
* @throws Exception
*/
public NerveManager(int sensoryNerveNub, int hiddenNerveNub, int outNerveNub
, int hiddenDepth, ActiveFunction activeFunction, boolean isDynamic, boolean isAccurate) throws Exception {
, int hiddenDepth, ActiveFunction activeFunction, boolean isDynamic, boolean isAccurate,
double studyPoint) throws Exception {
if (sensoryNerveNub > 0 && hiddenNerveNub > 0 && outNerveNub > 0 && hiddenDepth > 0 && activeFunction != null) {
this.hiddenNerveNub = hiddenNerveNub;
this.sensoryNerveNub = sensoryNerveNub;
@ -219,6 +207,9 @@ public class NerveManager {
this.activeFunction = activeFunction;
this.isDynamic = isDynamic;
this.isAccurate = isAccurate;
if (studyPoint > 0 && studyPoint < 1) {
this.studyPoint = studyPoint;
}
} else {
throw new Exception("param is null");
}

@ -22,7 +22,7 @@ public abstract class Nerve {
private int id;//同级神经元编号,注意在同层编号中ID应有唯一性
protected int upNub;//上一层神经元数量
protected int downNub;//下一层神经元的数量
protected Map<Long, List<Double>> features = new HashMap<>();
protected Map<Long, List<Double>> features = new HashMap<>();//上一层神经元输入的数值
protected Matrix nerveMatrix = new Matrix(3, 3);//权重矩阵可获取及注入
protected Map<Long, Matrix> matrixMap = new HashMap<>();//参数矩阵
protected double threshold;//此神经元的阈值需要取出
@ -265,9 +265,7 @@ public abstract class Nerve {
double w = dendrites.get(i + 1);
//System.out.println("w==" + w + ",value==" + value);
sigma = ArithUtil.add(ArithUtil.mul(w, value), sigma);
//logger.debug("name:{},eventId:{},id:{},myId:{},w:{},value:{}", name, eventId, i + 1, id, w, value);
}
//logger.debug("当前神经元线性变化已经完成,name:{},id:{}", name, getId());
return ArithUtil.sub(sigma, threshold);
}

@ -46,7 +46,7 @@ public class OutNerve extends Nerve {
if (E.containsKey(getId())) {
this.E = E.get(getId());
} else {
this.E = -1;
this.E = 0;
}
if (isShowLog) {
System.out.println("E==" + this.E + ",out==" + out + ",nerveId==" + getId());
@ -71,10 +71,6 @@ public class OutNerve extends Nerve {
Matrix myMatrix = dynamicNerve(matrix, eventId, isKernelStudy);
if (isKernelStudy) {//回传
Matrix matrix1 = matrixMapE.get(E);
// if (isShowLog) {
// System.out.println("E================" + E);
// System.out.println(myMatrix.getString());
// }
if (matrix1.getX() <= myMatrix.getX() && matrix1.getY() <= myMatrix.getY()) {
double g = getGradient(myMatrix, matrix1);
backMatrix(g, eventId);

@ -0,0 +1,153 @@
package coverTest;
import com.alibaba.fastjson.JSON;
import org.wlld.MatrixTools.Matrix;
import org.wlld.ModelData;
import org.wlld.config.Classifier;
import org.wlld.config.StudyPattern;
import org.wlld.imageRecognition.Operation;
import org.wlld.imageRecognition.Picture;
import org.wlld.imageRecognition.TempleConfig;
import org.wlld.nerveEntity.ModelParameter;
import org.wlld.tools.ArithUtil;
import java.util.HashMap;
import java.util.Map;
public class FoodTest {
public static void main(String[] args) throws Exception {
food();
}
public static void food() throws Exception {
Picture picture = new Picture();
TempleConfig templeConfig = new TempleConfig(false, true);
templeConfig.setClassifier(Classifier.DNN);
templeConfig.isShowLog(true);
templeConfig.init(StudyPattern.Accuracy_Pattern, true, 640, 640, 4);
ModelParameter modelParameter2 = JSON.parseObject(ModelData.DATA3, ModelParameter.class);
templeConfig.insertModel(modelParameter2);
Operation operation = new Operation(templeConfig);
// 一阶段
// for (int j = 0; j < 1; j++) {
// for (int i = 1; i < 1900; i++) {//一阶段
// System.out.println("study1===================" + i);
// //读取本地URL地址图片,并转化成矩阵
// Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
// Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
// Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
// Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
// //将图像矩阵和标注加入进行学习Accuracy_Pattern 模式 进行第二次学习
// //第二次学习的时候,第三个参数必须是 true
// operation.learning(a, 1, false);
// operation.learning(b, 2, false);
// operation.learning(c, 3, false);
// operation.learning(d, 4, false);
// }
// }
//二阶段
// for (int i = 1; i < 1900; i++) {
// System.out.println("avg==" + i);
// Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
// Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
// Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
// Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
// operation.normalization(a, templeConfig.getConvolutionNerveManager());
// operation.normalization(b, templeConfig.getConvolutionNerveManager());
// operation.normalization(c, templeConfig.getConvolutionNerveManager());
// operation.normalization(d, templeConfig.getConvolutionNerveManager());
// }
// templeConfig.getNormalization().avg();
for (int j = 0; j < 1; j++) {
for (int i = 1; i < 1900; i++) {
System.out.println("j==" + j + ",study2==================" + i);
//读取本地URL地址图片,并转化成矩阵
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
//将图像矩阵和标注加入进行学习Accuracy_Pattern 模式 进行第二次学习
//第二次学习的时候,第三个参数必须是 true
operation.learning(a, 1, true);
operation.learning(b, 2, true);
operation.learning(c, 3, true);
operation.learning(d, 4, true);
}
}
templeConfig.finishStudy();//结束学习
ModelParameter modelParameter = templeConfig.getModel();
String model = JSON.toJSONString(modelParameter);
System.out.println(model);
// ModelParameter modelParameter2 = JSON.parseObject(model, ModelParameter.class);
// TempleConfig templeConfig2 = new TempleConfig(false);
// templeConfig2.init(StudyPattern.Accuracy_Pattern, true, 1000, 1000, 2);
// templeConfig2.insertModel(modelParameter2);
// Operation operation2 = new Operation(templeConfig2);
int wrong = 0;
int allNub = 0;
for (int i = 1900; i <= 1998; i++) {
//读取本地URL地址图片,并转化成矩阵
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
//将图像矩阵和标注加入进行学习Accuracy_Pattern 模式 进行第二次学习
//第二次学习的时候,第三个参数必须是 true
allNub += 4;
int an = operation.toSee(a);
//System.out.println("an============1");
if (an != 1) {
//System.out.println("a错了");
wrong++;
}
int bn = operation.toSee(b);
//System.out.println("bn============2");
if (bn != 2) {
// System.out.println("b错了");
wrong++;
}
int cn = operation.toSee(c);
// System.out.println("cn============3");
if (cn != 3) {
//System.out.println("c错了");
wrong++;
}
int dn = operation.toSee(d);
// System.out.println("dn============4");
if (dn != 4) {
// System.out.println("d错了");
wrong++;
}
}
double wrongPoint = ArithUtil.div(wrong, allNub);
System.out.println("错误率1" + (wrongPoint * 100) + "%");
}
public static void test1() throws Exception {//覆盖率计算
Picture picture = new Picture();
TempleConfig templeConfig = new TempleConfig(false, true);
templeConfig.init(StudyPattern.Cover_Pattern, true, 320, 240, 2);
Operation operation = new Operation(templeConfig);
Map<Integer, Double> rightTagging = new HashMap<>();//分类标注
Map<Integer, Double> wrongTagging = new HashMap<>();//分类标注
rightTagging.put(1, 1.0);
wrongTagging.put(2, 1.0);
Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/picture/yes1.jpg");
Matrix wrong = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/picture/no4.jpg");
int a = 1;
for (int i = 0; i < a; i++) {
operation.coverStudy(right, rightTagging, wrong, wrongTagging);
}
System.out.println("学习完成");
long sys = System.currentTimeMillis();
double point = operation.coverPoint(right, 1);
long sys2 = System.currentTimeMillis();
long sys3 = sys2 - sys;
double point2 = operation.coverPoint(wrong, 1);
System.out.println("识别耗时:" + sys3);
System.out.println("测试覆盖1" + point + ",测试覆盖2:" + point2);
}
}

@ -1,108 +0,0 @@
package org.wlld;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.wlld.MatrixTools.Matrix;
import org.wlld.config.StudyPattern;
import org.wlld.function.Sigmod;
import org.wlld.imageRecognition.Operation;
import org.wlld.imageRecognition.Picture;
import org.wlld.imageRecognition.TempleConfig;
import org.wlld.imageRecognition.border.Frame;
import org.wlld.imageRecognition.border.FrameBody;
import org.wlld.nerveCenter.NerveManager;
import org.wlld.nerveEntity.ModelParameter;
import org.wlld.nerveEntity.SensoryNerve;
import java.util.*;
/**
* Hello world!
*/
public class App {
public static void main(String[] args) throws Exception {
test3();
}
public static void test3() throws Exception {
NerveManager nerveManager = new NerveManager(3, 6, 3
, 3, new Sigmod(), false, true);
nerveManager.init(true, false, false);//初始化
List<Map<Integer, Double>> data = new ArrayList<>();//正样本
List<Map<Integer, Double>> dataB = new ArrayList<>();//负样本
List<Map<Integer, Double>> dataC = new ArrayList<>();//负样本
Random random = new Random();
for (int i = 0; i < 4000; i++) {
Map<Integer, Double> map1 = new HashMap<>();
Map<Integer, Double> map2 = new HashMap<>();
Map<Integer, Double> map3 = new HashMap<>();
map1.put(0, 1 + random.nextDouble());
map1.put(1, 1 + random.nextDouble());
map1.put(2, 0.0);
//产生鲜明区分
map2.put(0, random.nextDouble());
map2.put(1, random.nextDouble());
map2.put(2, 0.0);
//
map3.put(0, 2 + random.nextDouble());
map3.put(1, 2 + random.nextDouble());
map3.put(2, 0.0);
data.add(map1);
dataB.add(map2);
dataC.add(map3);
}
Map<Integer, Double> right = new HashMap<>();
Map<Integer, Double> wrong = new HashMap<>();
Map<Integer, Double> other = new HashMap<>();
right.put(1, 1.0);
wrong.put(2, 1.0);
other.put(3, 1.0);
for (int i = 0; i < data.size(); i++) {
Map<Integer, Double> map1 = data.get(i);
Map<Integer, Double> map2 = dataB.get(i);
Map<Integer, Double> map3 = dataC.get(i);
post(nerveManager.getSensoryNerves(), map1, right, null, true);
post(nerveManager.getSensoryNerves(), map2, wrong, null, true);
post(nerveManager.getSensoryNerves(), map3, other, null, true);
}
List<Map<Integer, Double>> data2 = new ArrayList<>();
List<Map<Integer, Double>> data2B = new ArrayList<>();
List<Map<Integer, Double>> data2C = new ArrayList<>();//负样本
for (int i = 0; i < 20; i++) {
Map<Integer, Double> map1 = new HashMap<>();
Map<Integer, Double> map2 = new HashMap<>();
Map<Integer, Double> map3 = new HashMap<>();
map1.put(0, 1 + random.nextDouble());
map1.put(1, 1 + random.nextDouble());
map1.put(2, 0.0);
map2.put(0, random.nextDouble());
map2.put(1, random.nextDouble());
map2.put(2, 0.0);
map3.put(0, 2 + random.nextDouble());
map3.put(1, 2 + random.nextDouble());
map3.put(2, 0.0);
data2.add(map1);
data2B.add(map2);
data2C.add(map3);
}
Back back = new Back();
for (Map<Integer, Double> map : data2) {
post(nerveManager.getSensoryNerves(), map, null, back, false);
System.out.println("=====================");
}
}
public static void post(List<SensoryNerve> sensoryNerveList, Map<Integer, Double> data
, Map<Integer, Double> tagging, Back back, boolean isStudy) throws Exception {
int size = sensoryNerveList.size();
for (int i = 0; i < size; i++) {
sensoryNerveList.get(i).postMessage(1, data.get(i), isStudy, tagging, back);
}
}
}

File diff suppressed because it is too large Load Diff

@ -14,21 +14,6 @@ import java.util.*;
*/
public class LangTest {
public static void main(String[] args) throws Exception {
List<Double> listAll = new ArrayList<>();
List<Double> list = new ArrayList<>();
List<Double> list2 = new ArrayList<>();
List<Double> list3 = new ArrayList<>();
list.add(1.0);
list.add(2.0);
list2.add(3.0);
list2.add(4.0);
list3.add(5.0);
list3.add(6.0);
listAll.addAll(list);
listAll.addAll(list2);
listAll.addAll(list3);
System.out.println(listAll);
//test1();
}
public static void test1() throws Exception {

@ -1,126 +0,0 @@
package org.wlld;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.wlld.MatrixTools.Matrix;
import org.wlld.MatrixTools.MatrixOperation;
import org.wlld.randomForest.DataTable;
import org.wlld.randomForest.Node;
import org.wlld.randomForest.Tree;
import java.awt.*;
import java.util.*;
/**
* @author lidapeng
* @description
* @date 3:35 2020/1/23
*/
public class MatrixTest {
public static void main(String[] args) throws Exception {
test4();
}
public static void test4() throws Exception {
Set<String> column = new HashSet<>();
column.add("height");
column.add("weight");
column.add("sex");
column.add("h1");
column.add("h2");
DataTable dataTable = new DataTable(column);
dataTable.setKey("sex");
Random random = new Random();
int cla = 3;
for (int i = 0; i < 50; i++) {
Food food = new Food();
food.setHeight(random.nextInt(cla));
food.setWeight(random.nextInt(cla));
food.setSex(random.nextInt(cla));
food.setH1(random.nextInt(cla));
food.setH2(random.nextInt(cla));
dataTable.insert(food);
}
Tree tree = new Tree(dataTable);
tree.study();
Node node = tree.getRootNode();
String a = JSON.toJSONString(node);
Node node1 = JSONObject.parseObject(a, Node.class);
////
Tree tree2 = new Tree(dataTable);
tree2.setRootNode(node1);
for (int i = 0; i < 10; i++) {
Food food = new Food();
food.setHeight(random.nextInt(cla));
food.setWeight(random.nextInt(cla));
food.setSex(random.nextInt(cla));
food.setH1(random.nextInt(cla));
food.setH2(random.nextInt(cla));
int type = tree.judge(food).getType();
int type2 = tree2.judge(food).getType();
if (type != type2) {
System.out.println("出错,type1==" + type + ",type2==" + type2);
} else {
System.out.println(type);
}
}
System.out.println("结束");
}
public static void test3() throws Exception {
Matrix matrix = new Matrix(4, 3);
Matrix matrixY = new Matrix(4, 1);
String b = "[7]#" +
"[8]#" +
"[9]#" +
"[19]#";
matrixY.setAll(b);
String a = "[1,2,17]#" +
"[3,4,18]#" +
"[5,6,10]#" +
"[15,16,13]#";
matrix.setAll(a);
//将参数矩阵转置
Matrix matrix1 = MatrixOperation.transPosition(matrix);
//转置的参数矩阵乘以参数矩阵
Matrix matrix2 = MatrixOperation.mulMatrix(matrix1, matrix);
//求上一步的逆矩阵
Matrix matrix3 = MatrixOperation.getInverseMatrixs(matrix2);
//逆矩阵乘以转置矩阵
Matrix matrix4 = MatrixOperation.mulMatrix(matrix3, matrix1);
//最后乘以输出矩阵,生成权重矩阵
Matrix matrix5 = MatrixOperation.mulMatrix(matrix4, matrixY);
System.out.println(matrix5.getString());
}
public static void test1() throws Exception {
Matrix matrix = new Matrix(2, 2);
Matrix matrix2 = new Matrix(1, 5);
String b = "[6,7,8,9,10]#";
String a = "[1,2]#" +
"[3,4]#";
matrix.setAll(a);
matrix2.setAll(b);
Matrix matrix1 = MatrixOperation.matrixToVector(matrix, true);
matrix1 = MatrixOperation.push(matrix1, 5, true);
matrix1 = MatrixOperation.pushVector(matrix1, matrix2, true);
System.out.println(matrix1.getString());
}
public static void test2() throws Exception {
Matrix matrix = new Matrix(2, 2);
Matrix matrix1 = new Matrix(1, 5);
String a = "[1,2]#" +
"[3,4]#";
String b = "[6,7,8,9,10]#";
matrix.setAll(a);
matrix1.setAll(b);
matrix1 = MatrixOperation.matrixToVector(matrix1, false);
matrix = MatrixOperation.matrixToVector(matrix, false);
matrix = MatrixOperation.push(matrix, 5, true);
matrix = MatrixOperation.pushVector(matrix, matrix1, false);
System.out.println(matrix.getString());
}
}

File diff suppressed because one or more lines are too long

@ -32,7 +32,7 @@ public class NerveDemo1 {
* @param activeFunction
* @param isDynamic
*/
NerveManager nerveManager = new NerveManager(2, 6, 1, 4, new Sigmod(), false, true);
NerveManager nerveManager = new NerveManager(2, 6, 1, 4, new Sigmod(), false, true, 0);
nerveManager.init(true, false, false);
@ -106,6 +106,79 @@ public class NerveDemo1 {
}
public static void test3() throws Exception {
NerveManager nerveManager = new NerveManager(3, 6, 3
, 3, new Sigmod(), false, true, 0);
nerveManager.init(true, false, false);//初始化
List<Map<Integer, Double>> data = new ArrayList<>();//正样本
List<Map<Integer, Double>> dataB = new ArrayList<>();//负样本
List<Map<Integer, Double>> dataC = new ArrayList<>();//负样本
Random random = new Random();
for (int i = 0; i < 4000; i++) {
Map<Integer, Double> map1 = new HashMap<>();
Map<Integer, Double> map2 = new HashMap<>();
Map<Integer, Double> map3 = new HashMap<>();
map1.put(0, 1 + random.nextDouble());
map1.put(1, 1 + random.nextDouble());
map1.put(2, 0.0);
//产生鲜明区分
map2.put(0, random.nextDouble());
map2.put(1, random.nextDouble());
map2.put(2, 0.0);
//
map3.put(0, 2 + random.nextDouble());
map3.put(1, 2 + random.nextDouble());
map3.put(2, 0.0);
data.add(map1);
dataB.add(map2);
dataC.add(map3);
}
Map<Integer, Double> right = new HashMap<>();
Map<Integer, Double> wrong = new HashMap<>();
Map<Integer, Double> other = new HashMap<>();
right.put(1, 1.0);
wrong.put(2, 1.0);
other.put(3, 1.0);
for (int i = 0; i < data.size(); i++) {
Map<Integer, Double> map1 = data.get(i);
Map<Integer, Double> map2 = dataB.get(i);
Map<Integer, Double> map3 = dataC.get(i);
post(nerveManager.getSensoryNerves(), map1, right, null, true);
post(nerveManager.getSensoryNerves(), map2, wrong, null, true);
post(nerveManager.getSensoryNerves(), map3, other, null, true);
}
List<Map<Integer, Double>> data2 = new ArrayList<>();
List<Map<Integer, Double>> data2B = new ArrayList<>();
List<Map<Integer, Double>> data2C = new ArrayList<>();//负样本
for (int i = 0; i < 20; i++) {
Map<Integer, Double> map1 = new HashMap<>();
Map<Integer, Double> map2 = new HashMap<>();
Map<Integer, Double> map3 = new HashMap<>();
map1.put(0, 1 + random.nextDouble());
map1.put(1, 1 + random.nextDouble());
map1.put(2, 0.0);
map2.put(0, random.nextDouble());
map2.put(1, random.nextDouble());
map2.put(2, 0.0);
map3.put(0, 2 + random.nextDouble());
map3.put(1, 2 + random.nextDouble());
map3.put(2, 0.0);
data2.add(map1);
data2B.add(map2);
data2C.add(map3);
}
Back back = new Back();
for (Map<Integer, Double> map : data2) {
post(nerveManager.getSensoryNerves(), map, null, back, false);
System.out.println("=====================");
}
}
/**
*
*

Loading…
Cancel
Save