diff --git a/.idea/compiler.xml b/.idea/compiler.xml index 6aa88ff..d280c68 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -6,8 +6,8 @@ - + diff --git a/pom.xml b/pom.xml index 5aaac0d..d4b0ccd 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.github ImageMarket - 1.0.0 + 1.0.2 myBrain diff --git a/src/main/java/org/wlld/naturalLanguage/Talk.java b/src/main/java/org/wlld/naturalLanguage/Talk.java index 16180b4..ff2dd4a 100644 --- a/src/main/java/org/wlld/naturalLanguage/Talk.java +++ b/src/main/java/org/wlld/naturalLanguage/Talk.java @@ -2,6 +2,7 @@ package org.wlld.naturalLanguage; import org.wlld.randomForest.RandomForest; +import org.wlld.tools.ArithUtil; import java.util.ArrayList; import java.util.List; @@ -16,7 +17,8 @@ public class Talk { private RandomForest randomForest = WordTemple.get().getRandomForest();//获取随机森林模型 private List> wordTimes = WordTemple.get().getWordTimes(); - public void talk(String sentence) throws Exception { + public List talk(String sentence) throws Exception { + List typeList = new ArrayList<>(); String rgm = null; if (sentence.indexOf(",") > -1) { rgm = ","; @@ -41,29 +43,38 @@ public class Talk { if (randomForest != null) { for (Sentence sentence1 : sentences) { List features = sentence1.getFeatures(); - List keyWords = sentence1.getKeyWords(); + List keyWords = sentence1.getKeyWords();//拆分的关键词 + int wrong = 0; + int wordNumber = keyWords.size(); for (int i = 0; i < 8; i++) { int nub = 0; if (keyWords.size() > i) { List words = wordTimes.get(i); nub = getNub(words, keyWords.get(i)); + if (nub == 0) {//出现了不认识的词 + wrong++; + } } features.add(nub); } - LangBody langBody = new LangBody(); - langBody.setA1(features.get(0)); - langBody.setA2(features.get(1)); - langBody.setA3(features.get(2)); - langBody.setA4(features.get(3)); - langBody.setA5(features.get(4)); - langBody.setA6(features.get(5)); - langBody.setA7(features.get(6)); - langBody.setA8(features.get(7)); - int type = randomForest.forest(langBody); - System.out.println("type==" + type); + int type = 0; + if (ArithUtil.div(wrong, wordNumber) < WordTemple.get().getGarbageTh()) { + LangBody langBody = new LangBody(); + langBody.setA1(features.get(0)); + langBody.setA2(features.get(1)); + langBody.setA3(features.get(2)); + langBody.setA4(features.get(3)); + langBody.setA5(features.get(4)); + langBody.setA6(features.get(5)); + langBody.setA7(features.get(6)); + langBody.setA8(features.get(7)); + type = randomForest.forest(langBody); + } + typeList.add(type); } + return typeList; } else { - System.out.println("随机森林没有训练"); + throw new Exception("forest is not study"); } } diff --git a/src/main/java/org/wlld/naturalLanguage/Tokenizer.java b/src/main/java/org/wlld/naturalLanguage/Tokenizer.java index 7aebc42..cb433ac 100644 --- a/src/main/java/org/wlld/naturalLanguage/Tokenizer.java +++ b/src/main/java/org/wlld/naturalLanguage/Tokenizer.java @@ -69,7 +69,7 @@ public class Tokenizer extends Frequency { DataTable dataTable = new DataTable(column); dataTable.setKey("key"); //初始化随机森林 - RandomForest randomForest = new RandomForest(5); + RandomForest randomForest = new RandomForest(7); WordTemple.get().setRandomForest(randomForest);//保存随机森林到模版 randomForest.init(dataTable); for (Sentence sentence : sentences) { diff --git a/src/main/java/org/wlld/naturalLanguage/WordConst.java b/src/main/java/org/wlld/naturalLanguage/WordConst.java index 0e2c1ac..be484e2 100644 --- a/src/main/java/org/wlld/naturalLanguage/WordConst.java +++ b/src/main/java/org/wlld/naturalLanguage/WordConst.java @@ -2,19 +2,9 @@ package org.wlld.naturalLanguage; public class WordConst { public static double Word_Noise = 0.7;//收缩程度 - public static final int SHOP = 1;//购买类型 - public static final int FOOD = 3;//食物类型 - public static final int DRINK = 4;//饮品类型 - public static final int OTHER = 5;//家庭日用(油盐酱醋卫生纸之类的) - public static final int SMOKE = 10;//烟草 - public static final int ADD = 6;//订单增0.5037412492 - public static final int DEL = 7;//订单删 - public static final int UPDATE = 8;//订单改 - public static final int SELECT = 9;//订单查 - public static final int TALK = 2;//聊天类型 - public static final int ALL = 11;//全文本 - public static final int CHANGE = 12;//分类文本 - public static final int DROP = 13;//消文本 - public static final int CURD = 14;//对订单增删改查类型 - public static final int ANS = 0;//聊天回复 + public static final int Water = 2;//送水 + public static final int Nanny = 3;//保姆 + public static final int Unlock = 4;//开锁 + public static final int Express = 5;//快递 + public static final int Telephone = 6;//充话费 } diff --git a/src/main/java/org/wlld/naturalLanguage/WordTemple.java b/src/main/java/org/wlld/naturalLanguage/WordTemple.java index 7583c53..69ea7bc 100644 --- a/src/main/java/org/wlld/naturalLanguage/WordTemple.java +++ b/src/main/java/org/wlld/naturalLanguage/WordTemple.java @@ -16,6 +16,15 @@ public class WordTemple { private List allWorld = new ArrayList<>();//所有词集合 private List> wordTimes = new ArrayList<>();//词编号 private RandomForest randomForest;//保存的随机森林模型 + private double garbageTh = 0.5;//垃圾分类的阈值默认0.7 + + public double getGarbageTh() { + return garbageTh; + } + + public void setGarbageTh(double garbageTh) { + this.garbageTh = garbageTh; + } public RandomForest getRandomForest() { return randomForest; diff --git a/src/main/java/org/wlld/randomForest/RandomForest.java b/src/main/java/org/wlld/randomForest/RandomForest.java index ef93efe..0891ed0 100644 --- a/src/main/java/org/wlld/randomForest/RandomForest.java +++ b/src/main/java/org/wlld/randomForest/RandomForest.java @@ -48,6 +48,7 @@ public class RandomForest { for (int i = 0; i < forest.length; i++) { Tree tree = forest[i]; int type = tree.judge(object); + //System.out.println(type); if (map.containsKey(type)) { map.put(type, map.get(type) + 1); } else { @@ -81,7 +82,7 @@ public class RandomForest { public void study() throws Exception {//学习 for (int i = 0; i < forest.length; i++) { - System.out.println("开始学习==" + i + ",treeNub==" + forest.length); + //System.out.println("开始学习==" + i + ",treeNub==" + forest.length); Tree tree = forest[i]; tree.study(); } diff --git a/src/test/java/org/wlld/LangTest.java b/src/test/java/org/wlld/LangTest.java index 98f7d77..9e68b2b 100644 --- a/src/test/java/org/wlld/LangTest.java +++ b/src/test/java/org/wlld/LangTest.java @@ -4,6 +4,8 @@ import org.wlld.naturalLanguage.IOConst; import org.wlld.naturalLanguage.Talk; import org.wlld.naturalLanguage.TemplateReader; +import java.util.List; + /** * @author lidapeng * @description @@ -16,11 +18,9 @@ public class LangTest { public static void test() throws Exception { TemplateReader templateReader = new TemplateReader(); - templateReader.read("/Users/lidapeng/Desktop/myDocment/a2.txt", "UTF-8", IOConst.NOT_WIN); + templateReader.read("/Users/lidapeng/Desktop/myDocment/a1.txt", "UTF-8", IOConst.NOT_WIN); Talk talk = new Talk(); - talk.talk("我要吃面包"); - talk.talk("我渴了"); - talk.talk("我要去看望你"); - talk.talk("我买两盒烟"); + List list = talk.talk("我草尼玛"); + System.out.println(list); } }