!45 增加分词api

Merge pull request !45 from 逐光/test
pull/45/MERGE
逐光 5 years ago committed by Gitee
commit 0fd848ccb8

@ -25,8 +25,16 @@ public class Talk {
wordTimes = wordTemple.getWordTimes(); wordTimes = wordTemple.getWordTimes();
} }
public List<Integer> talk(String sentence) throws Exception { public List<List<String>> getSplitWord(String sentence) {//单纯进行拆词
List<Integer> typeList = new ArrayList<>(); List<Sentence> sentences = splitSentence(sentence);
List<List<String>> words = new ArrayList<>();
for (Sentence sentence1 : sentences) {
words.add(sentence1.getKeyWords());
}
return words;
}
private List<Sentence> splitSentence(String sentence) {
String rgm = null; String rgm = null;
if (sentence.indexOf(",") > -1) { if (sentence.indexOf(",") > -1) {
rgm = ","; rgm = ",";
@ -42,7 +50,8 @@ public class Talk {
//拆词 //拆词
List<Sentence> sentences = new ArrayList<>(); List<Sentence> sentences = new ArrayList<>();
for (int i = 0; i < sens.length; i++) { for (int i = 0; i < sens.length; i++) {
List<Sentence> sentenceList = catchSentence(sentence); String mySentence = sens[i];
List<Sentence> sentenceList = catchSentence(mySentence);
int key = 0; int key = 0;
int nub = 0; int nub = 0;
for (int j = 0; j < sentenceList.size(); j++) { for (int j = 0; j < sentenceList.size(); j++) {
@ -57,43 +66,52 @@ public class Talk {
//System.out.println(sentenceList.get(key).getKeyWords()); //System.out.println(sentenceList.get(key).getKeyWords());
sentences.add(sentenceList.get(key)); sentences.add(sentenceList.get(key));
} }
return sentences;
}
//进行识别 public List<Integer> talk(String sentence) throws Exception {
if (randomForest != null) { if (!wordTemple.isSplitWord()) {
for (Sentence sentence1 : sentences) { List<Integer> typeList = new ArrayList<>();
List<Integer> features = sentence1.getFeatures(); List<Sentence> sentences = splitSentence(sentence);
List<String> keyWords = sentence1.getKeyWords();//拆分的关键词 //进行识别
int wrong = 0; if (randomForest != null) {
int wordNumber = keyWords.size(); for (Sentence sentence1 : sentences) {
for (int i = 0; i < 8; i++) { List<Integer> features = sentence1.getFeatures();
int nub = 0; List<String> keyWords = sentence1.getKeyWords();//拆分的关键词
if (keyWords.size() > i) { int wrong = 0;
List<String> words = wordTimes.get(i); int wordNumber = keyWords.size();
nub = getNub(words, keyWords.get(i)); for (int i = 0; i < 8; i++) {
if (nub == 0) {//出现了不认识的词 int nub = 0;
wrong++; if (keyWords.size() > i) {
List<String> words = wordTimes.get(i);
nub = getNub(words, keyWords.get(i));
if (nub == 0) {//出现了不认识的词
wrong++;
}
} }
features.add(nub);
} }
features.add(nub); int type = 0;
} if (ArithUtil.div(wrong, wordNumber) < wordTemple.getGarbageTh()) {
int type = 0; LangBody langBody = new LangBody();
if (ArithUtil.div(wrong, wordNumber) < wordTemple.getGarbageTh()) { langBody.setA1(features.get(0));
LangBody langBody = new LangBody(); langBody.setA2(features.get(1));
langBody.setA1(features.get(0)); langBody.setA3(features.get(2));
langBody.setA2(features.get(1)); langBody.setA4(features.get(3));
langBody.setA3(features.get(2)); langBody.setA5(features.get(4));
langBody.setA4(features.get(3)); langBody.setA6(features.get(5));
langBody.setA5(features.get(4)); langBody.setA7(features.get(6));
langBody.setA6(features.get(5)); langBody.setA8(features.get(7));
langBody.setA7(features.get(6)); type = randomForest.forest(langBody);
langBody.setA8(features.get(7)); }
type = randomForest.forest(langBody); typeList.add(type);
} }
typeList.add(type); return typeList;
} else {
throw new Exception("forest is not study");
} }
return typeList;
} else { } else {
throw new Exception("forest is not study"); throw new Exception("isSplitWord is true");
} }
} }

@ -44,9 +44,11 @@ public class Tokenizer extends Frequency {
} }
restructure();//对集合中的词进行词频统计 restructure();//对集合中的词进行词频统计
//这里分词已经结束,对词进行编号 //这里分词已经结束,对词进行编号
number(); if (!wordTemple.isSplitWord()) {
//进入随机森林进行学习 number();
study(); //进入随机森林进行学习
study();
}
} }
private int getKey(List<String> words, String testWord) { private int getKey(List<String> words, String testWord) {

@ -20,6 +20,15 @@ public class WordTemple {
private double trustPunishment = 0.1;//信任惩罚 private double trustPunishment = 0.1;//信任惩罚
private double trustTh = 0.1;//信任阈值,相当于一次信任惩罚的数值 private double trustTh = 0.1;//信任阈值,相当于一次信任惩罚的数值
private int treeNub = 11;//丛林里面树的数量 private int treeNub = 11;//丛林里面树的数量
private boolean isSplitWord = false;//是否使用拆分词模式,默认是不使用
public boolean isSplitWord() {
return isSplitWord;
}
public void setSplitWord(boolean splitWord) {
isSplitWord = splitWord;
}
public int getTreeNub() { public int getTreeNub() {
return treeNub; return treeNub;

@ -72,13 +72,18 @@ public class LangTest {
TemplateReader templateReader = new TemplateReader(); TemplateReader templateReader = new TemplateReader();
WordTemple wordTemple = new WordTemple();//初始化语言模版 WordTemple wordTemple = new WordTemple();//初始化语言模版
wordTemple.setTreeNub(9); wordTemple.setTreeNub(9);
// wordTemple.setSplitWord(true);
//读取语言模版,第一个参数是模版地址,第二个参数是编码方式 (教程里的第三个参数已经省略) //读取语言模版,第一个参数是模版地址,第二个参数是编码方式 (教程里的第三个参数已经省略)
//同时也是学习过程 //同时也是学习过程
templateReader.read("/Users/lidapeng/Desktop/myDocument/model.txt", "UTF-8", wordTemple); templateReader.read("/Users/lidapeng/Desktop/myDocument/model.txt", "UTF-8", wordTemple);
Talk talk = new Talk(wordTemple); Talk talk = new Talk(wordTemple);
//输入语句进行识别若有标点符号会形成LIST中的每个元素 //输入语句进行识别若有标点符号会形成LIST中的每个元素
//返回的集合中每个值代表了输入语句,每个标点符号前语句的分类 //返回的集合中每个值代表了输入语句,每个标点符号前语句的分类
List<Integer> list = talk.talk("空调坏了"); // List<List<String>> lists = talk.getSplitWord("空调坏了,帮我修一修");
// for (List<String> list : lists) {
// System.out.println(list);
// }
List<Integer> list = talk.talk("空调坏了,帮我修一修");
System.out.println(list); System.out.println(list);
} }
} }

@ -31,64 +31,57 @@ public class NerveDemo1 {
* @param hiddenNerverNub * @param hiddenNerverNub
* @param outNerveNub * @param outNerveNub
* @param hiddenDepth * @param hiddenDepth
* @param isAccurate
* @param activeFunction * @param activeFunction
* @param isDynamic * @param isDynamic
*/ */
// NerveManager nerveManager = new NerveManager(2, 6, 1, 4, new Sigmod(), NerveManager nerveManager = new NerveManager(2, 6, 1, 2, new Tanh(),
// false, true, 0, RZ.NOT_RZ, 0); false, true, 0, RZ.NOT_RZ, 0);
// nerveManager.init(true, false, false, false); nerveManager.init(true, false, false, false);
// //创建训练
// List<Map<Integer, Double>> list_right = new LinkedList<>();//存放正确的值
// //创建训练 List<Map<Integer, Double>> list_wrong = new LinkedList<>();//存放错误的值
// List<Map<Integer, Double>> list_right = new LinkedList<>();//存放正确的值 Random random = new Random();
// List<Map<Integer, Double>> list_wrong = new LinkedList<>();//存放错误的值 for (int i = 0; i < 1000; i++) {
// Random random = new Random(); Map<Integer, Double> mp1 = new HashMap<>();
// for (int i = 0; i < 10000; i++) { Map<Integer, Double> mp2 = new HashMap<>();
// Map<Integer, Double> mp1 = new HashMap<>(); mp1.put(0, random.nextDouble());
// Map<Integer, Double> mp2 = new HashMap<>(); mp1.put(1, random.nextDouble());
// mp1.put(0, random.nextDouble()); mp2.put(0, -random.nextDouble());//负样本:负数永远小于0
// mp1.put(1, random.nextDouble()); mp2.put(1, -random.nextDouble());
// mp2.put(0, -random.nextDouble());//负样本:负数永远小于0 list_right.add(mp1);
// mp2.put(1, -random.nextDouble()); list_wrong.add(mp2);
// list_right.add(mp1); }
// list_wrong.add(mp2);
// }
// //
// //做一个正标注和负标注 // //做一个正标注和负标注
// Map<Integer, Double> right = new HashMap<>(); Map<Integer, Double> right = new HashMap<>();
// Map<Integer, Double> wrong = new HashMap<>(); Map<Integer, Double> wrong = new HashMap<>();
// right.put(1, 1.0); right.put(1, 1.0);
// wrong.put(1, 0.0); wrong.put(1, 0.0);
//
// //开始训练 //开始训练
// for (int i = 0; i < list_right.size(); i++) { for (int i = 0; i < list_right.size(); i++) {
// Map<Integer, Double> mp1 = list_right.get(i); Map<Integer, Double> mp1 = list_right.get(i);
// Map<Integer, Double> mp2 = list_wrong.get(i); Map<Integer, Double> mp2 = list_wrong.get(i);
// //这里的post的训练 //这里的post的训练
// post(nerveManager.getSensoryNerves(), mp1, right, null, true); post(nerveManager.getSensoryNerves(), mp1, right, null, true);
// post(nerveManager.getSensoryNerves(), mp2, wrong, null, true); post(nerveManager.getSensoryNerves(), mp2, wrong, null, true);
// } }
// //
// //测试 这里测试10个数据 // //测试 这里测试10个数据
// List<Map<Integer, Double>> test1 = new LinkedList<>(); List<Map<Integer, Double>> test1 = new LinkedList<>();
// for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
// Map<Integer, Double> mp1 = new HashMap<>(); Map<Integer, Double> mp1 = new HashMap<>();
// if (i == 4) {//在第五次的时候给一个错误的值//看看能否正确识别 mp1.put(0, -random.nextDouble());
// mp1.put(0, -random.nextDouble()); mp1.put(1, -random.nextDouble());
// mp1.put(1, -random.nextDouble()); test1.add(mp1);
// test1.add(mp1); }
// continue;
// }
// mp1.put(0, random.nextDouble());
// mp1.put(1, random.nextDouble());
// test1.add(mp1);
// }
// //查看结果 // //查看结果
// Back back = new Back(); Back back = new Back();
// for (Map<Integer, Double> test_data : test1) { for (Map<Integer, Double> test_data : test1) {
// //这里的post是进行学习 //这里的post是进行学习
// post(nerveManager.getSensoryNerves(), test_data, null, back, false); post(nerveManager.getSensoryNerves(), test_data, null, back, false);
// } }
/* /*
* : * :
* *
@ -105,65 +98,6 @@ public class NerveDemo1 {
* *
* , * ,
*/ */
test3();
}
public static void test3() throws Exception {
NerveManager nerveManager = new NerveManager(2, 6, 2
, 2, new Tanh(),
false, true, 0.01, RZ.NOT_RZ, 0);
nerveManager.init(true, false, true, false);//初始化
List<Map<Integer, Double>> data = new ArrayList<>();//正样本
List<Map<Integer, Double>> dataB = new ArrayList<>();//负样本
double a1 = 0.5463803496429489;
double a2 = 1.0917521555875922;
double b1 = 1.0012872347982982;
double b2 = 0.6597094176788427;
Random random = new Random();
for (int i = 0; i < 3000; i++) {
Map<Integer, Double> map1 = new HashMap<>();
Map<Integer, Double> map2 = new HashMap<>();
map1.put(0, a1);
map1.put(1, a2);
//产生鲜明区分
map2.put(0, b1);
map2.put(1, b2);
data.add(map1);
dataB.add(map2);
}
Map<Integer, Double> right = new HashMap<>();
Map<Integer, Double> wrong = new HashMap<>();
right.put(1, 1.0);
wrong.put(2, 1.0);
for (int i = 0; i < data.size(); i++) {
Map<Integer, Double> map1 = data.get(i);
Map<Integer, Double> map2 = dataB.get(i);
post(nerveManager.getSensoryNerves(), map1, right, null, true);
System.out.println("================");
post(nerveManager.getSensoryNerves(), map2, wrong, null, true);
}
List<Map<Integer, Double>> data2 = new ArrayList<>();
List<Map<Integer, Double>> data2B = new ArrayList<>();
for (int i = 0; i < 20; i++) {
Map<Integer, Double> map1 = new HashMap<>();
Map<Integer, Double> map2 = new HashMap<>();
map1.put(0, a1);
map1.put(1, a2);
map2.put(0, b1);
map2.put(1, b2);
data2.add(map1);
data2B.add(map2);
}
Back back = new Back();
for (Map<Integer, Double> map : data2B) {
post(nerveManager.getSensoryNerves(), map, null, back, false);
System.out.println("=====================");
}
} }
/** /**

Loading…
Cancel
Save