增加二叉树回归及后剪枝

pull/57/head
lidapeng 4 years ago
parent dfe0b9576d
commit 58db5df18c

@ -34,15 +34,17 @@ public class Forest extends Frequency {
private int id;//本节点的id private int id;//本节点的id
private boolean isRemove = false;//是否已经被移除了 private boolean isRemove = false;//是否已经被移除了
private boolean notRemovable = false;//不可移除 private boolean notRemovable = false;//不可移除
private int minGrain;//最小粒度
public Forest(int featureSize, double shrinkParameter, Matrix pc, Map<Integer, Forest> forestMap public Forest(int featureSize, double shrinkParameter, Matrix pc, Map<Integer, Forest> forestMap
, int id) { , int id, int minGrain) {
this.featureSize = featureSize; this.featureSize = featureSize;
this.shrinkParameter = shrinkParameter; this.shrinkParameter = shrinkParameter;
this.pc = pc; this.pc = pc;
w = new double[featureSize]; w = new double[featureSize];
this.forestMap = forestMap; this.forestMap = forestMap;
this.id = id; this.id = id;
this.minGrain = minGrain;
} }
public double getMedian() { public double getMedian() {
@ -186,7 +188,7 @@ public class Forest extends Frequency {
public void cut() throws Exception { public void cut() throws Exception {
int y = resultMatrix.getX(); int y = resultMatrix.getX();
if (y > 200) { if (y > minGrain) {
double[] dm = findG(); double[] dm = findG();
int z = y / 2; int z = y / 2;
median = dm[z]; median = dm[z];
@ -203,8 +205,8 @@ public class Forest extends Frequency {
int rightId = leftId + 1; int rightId = leftId + 1;
//System.out.println("id:" + id + ",size:" + dm.length); //System.out.println("id:" + id + ",size:" + dm.length);
forestMap.put(id, this); forestMap.put(id, this);
forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId); forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId, minGrain);
forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId); forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId, minGrain);
forestRight.setFather(this); forestRight.setFather(this);
forestLeft.setFather(this); forestLeft.setFather(this);
Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左 Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左

@ -34,7 +34,7 @@ public class RegressionForest extends Frequency {
this.cosSize = cosSize; this.cosSize = cosSize;
} }
public RegressionForest(int size, int featureNub, double shrinkParameter) throws Exception {//初始化 public RegressionForest(int size, int featureNub, double shrinkParameter, int minGrain) throws Exception {//初始化
if (size > 0 && featureNub > 0) { if (size > 0 && featureNub > 0) {
this.featureNub = featureNub; this.featureNub = featureNub;
w = new double[featureNub]; w = new double[featureNub];
@ -42,7 +42,7 @@ public class RegressionForest extends Frequency {
conditionMatrix = new Matrix(size, featureNub); conditionMatrix = new Matrix(size, featureNub);
resultMatrix = new Matrix(size, 1); resultMatrix = new Matrix(size, 1);
createG(); createG();
forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1); forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1, minGrain);
forestMap.put(1, forest); forestMap.put(1, forest);
forest.setW(w); forest.setW(w);
forest.setConditionMatrix(conditionMatrix); forest.setConditionMatrix(conditionMatrix);

@ -22,7 +22,7 @@ public class ForestTest {
public static void test() throws Exception {//对分段回归进行测试 public static void test() throws Exception {//对分段回归进行测试
int size = 2000; int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.01); RegressionForest regressionForest = new RegressionForest(size, 3, 0.01, 200);
regressionForest.setCosSize(40); regressionForest.setCosSize(40);
List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1); List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b = fun(0.7, 0.3, 0.1, size, 2, 2); List<double[]> b = fun(0.7, 0.3, 0.1, size, 2, 2);

Loading…
Cancel
Save