增加二叉树回归及后剪枝

pull/57/head
lidapeng 4 years ago
parent dfe0b9576d
commit 58db5df18c

@ -34,15 +34,17 @@ public class Forest extends Frequency {
private int id;//本节点的id
private boolean isRemove = false;//是否已经被移除了
private boolean notRemovable = false;//不可移除
private int minGrain;//最小粒度
public Forest(int featureSize, double shrinkParameter, Matrix pc, Map<Integer, Forest> forestMap
, int id) {
, int id, int minGrain) {
this.featureSize = featureSize;
this.shrinkParameter = shrinkParameter;
this.pc = pc;
w = new double[featureSize];
this.forestMap = forestMap;
this.id = id;
this.minGrain = minGrain;
}
public double getMedian() {
@ -186,7 +188,7 @@ public class Forest extends Frequency {
public void cut() throws Exception {
int y = resultMatrix.getX();
if (y > 200) {
if (y > minGrain) {
double[] dm = findG();
int z = y / 2;
median = dm[z];
@ -203,8 +205,8 @@ public class Forest extends Frequency {
int rightId = leftId + 1;
//System.out.println("id:" + id + ",size:" + dm.length);
forestMap.put(id, this);
forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId);
forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId);
forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId, minGrain);
forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId, minGrain);
forestRight.setFather(this);
forestLeft.setFather(this);
Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左

@ -34,7 +34,7 @@ public class RegressionForest extends Frequency {
this.cosSize = cosSize;
}
public RegressionForest(int size, int featureNub, double shrinkParameter) throws Exception {//初始化
public RegressionForest(int size, int featureNub, double shrinkParameter, int minGrain) throws Exception {//初始化
if (size > 0 && featureNub > 0) {
this.featureNub = featureNub;
w = new double[featureNub];
@ -42,7 +42,7 @@ public class RegressionForest extends Frequency {
conditionMatrix = new Matrix(size, featureNub);
resultMatrix = new Matrix(size, 1);
createG();
forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1);
forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1, minGrain);
forestMap.put(1, forest);
forest.setW(w);
forest.setConditionMatrix(conditionMatrix);

@ -22,7 +22,7 @@ public class ForestTest {
public static void test() throws Exception {//对分段回归进行测试
int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.01);
RegressionForest regressionForest = new RegressionForest(size, 3, 0.01, 200);
regressionForest.setCosSize(40);
List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b = fun(0.7, 0.3, 0.1, size, 2, 2);

Loading…
Cancel
Save