增加随机森林模型获取和注入

pull/1/head
lidapeng 5 years ago
parent a84bd584a8
commit 6155c099e9

@ -194,7 +194,12 @@
//创建模板读取累
TemplateReader templateReader = new TemplateReader();
//读取语言模版第一个参数是模版地址第二个参数是编码方式第三个参数是是否是WIN系统
//同时也是学习过程
templateReader.read("/Users/lidapeng/Desktop/myDocment/a1.txt", "UTF-8", IOConst.NOT_WIN);
//学习结束获取模型参数
//WordModel wordModel = WordTemple.get().getModel();
//不用学习注入模型参数
//WordTemple.get().insertModel(wordModel);
Talk talk = new Talk();
//输入语句进行识别若有标点符号会形成LIST中的每个元素
//返回的集合中每个值代表了输入语句,每个标点符号前语句的分类

@ -13,6 +13,33 @@ public class WordModel {
private RfModel rfModel;//随机森林模型
private List<WorldBody> allWorld;//所有词集合
private List<List<String>> wordTimes;//所有分词编号
private double garbageTh;//垃圾分类的阈值默认0.5
private double trustPunishment;//信任惩罚
private double trustTh;//信任阈值
public double getGarbageTh() {
return garbageTh;
}
public void setGarbageTh(double garbageTh) {
this.garbageTh = garbageTh;
}
public double getTrustPunishment() {
return trustPunishment;
}
public void setTrustPunishment(double trustPunishment) {
this.trustPunishment = trustPunishment;
}
public double getTrustTh() {
return trustTh;
}
public void setTrustTh(double trustTh) {
this.trustTh = trustTh;
}
public RfModel getRfModel() {
return rfModel;

@ -19,6 +19,27 @@ public class WordTemple {
private double garbageTh = 0.5;//垃圾分类的阈值默认0.5
private double trustPunishment = 0.1;//信任惩罚
public WordModel getModel() {//获取模型
WordModel wordModel = new WordModel();
wordModel.setAllWorld(allWorld);
wordModel.setWordTimes(wordTimes);
wordModel.setGarbageTh(garbageTh);
wordModel.setTrustPunishment(trustPunishment);
wordModel.setTrustTh(randomForest.getTrustTh());
wordModel.setRfModel(randomForest.getModel());
return wordModel;
}
public void insertModel(WordModel wordModel) throws Exception {//注入模型
allWorld = wordModel.getAllWorld();
wordTimes = wordModel.getWordTimes();
garbageTh = wordModel.getGarbageTh();
trustPunishment = wordModel.getTrustPunishment();
randomForest = new RandomForest();
randomForest.setTrustTh(wordModel.getTrustTh());
randomForest.insertModel(wordModel.getRfModel());
}
public double getTrustPunishment() {
return trustPunishment;
}

@ -14,10 +14,17 @@ public class RandomForest {
private Tree[] forest;
private double trustTh = 0.1;//信任阈值
public void setTrustTh(double trustTh) {//设置信任阈值
public double getTrustTh() {
return trustTh;
}
public void setTrustTh(double trustTh) {
this.trustTh = trustTh;
}
public RandomForest() {
}
public RandomForest(int treeNub) throws Exception {
if (treeNub > 0) {
forest = new Tree[treeNub];
@ -29,8 +36,12 @@ public class RandomForest {
public void insertModel(RfModel rfModel) throws Exception {//注入模型
if (rfModel != null) {
Map<Integer, Node> nodeMap = rfModel.getNodeMap();
for (int i = 0; i < forest.length; i++) {
forest[i].setRootNode(nodeMap.get(i));
forest = new Tree[nodeMap.size()];
for (Map.Entry<Integer, Node> entry : nodeMap.entrySet()) {
int key = entry.getKey();
Tree tree = new Tree();
forest[key] = tree;
tree.setRootNode(entry.getValue());
}
} else {
throw new Exception("model is null");

@ -36,6 +36,9 @@ public class Tree {//决策树
private double gainRatio;
}
public Tree() {
}
public Tree(DataTable dataTable) throws Exception {
if (dataTable != null && dataTable.getKey() != null) {
this.dataTable = dataTable;

@ -17,15 +17,19 @@ public class LangTest {
}
public static void test() throws Exception {
//学习过程 过程(长期内存持有)
//模版类
//创建模板读取累
TemplateReader templateReader = new TemplateReader();
//读取模版
//读取语言模版第一个参数是模版地址第二个参数是编码方式第三个参数是是否是WIN系统
//同时也是学习过程
templateReader.read("/Users/lidapeng/Desktop/myDocment/a1.txt", "UTF-8", IOConst.NOT_WIN);
//识别过程
//学习结束获取模型参数
//WordModel wordModel = WordTemple.get().getModel();
//不用学习注入模型参数
//WordTemple.get().insertModel(wordModel);
Talk talk = new Talk();
//我饿了,我想吃个饭
List<Integer> list = talk.talk("语速");
//输入语句进行识别若有标点符号会形成LIST中的每个元素
//返回的集合中每个值代表了输入语句,每个标点符号前语句的分类
List<Integer> list = talk.talk("帮我配把锁");
System.out.println(list);
}
}

Loading…
Cancel
Save