修改了编码时忘记排重的问题

pull/1/head
lidapeng 5 years ago
parent a01779d547
commit 03bab40e30

@ -4,8 +4,8 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.github</groupId>
<artifactId>ImageMarket</artifactId>
<groupId>com.wlld</groupId>
<artifactId>easyAi</artifactId>
<version>1.0.0</version>
<name>myBrain</name>

@ -46,6 +46,7 @@ public class Talk {
nub = size;
}
}
//System.out.println(sentenceList.get(key).getKeyWords());
sentences.add(sentenceList.get(key));
}

@ -41,6 +41,23 @@ public class Tokenizer extends Frequency {
study();
}
private int getKey(List<String> words, String testWord) {
int nub = 0;
int size = words.size();
for (int i = 0; i < size; i++) {
String word = words.get(i);
if (testWord.hashCode() == word.hashCode() && testWord.equals(word)) {
nub = i + 1;
break;
}
}
if (nub == 0) {
words.add(testWord);
nub = words.size();
}
return nub;
}
private void number() {//分词编号
System.out.println("开始编码:" + sentences.size());
for (Sentence sentence : sentences) {
@ -51,10 +68,10 @@ public class Tokenizer extends Frequency {
if (wordTimes.size() < i + 1) {
wordTimes.add(new ArrayList<>());
}
String word = sentenceList.get(i);//当前关键字
List<String> list = wordTimes.get(i);
int nub = list.size();
int nub = getKey(list, word);
features.add(nub);
list.add(sentenceList.get(i));
}
}
}
@ -69,7 +86,7 @@ public class Tokenizer extends Frequency {
DataTable dataTable = new DataTable(column);
dataTable.setKey("key");
//初始化随机森林
RandomForest randomForest = new RandomForest(7);
RandomForest randomForest = new RandomForest(11);
WordTemple.get().setRandomForest(randomForest);//保存随机森林到模版
randomForest.init(dataTable);
for (Sentence sentence : sentences) {

@ -16,7 +16,16 @@ public class WordTemple {
private List<WorldBody> allWorld = new ArrayList<>();//所有词集合
private List<List<String>> wordTimes = new ArrayList<>();//词编号
private RandomForest randomForest;//保存的随机森林模型
private double garbageTh = 0.5;//垃圾分类的阈值默认0.7
private double garbageTh = 0.5;//垃圾分类的阈值默认0.5
private double trustPunishment = 0.1;//信任惩罚
public double getTrustPunishment() {
return trustPunishment;
}
public void setTrustPunishment(double trustPunishment) {
this.trustPunishment = trustPunishment;
}
public double getGarbageTh() {
return garbageTh;

@ -44,21 +44,24 @@ public class RandomForest {
}
public int forest(Object object) throws Exception {//随机森林识别
Map<Integer, Integer> map = new HashMap<>();
Map<Integer, Double> map = new HashMap<>();
for (int i = 0; i < forest.length; i++) {
Tree tree = forest[i];
int type = tree.judge(object);
TreeWithTrust treeWithTrust = tree.judge(object);
int type = treeWithTrust.getType();
//System.out.println(type);
double trust = treeWithTrust.getTrust();
if (map.containsKey(type)) {
map.put(type, map.get(type) + 1);
map.put(type, ArithUtil.add(map.get(type), trust));
} else {
map.put(type, 1);
map.put(type, trust);
}
}
int type = 0;
int nub = 0;
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
int myNub = entry.getValue();
double nub = 0;
for (Map.Entry<Integer, Double> entry : map.entrySet()) {
double myNub = entry.getValue();
//System.out.println("type==" + entry.getKey() + ",nub==" + myNub);
if (myNub > nub) {
type = entry.getKey();
nub = myNub;
@ -71,6 +74,8 @@ public class RandomForest {
//一棵树属性的数量
if (dataTable.getSize() > 4) {
int kNub = (int) ArithUtil.div(Math.log(dataTable.getSize()), Math.log(2));
//int kNub = dataTable.getSize() - 1;
// System.out.println("knNub==" + kNub);
for (int i = 0; i < forest.length; i++) {
Tree tree = new Tree(getRandomData(dataTable, kNub));
forest[i] = tree;
@ -111,6 +116,7 @@ public class RandomForest {
list.remove(index);
}
myName.add(key);
//System.out.println(myName);
DataTable data = new DataTable(myName);
data.setKey(key);
return data;

@ -1,5 +1,6 @@
package org.wlld.randomForest;
import org.wlld.naturalLanguage.WordTemple;
import org.wlld.tools.ArithUtil;
import java.lang.reflect.Method;
@ -199,20 +200,34 @@ public class Tree {//决策树
Class<?> body = ob.getClass();
String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
Method method = body.getMethod(methodName);
return (int) method.invoke(ob);
int nub = (int) method.invoke(ob);
return nub;
}
public int judge(Object ob) throws Exception {//进行类别判断
public TreeWithTrust judge(Object ob) throws Exception {//进行类别判断
if (rootNode != null) {
return goTree(ob, rootNode);
TreeWithTrust treeWithTrust = new TreeWithTrust();
treeWithTrust.setTrust(1.0);
goTree(ob, rootNode, treeWithTrust, 0);
return treeWithTrust;
} else {
throw new Exception("rootNode is null");
}
}
private int goTree(Object ob, Node node) throws Exception {//从树顶向下攀爬
private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚
//System.out.println("惩罚");
double trust = treeWithTrust.getTrust();//获取当前信任值
trust = ArithUtil.mul(trust, WordTemple.get().getTrustPunishment());
treeWithTrust.setTrust(trust);
}
private void goTree(Object ob, Node node, TreeWithTrust treeWithTrust, int times) throws Exception {//从树顶向下攀爬
if (!node.isEnd) {
int myType = getTypeId(ob, node.key);//当前类别的ID
if (myType == 0) {//做信任惩罚
punishment(treeWithTrust);
}
List<Node> nodeList = node.nodeList;
boolean isOk = false;
for (Node testNode : nodeList) {
@ -223,12 +238,23 @@ public class Tree {//决策树
}
}
if (!isOk) {//当前类别缺失,未知的属性值
punishment(treeWithTrust);
punishment(treeWithTrust);
int index = random.nextInt(nodeList.size());
node = nodeList.get(index);
}
return goTree(ob, node);
times++;
goTree(ob, node, treeWithTrust, times);
} else {
return node.type;
//当以0作为结束的时候要做严厉的信任惩罚
if (node.typeId == 0) {
int nub = rootNode.attribute.size() - times;
//System.out.println("惩罚次数" + nub);
for (int i = 0; i < nub; i++) {
punishment(treeWithTrust);
}
}
treeWithTrust.setType(node.type);
}
}

@ -0,0 +1,27 @@
package org.wlld.randomForest;
/**
* @author lidapeng
* @description
* @date 8:37 2020/2/28
*/
public class TreeWithTrust {
private int type;//类别
private double trust;//可信度
public int getType() {
return type;
}
public void setType(int type) {
this.type = type;
}
public double getTrust() {
return trust;
}
public void setTrust(double trust) {
this.trust = trust;
}
}

@ -25,7 +25,7 @@ public class LangTest {
//识别过程
Talk talk = new Talk();
//我饿了,我想吃个饭
List<Integer> list = talk.talk("联系个开锁公司");
List<Integer> list = talk.talk("找个上门取件的快递员");
System.out.println(list);
}
}

@ -58,8 +58,8 @@ public class MatrixTest {
food.setSex(random.nextInt(cla));
food.setH1(random.nextInt(cla));
food.setH2(random.nextInt(cla));
int type = tree.judge(food);
int type2 = tree2.judge(food);
int type = tree.judge(food).getType();
int type2 = tree2.judge(food).getType();
if (type != type2) {
System.out.println("出错,type1==" + type + ",type2==" + type2);
} else {

Loading…
Cancel
Save