diff --git a/README.md b/README.md index 00e027d..d1bd905 100644 --- a/README.md +++ b/README.md @@ -194,7 +194,12 @@ //创建模板读取累 TemplateReader templateReader = new TemplateReader(); //读取语言模版,第一个参数是模版地址,第二个参数是编码方式,第三个参数是是否是WIN系统 + //同时也是学习过程 templateReader.read("/Users/lidapeng/Desktop/myDocment/a1.txt", "UTF-8", IOConst.NOT_WIN); + //学习结束获取模型参数 + //WordModel wordModel = WordTemple.get().getModel(); + //不用学习注入模型参数 + //WordTemple.get().insertModel(wordModel); Talk talk = new Talk(); //输入语句进行识别,若有标点符号会形成LIST中的每个元素 //返回的集合中每个值代表了输入语句,每个标点符号前语句的分类 diff --git a/src/main/java/org/wlld/naturalLanguage/WordModel.java b/src/main/java/org/wlld/naturalLanguage/WordModel.java index e6d98cf..d3a4fbe 100644 --- a/src/main/java/org/wlld/naturalLanguage/WordModel.java +++ b/src/main/java/org/wlld/naturalLanguage/WordModel.java @@ -13,6 +13,33 @@ public class WordModel { private RfModel rfModel;//随机森林模型 private List allWorld;//所有词集合 private List> wordTimes;//所有分词编号 + private double garbageTh;//垃圾分类的阈值默认0.5 + private double trustPunishment;//信任惩罚 + private double trustTh;//信任阈值 + + public double getGarbageTh() { + return garbageTh; + } + + public void setGarbageTh(double garbageTh) { + this.garbageTh = garbageTh; + } + + public double getTrustPunishment() { + return trustPunishment; + } + + public void setTrustPunishment(double trustPunishment) { + this.trustPunishment = trustPunishment; + } + + public double getTrustTh() { + return trustTh; + } + + public void setTrustTh(double trustTh) { + this.trustTh = trustTh; + } public RfModel getRfModel() { return rfModel; diff --git a/src/main/java/org/wlld/naturalLanguage/WordTemple.java b/src/main/java/org/wlld/naturalLanguage/WordTemple.java index e78f2b3..d254e64 100644 --- a/src/main/java/org/wlld/naturalLanguage/WordTemple.java +++ b/src/main/java/org/wlld/naturalLanguage/WordTemple.java @@ -19,6 +19,27 @@ public class WordTemple { private double garbageTh = 0.5;//垃圾分类的阈值默认0.5 private double trustPunishment = 0.1;//信任惩罚 + public WordModel getModel() {//获取模型 + WordModel wordModel = new WordModel(); + wordModel.setAllWorld(allWorld); + wordModel.setWordTimes(wordTimes); + wordModel.setGarbageTh(garbageTh); + wordModel.setTrustPunishment(trustPunishment); + wordModel.setTrustTh(randomForest.getTrustTh()); + wordModel.setRfModel(randomForest.getModel()); + return wordModel; + } + + public void insertModel(WordModel wordModel) throws Exception {//注入模型 + allWorld = wordModel.getAllWorld(); + wordTimes = wordModel.getWordTimes(); + garbageTh = wordModel.getGarbageTh(); + trustPunishment = wordModel.getTrustPunishment(); + randomForest = new RandomForest(); + randomForest.setTrustTh(wordModel.getTrustTh()); + randomForest.insertModel(wordModel.getRfModel()); + } + public double getTrustPunishment() { return trustPunishment; } diff --git a/src/main/java/org/wlld/randomForest/RandomForest.java b/src/main/java/org/wlld/randomForest/RandomForest.java index 59319f6..c961400 100644 --- a/src/main/java/org/wlld/randomForest/RandomForest.java +++ b/src/main/java/org/wlld/randomForest/RandomForest.java @@ -14,10 +14,17 @@ public class RandomForest { private Tree[] forest; private double trustTh = 0.1;//信任阈值 - public void setTrustTh(double trustTh) {//设置信任阈值 + public double getTrustTh() { + return trustTh; + } + + public void setTrustTh(double trustTh) { this.trustTh = trustTh; } + public RandomForest() { + } + public RandomForest(int treeNub) throws Exception { if (treeNub > 0) { forest = new Tree[treeNub]; @@ -29,8 +36,12 @@ public class RandomForest { public void insertModel(RfModel rfModel) throws Exception {//注入模型 if (rfModel != null) { Map nodeMap = rfModel.getNodeMap(); - for (int i = 0; i < forest.length; i++) { - forest[i].setRootNode(nodeMap.get(i)); + forest = new Tree[nodeMap.size()]; + for (Map.Entry entry : nodeMap.entrySet()) { + int key = entry.getKey(); + Tree tree = new Tree(); + forest[key] = tree; + tree.setRootNode(entry.getValue()); } } else { throw new Exception("model is null"); diff --git a/src/main/java/org/wlld/randomForest/Tree.java b/src/main/java/org/wlld/randomForest/Tree.java index 5975490..2e17b6e 100644 --- a/src/main/java/org/wlld/randomForest/Tree.java +++ b/src/main/java/org/wlld/randomForest/Tree.java @@ -36,6 +36,9 @@ public class Tree {//决策树 private double gainRatio; } + public Tree() { + } + public Tree(DataTable dataTable) throws Exception { if (dataTable != null && dataTable.getKey() != null) { this.dataTable = dataTable; diff --git a/src/test/java/org/wlld/LangTest.java b/src/test/java/org/wlld/LangTest.java index 103e174..4af2699 100644 --- a/src/test/java/org/wlld/LangTest.java +++ b/src/test/java/org/wlld/LangTest.java @@ -17,15 +17,19 @@ public class LangTest { } public static void test() throws Exception { - //学习过程 过程(长期内存持有) - //模版类 + //创建模板读取累 TemplateReader templateReader = new TemplateReader(); - //读取模版 + //读取语言模版,第一个参数是模版地址,第二个参数是编码方式,第三个参数是是否是WIN系统 + //同时也是学习过程 templateReader.read("/Users/lidapeng/Desktop/myDocment/a1.txt", "UTF-8", IOConst.NOT_WIN); - //识别过程 + //学习结束获取模型参数 + //WordModel wordModel = WordTemple.get().getModel(); + //不用学习注入模型参数 + //WordTemple.get().insertModel(wordModel); Talk talk = new Talk(); - //我饿了,我想吃个饭 - List list = talk.talk("语速"); + //输入语句进行识别,若有标点符号会形成LIST中的每个元素 + //返回的集合中每个值代表了输入语句,每个标点符号前语句的分类 + List list = talk.talk("帮我配把锁"); System.out.println(list); } }