修改了编码时忘记排重的问题

pull/1/head
lidapeng 5 years ago
parent a01779d547
commit 03bab40e30

@ -4,8 +4,8 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>com.github</groupId> <groupId>com.wlld</groupId>
<artifactId>ImageMarket</artifactId> <artifactId>easyAi</artifactId>
<version>1.0.0</version> <version>1.0.0</version>
<name>myBrain</name> <name>myBrain</name>

@ -46,6 +46,7 @@ public class Talk {
nub = size; nub = size;
} }
} }
//System.out.println(sentenceList.get(key).getKeyWords());
sentences.add(sentenceList.get(key)); sentences.add(sentenceList.get(key));
} }

@ -41,6 +41,23 @@ public class Tokenizer extends Frequency {
study(); study();
} }
private int getKey(List<String> words, String testWord) {
int nub = 0;
int size = words.size();
for (int i = 0; i < size; i++) {
String word = words.get(i);
if (testWord.hashCode() == word.hashCode() && testWord.equals(word)) {
nub = i + 1;
break;
}
}
if (nub == 0) {
words.add(testWord);
nub = words.size();
}
return nub;
}
private void number() {//分词编号 private void number() {//分词编号
System.out.println("开始编码:" + sentences.size()); System.out.println("开始编码:" + sentences.size());
for (Sentence sentence : sentences) { for (Sentence sentence : sentences) {
@ -51,10 +68,10 @@ public class Tokenizer extends Frequency {
if (wordTimes.size() < i + 1) { if (wordTimes.size() < i + 1) {
wordTimes.add(new ArrayList<>()); wordTimes.add(new ArrayList<>());
} }
String word = sentenceList.get(i);//当前关键字
List<String> list = wordTimes.get(i); List<String> list = wordTimes.get(i);
int nub = list.size(); int nub = getKey(list, word);
features.add(nub); features.add(nub);
list.add(sentenceList.get(i));
} }
} }
} }
@ -69,7 +86,7 @@ public class Tokenizer extends Frequency {
DataTable dataTable = new DataTable(column); DataTable dataTable = new DataTable(column);
dataTable.setKey("key"); dataTable.setKey("key");
//初始化随机森林 //初始化随机森林
RandomForest randomForest = new RandomForest(7); RandomForest randomForest = new RandomForest(11);
WordTemple.get().setRandomForest(randomForest);//保存随机森林到模版 WordTemple.get().setRandomForest(randomForest);//保存随机森林到模版
randomForest.init(dataTable); randomForest.init(dataTable);
for (Sentence sentence : sentences) { for (Sentence sentence : sentences) {

@ -16,7 +16,16 @@ public class WordTemple {
private List<WorldBody> allWorld = new ArrayList<>();//所有词集合 private List<WorldBody> allWorld = new ArrayList<>();//所有词集合
private List<List<String>> wordTimes = new ArrayList<>();//词编号 private List<List<String>> wordTimes = new ArrayList<>();//词编号
private RandomForest randomForest;//保存的随机森林模型 private RandomForest randomForest;//保存的随机森林模型
private double garbageTh = 0.5;//垃圾分类的阈值默认0.7 private double garbageTh = 0.5;//垃圾分类的阈值默认0.5
private double trustPunishment = 0.1;//信任惩罚
public double getTrustPunishment() {
return trustPunishment;
}
public void setTrustPunishment(double trustPunishment) {
this.trustPunishment = trustPunishment;
}
public double getGarbageTh() { public double getGarbageTh() {
return garbageTh; return garbageTh;

@ -44,21 +44,24 @@ public class RandomForest {
} }
public int forest(Object object) throws Exception {//随机森林识别 public int forest(Object object) throws Exception {//随机森林识别
Map<Integer, Integer> map = new HashMap<>(); Map<Integer, Double> map = new HashMap<>();
for (int i = 0; i < forest.length; i++) { for (int i = 0; i < forest.length; i++) {
Tree tree = forest[i]; Tree tree = forest[i];
int type = tree.judge(object); TreeWithTrust treeWithTrust = tree.judge(object);
int type = treeWithTrust.getType();
//System.out.println(type); //System.out.println(type);
double trust = treeWithTrust.getTrust();
if (map.containsKey(type)) { if (map.containsKey(type)) {
map.put(type, map.get(type) + 1); map.put(type, ArithUtil.add(map.get(type), trust));
} else { } else {
map.put(type, 1); map.put(type, trust);
} }
} }
int type = 0; int type = 0;
int nub = 0; double nub = 0;
for (Map.Entry<Integer, Integer> entry : map.entrySet()) { for (Map.Entry<Integer, Double> entry : map.entrySet()) {
int myNub = entry.getValue(); double myNub = entry.getValue();
//System.out.println("type==" + entry.getKey() + ",nub==" + myNub);
if (myNub > nub) { if (myNub > nub) {
type = entry.getKey(); type = entry.getKey();
nub = myNub; nub = myNub;
@ -71,6 +74,8 @@ public class RandomForest {
//一棵树属性的数量 //一棵树属性的数量
if (dataTable.getSize() > 4) { if (dataTable.getSize() > 4) {
int kNub = (int) ArithUtil.div(Math.log(dataTable.getSize()), Math.log(2)); int kNub = (int) ArithUtil.div(Math.log(dataTable.getSize()), Math.log(2));
//int kNub = dataTable.getSize() - 1;
// System.out.println("knNub==" + kNub);
for (int i = 0; i < forest.length; i++) { for (int i = 0; i < forest.length; i++) {
Tree tree = new Tree(getRandomData(dataTable, kNub)); Tree tree = new Tree(getRandomData(dataTable, kNub));
forest[i] = tree; forest[i] = tree;
@ -111,6 +116,7 @@ public class RandomForest {
list.remove(index); list.remove(index);
} }
myName.add(key); myName.add(key);
//System.out.println(myName);
DataTable data = new DataTable(myName); DataTable data = new DataTable(myName);
data.setKey(key); data.setKey(key);
return data; return data;

@ -1,5 +1,6 @@
package org.wlld.randomForest; package org.wlld.randomForest;
import org.wlld.naturalLanguage.WordTemple;
import org.wlld.tools.ArithUtil; import org.wlld.tools.ArithUtil;
import java.lang.reflect.Method; import java.lang.reflect.Method;
@ -199,20 +200,34 @@ public class Tree {//决策树
Class<?> body = ob.getClass(); Class<?> body = ob.getClass();
String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1); String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
Method method = body.getMethod(methodName); Method method = body.getMethod(methodName);
return (int) method.invoke(ob); int nub = (int) method.invoke(ob);
return nub;
} }
public int judge(Object ob) throws Exception {//进行类别判断 public TreeWithTrust judge(Object ob) throws Exception {//进行类别判断
if (rootNode != null) { if (rootNode != null) {
return goTree(ob, rootNode); TreeWithTrust treeWithTrust = new TreeWithTrust();
treeWithTrust.setTrust(1.0);
goTree(ob, rootNode, treeWithTrust, 0);
return treeWithTrust;
} else { } else {
throw new Exception("rootNode is null"); throw new Exception("rootNode is null");
} }
} }
private int goTree(Object ob, Node node) throws Exception {//从树顶向下攀爬 private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚
//System.out.println("惩罚");
double trust = treeWithTrust.getTrust();//获取当前信任值
trust = ArithUtil.mul(trust, WordTemple.get().getTrustPunishment());
treeWithTrust.setTrust(trust);
}
private void goTree(Object ob, Node node, TreeWithTrust treeWithTrust, int times) throws Exception {//从树顶向下攀爬
if (!node.isEnd) { if (!node.isEnd) {
int myType = getTypeId(ob, node.key);//当前类别的ID int myType = getTypeId(ob, node.key);//当前类别的ID
if (myType == 0) {//做信任惩罚
punishment(treeWithTrust);
}
List<Node> nodeList = node.nodeList; List<Node> nodeList = node.nodeList;
boolean isOk = false; boolean isOk = false;
for (Node testNode : nodeList) { for (Node testNode : nodeList) {
@ -223,12 +238,23 @@ public class Tree {//决策树
} }
} }
if (!isOk) {//当前类别缺失,未知的属性值 if (!isOk) {//当前类别缺失,未知的属性值
punishment(treeWithTrust);
punishment(treeWithTrust);
int index = random.nextInt(nodeList.size()); int index = random.nextInt(nodeList.size());
node = nodeList.get(index); node = nodeList.get(index);
} }
return goTree(ob, node); times++;
goTree(ob, node, treeWithTrust, times);
} else { } else {
return node.type; //当以0作为结束的时候要做严厉的信任惩罚
if (node.typeId == 0) {
int nub = rootNode.attribute.size() - times;
//System.out.println("惩罚次数" + nub);
for (int i = 0; i < nub; i++) {
punishment(treeWithTrust);
}
}
treeWithTrust.setType(node.type);
} }
} }

@ -0,0 +1,27 @@
package org.wlld.randomForest;
/**
* @author lidapeng
* @description
* @date 8:37 2020/2/28
*/
public class TreeWithTrust {
private int type;//类别
private double trust;//可信度
public int getType() {
return type;
}
public void setType(int type) {
this.type = type;
}
public double getTrust() {
return trust;
}
public void setTrust(double trust) {
this.trust = trust;
}
}

@ -25,7 +25,7 @@ public class LangTest {
//识别过程 //识别过程
Talk talk = new Talk(); Talk talk = new Talk();
//我饿了,我想吃个饭 //我饿了,我想吃个饭
List<Integer> list = talk.talk("联系个开锁公司"); List<Integer> list = talk.talk("找个上门取件的快递员");
System.out.println(list); System.out.println(list);
} }
} }

@ -58,8 +58,8 @@ public class MatrixTest {
food.setSex(random.nextInt(cla)); food.setSex(random.nextInt(cla));
food.setH1(random.nextInt(cla)); food.setH1(random.nextInt(cla));
food.setH2(random.nextInt(cla)); food.setH2(random.nextInt(cla));
int type = tree.judge(food); int type = tree.judge(food).getType();
int type2 = tree2.judge(food); int type2 = tree2.judge(food).getType();
if (type != type2) { if (type != type2) {
System.out.println("出错,type1==" + type + ",type2==" + type2); System.out.println("出错,type1==" + type + ",type2==" + type2);
} else { } else {

Loading…
Cancel
Save