diff --git a/src/main/java/org/wlld/naturalLanguage/Talk.java b/src/main/java/org/wlld/naturalLanguage/Talk.java index 1dfa92f..004c2d5 100644 --- a/src/main/java/org/wlld/naturalLanguage/Talk.java +++ b/src/main/java/org/wlld/naturalLanguage/Talk.java @@ -25,8 +25,16 @@ public class Talk { wordTimes = wordTemple.getWordTimes(); } - public List talk(String sentence) throws Exception { - List typeList = new ArrayList<>(); + public List> getSplitWord(String sentence) {//单纯进行拆词 + List sentences = splitSentence(sentence); + List> words = new ArrayList<>(); + for (Sentence sentence1 : sentences) { + words.add(sentence1.getKeyWords()); + } + return words; + } + + private List splitSentence(String sentence) { String rgm = null; if (sentence.indexOf(",") > -1) { rgm = ","; @@ -42,7 +50,8 @@ public class Talk { //拆词 List sentences = new ArrayList<>(); for (int i = 0; i < sens.length; i++) { - List sentenceList = catchSentence(sentence); + String mySentence = sens[i]; + List sentenceList = catchSentence(mySentence); int key = 0; int nub = 0; for (int j = 0; j < sentenceList.size(); j++) { @@ -57,43 +66,52 @@ public class Talk { //System.out.println(sentenceList.get(key).getKeyWords()); sentences.add(sentenceList.get(key)); } + return sentences; + } - //进行识别 - if (randomForest != null) { - for (Sentence sentence1 : sentences) { - List features = sentence1.getFeatures(); - List keyWords = sentence1.getKeyWords();//拆分的关键词 - int wrong = 0; - int wordNumber = keyWords.size(); - for (int i = 0; i < 8; i++) { - int nub = 0; - if (keyWords.size() > i) { - List words = wordTimes.get(i); - nub = getNub(words, keyWords.get(i)); - if (nub == 0) {//出现了不认识的词 - wrong++; + public List talk(String sentence) throws Exception { + if (!wordTemple.isSplitWord()) { + List typeList = new ArrayList<>(); + List sentences = splitSentence(sentence); + //进行识别 + if (randomForest != null) { + for (Sentence sentence1 : sentences) { + List features = sentence1.getFeatures(); + List keyWords = sentence1.getKeyWords();//拆分的关键词 + int wrong = 0; + int wordNumber = keyWords.size(); + for (int i = 0; i < 8; i++) { + int nub = 0; + if (keyWords.size() > i) { + List words = wordTimes.get(i); + nub = getNub(words, keyWords.get(i)); + if (nub == 0) {//出现了不认识的词 + wrong++; + } } + features.add(nub); } - features.add(nub); - } - int type = 0; - if (ArithUtil.div(wrong, wordNumber) < wordTemple.getGarbageTh()) { - LangBody langBody = new LangBody(); - langBody.setA1(features.get(0)); - langBody.setA2(features.get(1)); - langBody.setA3(features.get(2)); - langBody.setA4(features.get(3)); - langBody.setA5(features.get(4)); - langBody.setA6(features.get(5)); - langBody.setA7(features.get(6)); - langBody.setA8(features.get(7)); - type = randomForest.forest(langBody); + int type = 0; + if (ArithUtil.div(wrong, wordNumber) < wordTemple.getGarbageTh()) { + LangBody langBody = new LangBody(); + langBody.setA1(features.get(0)); + langBody.setA2(features.get(1)); + langBody.setA3(features.get(2)); + langBody.setA4(features.get(3)); + langBody.setA5(features.get(4)); + langBody.setA6(features.get(5)); + langBody.setA7(features.get(6)); + langBody.setA8(features.get(7)); + type = randomForest.forest(langBody); + } + typeList.add(type); } - typeList.add(type); + return typeList; + } else { + throw new Exception("forest is not study"); } - return typeList; } else { - throw new Exception("forest is not study"); + throw new Exception("isSplitWord is true"); } } diff --git a/src/main/java/org/wlld/naturalLanguage/Tokenizer.java b/src/main/java/org/wlld/naturalLanguage/Tokenizer.java index 2f4a8d5..9f2181e 100644 --- a/src/main/java/org/wlld/naturalLanguage/Tokenizer.java +++ b/src/main/java/org/wlld/naturalLanguage/Tokenizer.java @@ -44,9 +44,11 @@ public class Tokenizer extends Frequency { } restructure();//对集合中的词进行词频统计 //这里分词已经结束,对词进行编号 - number(); - //进入随机森林进行学习 - study(); + if (!wordTemple.isSplitWord()) { + number(); + //进入随机森林进行学习 + study(); + } } private int getKey(List words, String testWord) { diff --git a/src/main/java/org/wlld/naturalLanguage/WordTemple.java b/src/main/java/org/wlld/naturalLanguage/WordTemple.java index 402249e..7fc2205 100644 --- a/src/main/java/org/wlld/naturalLanguage/WordTemple.java +++ b/src/main/java/org/wlld/naturalLanguage/WordTemple.java @@ -20,6 +20,15 @@ public class WordTemple { private double trustPunishment = 0.1;//信任惩罚 private double trustTh = 0.1;//信任阈值,相当于一次信任惩罚的数值 private int treeNub = 11;//丛林里面树的数量 + private boolean isSplitWord = false;//是否使用拆分词模式,默认是不使用 + + public boolean isSplitWord() { + return isSplitWord; + } + + public void setSplitWord(boolean splitWord) { + isSplitWord = splitWord; + } public int getTreeNub() { return treeNub; diff --git a/src/test/java/org/wlld/LangTest.java b/src/test/java/org/wlld/LangTest.java index 8f34caa..1df3657 100644 --- a/src/test/java/org/wlld/LangTest.java +++ b/src/test/java/org/wlld/LangTest.java @@ -72,13 +72,18 @@ public class LangTest { TemplateReader templateReader = new TemplateReader(); WordTemple wordTemple = new WordTemple();//初始化语言模版 wordTemple.setTreeNub(9); + // wordTemple.setSplitWord(true); //读取语言模版,第一个参数是模版地址,第二个参数是编码方式 (教程里的第三个参数已经省略) //同时也是学习过程 templateReader.read("/Users/lidapeng/Desktop/myDocument/model.txt", "UTF-8", wordTemple); Talk talk = new Talk(wordTemple); //输入语句进行识别,若有标点符号会形成LIST中的每个元素 //返回的集合中每个值代表了输入语句,每个标点符号前语句的分类 - List list = talk.talk("空调坏了"); +// List> lists = talk.getSplitWord("空调坏了,帮我修一修"); +// for (List list : lists) { +// System.out.println(list); +// } + List list = talk.talk("空调坏了,帮我修一修"); System.out.println(list); } } diff --git a/src/test/java/org/wlld/NerveDemo1.java b/src/test/java/org/wlld/NerveDemo1.java index 858da82..76806f2 100644 --- a/src/test/java/org/wlld/NerveDemo1.java +++ b/src/test/java/org/wlld/NerveDemo1.java @@ -31,64 +31,57 @@ public class NerveDemo1 { * @param hiddenNerverNub 隐层神经元个数 * @param outNerveNub 输出神经元个数 * @param hiddenDepth 隐层深度 + * @param isAccurate 是否保留精度 * @param activeFunction 激活函数 * @param isDynamic 是否是动态神经元 */ -// NerveManager nerveManager = new NerveManager(2, 6, 1, 4, new Sigmod(), -// false, true, 0, RZ.NOT_RZ, 0); -// nerveManager.init(true, false, false, false); -// -// -// //创建训练 -// List> list_right = new LinkedList<>();//存放正确的值 -// List> list_wrong = new LinkedList<>();//存放错误的值 -// Random random = new Random(); -// for (int i = 0; i < 10000; i++) { -// Map mp1 = new HashMap<>(); -// Map mp2 = new HashMap<>(); -// mp1.put(0, random.nextDouble()); -// mp1.put(1, random.nextDouble()); -// mp2.put(0, -random.nextDouble());//负样本:负数永远小于0 -// mp2.put(1, -random.nextDouble()); -// list_right.add(mp1); -// list_wrong.add(mp2); -// } + NerveManager nerveManager = new NerveManager(2, 6, 1, 2, new Tanh(), + false, true, 0, RZ.NOT_RZ, 0); + nerveManager.init(true, false, false, false); + //创建训练 + List> list_right = new LinkedList<>();//存放正确的值 + List> list_wrong = new LinkedList<>();//存放错误的值 + Random random = new Random(); + for (int i = 0; i < 1000; i++) { + Map mp1 = new HashMap<>(); + Map mp2 = new HashMap<>(); + mp1.put(0, random.nextDouble()); + mp1.put(1, random.nextDouble()); + mp2.put(0, -random.nextDouble());//负样本:负数永远小于0 + mp2.put(1, -random.nextDouble()); + list_right.add(mp1); + list_wrong.add(mp2); + } // // //做一个正标注和负标注 -// Map right = new HashMap<>(); -// Map wrong = new HashMap<>(); -// right.put(1, 1.0); -// wrong.put(1, 0.0); -// -// //开始训练 -// for (int i = 0; i < list_right.size(); i++) { -// Map mp1 = list_right.get(i); -// Map mp2 = list_wrong.get(i); -// //这里的post的训练 -// post(nerveManager.getSensoryNerves(), mp1, right, null, true); -// post(nerveManager.getSensoryNerves(), mp2, wrong, null, true); -// } + Map right = new HashMap<>(); + Map wrong = new HashMap<>(); + right.put(1, 1.0); + wrong.put(1, 0.0); + + //开始训练 + for (int i = 0; i < list_right.size(); i++) { + Map mp1 = list_right.get(i); + Map mp2 = list_wrong.get(i); + //这里的post的训练 + post(nerveManager.getSensoryNerves(), mp1, right, null, true); + post(nerveManager.getSensoryNerves(), mp2, wrong, null, true); + } // // //测试 这里测试10个数据 -// List> test1 = new LinkedList<>(); -// for (int i = 0; i < 10; i++) { -// Map mp1 = new HashMap<>(); -// if (i == 4) {//在第五次的时候给一个错误的值//看看能否正确识别 -// mp1.put(0, -random.nextDouble()); -// mp1.put(1, -random.nextDouble()); -// test1.add(mp1); -// continue; -// } -// mp1.put(0, random.nextDouble()); -// mp1.put(1, random.nextDouble()); -// test1.add(mp1); -// } + List> test1 = new LinkedList<>(); + for (int i = 0; i < 10; i++) { + Map mp1 = new HashMap<>(); + mp1.put(0, -random.nextDouble()); + mp1.put(1, -random.nextDouble()); + test1.add(mp1); + } // //查看结果 -// Back back = new Back(); -// for (Map test_data : test1) { -// //这里的post是进行学习 -// post(nerveManager.getSensoryNerves(), test_data, null, back, false); -// } + Back back = new Back(); + for (Map test_data : test1) { + //这里的post是进行学习 + post(nerveManager.getSensoryNerves(), test_data, null, back, false); + } /* * 输出结果: * @@ -105,65 +98,6 @@ public class NerveDemo1 { * * 很显然第五个数值非常小,意味着不是我们想要的结果 */ - - test3(); - } - - public static void test3() throws Exception { - NerveManager nerveManager = new NerveManager(2, 6, 2 - , 2, new Tanh(), - false, true, 0.01, RZ.NOT_RZ, 0); - nerveManager.init(true, false, true, false);//初始化 - List> data = new ArrayList<>();//正样本 - List> dataB = new ArrayList<>();//负样本 - double a1 = 0.5463803496429489; - double a2 = 1.0917521555875922; - double b1 = 1.0012872347982982; - double b2 = 0.6597094176788427; - Random random = new Random(); - for (int i = 0; i < 3000; i++) { - Map map1 = new HashMap<>(); - Map map2 = new HashMap<>(); - map1.put(0, a1); - map1.put(1, a2); - //产生鲜明区分 - map2.put(0, b1); - map2.put(1, b2); - - data.add(map1); - dataB.add(map2); - } - Map right = new HashMap<>(); - Map wrong = new HashMap<>(); - right.put(1, 1.0); - - wrong.put(2, 1.0); - for (int i = 0; i < data.size(); i++) { - Map map1 = data.get(i); - Map map2 = dataB.get(i); - post(nerveManager.getSensoryNerves(), map1, right, null, true); - System.out.println("================"); - post(nerveManager.getSensoryNerves(), map2, wrong, null, true); - } - - List> data2 = new ArrayList<>(); - List> data2B = new ArrayList<>(); - for (int i = 0; i < 20; i++) { - Map map1 = new HashMap<>(); - Map map2 = new HashMap<>(); - map1.put(0, a1); - map1.put(1, a2); - - map2.put(0, b1); - map2.put(1, b2); - data2.add(map1); - data2B.add(map2); - } - Back back = new Back(); - for (Map map : data2B) { - post(nerveManager.getSensoryNerves(), map, null, back, false); - System.out.println("====================="); - } } /**