RANSAC算法的Java语言实现

上了一学期的当代摄影测量课,讲的理论和方法啊感觉都挺高深的,听得云里雾里的,就是搞不清楚怎么实现,在哪里用。我是比较喜欢通过实际地操作去学东西的,于是就把老师上课讲的几个算法记下来,找了些资料钻了一下,然后动手把程序写了一遍,豁然开朗。

在摄影测量和遥感领域中,特征提取是一个比较关键的问题。而RANSAC是一种比较经典的特征提取算法。RANSAC是随机抽样一致算法(RANdom SAmple Cosensus)的缩写,是Fischler和 Bolles在1981年提出的一个在一组观测数据集中估计模型参数的迭代方法。它属于一种不确定算法,每次运算都会有一定的概率得到一个合理的结果,这个概率随着迭代次数的增加而增加,所以总能通过增加迭代次数来过的一个比较满意的结果。

RANSAC算法中假设数据由“内点”和“外点”组成。其中内点是指那些在位置分布符合指定模型参数的点;而“外点”是那些和模型参数不相匹配的点。“外点”有可能来自极端的数据噪音点、测量错误或对数据解释错误。算法中假定,如果给定一个内点集合,可以最终估算出一组最佳的模型参数使得模型与给定的内点集合相吻合。

RANSAC算法的输入参数包括一组观测数据、一个可以用来匹配或拟合的参数化模型以及一些置信参数。RANSAC算法的过程主要是通过迭代的方式选择原始数据的一组子集作假设检验,最终得到最优的模型参数。在每次迭代中,假设被选择的子数据集是“内点”,然后对这些“内点”做如下检验:

  1. 利用内点计算模型参数,看模型和假设的“内点”是否相符相符。
  2. 将所有的“外点”与上一步计算出的模型相比较,如果能够很好地匹配这个模型,则将该点放入“内点”集合。
  3.  如果有“内点”集合中有足够多的点,则认为本次迭代所估计的模型是较好的。
  4.  根据新的“内点”集合来重新计算模型的参数。
  5. 最后,通过根据“内点”与计算出的模型得出模型的误差指数,并以此作为模型的评价标准。

如果本次所计算出的模型优于之前所计算出的最优模型,则将当前的模型保存为最优模型,否则继续下一轮迭代。经过指定的迭代次数后,返回最优的那一个模型。

还是举个例子说明一下比较容易理解。如下图所示,有一个由许多点组成的集合,这里面有一部分点呈比较明显地沿一条直线分布,而其他的点则是杂乱地分布,这个算法要解决的问题就在是已知所有点坐标的情况下如何才能比较准确地得到那条直线的位置,也就是求出直线方程的各个参数。也许你会马上想到用最小二乘发做线性回归,但是当那些干扰点所占的比例很大时,计算结果就会有很大的误差。RANSAC算法是这样解决这一问题的,首先任意选两个点作为初始内点集,可以计算出一个直线方程,然后依次检验剩余的点是否符合这一方程,如果是则将这个点加入内点集,最后根据内点的比例判断这次计算的结果是不是当前最好的,如果是则根据这些内点用最小二乘法求出直线参数并记录下来,然后在进行下一轮的计算,当经过指定的循环次数或结果的足够好是停止计算,所记录的当前最优参数即为要求的结果。

用RANSAC算法提取直线特征
用RANSAC算法提取直线特征

RANSAC的一个优势是它具有很好的鲁棒性,即使在有很多“外点”的情况下也能做出精度很高的模型估计。但是RANSAC算法的计算时间无法很好地控制,通常的做法是根据结果所需要满足的概率p、外点比例的最大值w和估算模型所需的样品个数,使用以下公式计算出需要迭代的次数。

k = \frac{\log(1 - p)}{\log(1 - w^n)}

知道了这些,程序就很容易搞定了。用Java来做,也顺便复习了一下泛型和2D绘图。考虑的代码的重用性,把算法独立写成一个类,把模型估算器写成接口,最后用提取直线做例子来进行测试。

代码1:tigerlihao/cv/ransac/Ransac.java

package tigerlihao.cv.ransac;

import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;

/**
 * @author tigerlihao
 *
 * @param <T>
 *            样本的类型
 * @param <S>
 *            参数的类型
 */
public class Ransac<T, S> {
    private List<S> parameters = null;
    private ParameterEstimator<T, S> paramEstimator;
    private boolean[] bestVotes;
    private int numForEstimate;
    private double maximalOutlierPercentage;

    /**
     * @return 最优内点集
     */
    public boolean[] getBestVotes() {
        return bestVotes;
    }

    /**
     * @return 最优模型参数
     */
    public List<S> getParameters() {
        return parameters;
    }

    /**
     * Ransac对象的构造方法
     *
     * @param paramEstimator
     *            所使用的参数估计器
     * @param numForEstimate
     *            估计模型参数所需的最小样本数
     * @param maximalOutlierPercentage
     *            外点的最大百分比
     */
    public Ransac(ParameterEstimator<T, S> paramEstimator, int numForEstimate,
            double maximalOutlierPercentage) {
        this.paramEstimator = paramEstimator;
        this.numForEstimate = numForEstimate;
        this.maximalOutlierPercentage = maximalOutlierPercentage;
    }

    /**
     * 执行RANSAC算法的方法
     *
     * @param data
     *            样本集合
     * @param desiredProbabilityForNoOutliers
     *            所需要的精度
     * @return 最优情况下的内点百分比
     */
    public double compute(List<T> data, double desiredProbabilityForNoOutliers) {
        int dataSize = data.size();
        if (dataSize < numForEstimate || maximalOutlierPercentage >= 1.0) {
            return 0.0;
        }
        List<T> exactedData = new ArrayList<T>();
        List<T> leastSqData;
        List<S> exactedParams;
        int bestSize, curSize, tryTimes;
        bestVotes = new boolean[dataSize];
        boolean[] curVotes = new boolean[dataSize];
        boolean[] notChosen = new boolean[dataSize];
        Set<int[]> chosenSubSets = new HashSet<int[]>();
        int[] curSubSetIndexes;
        double outlierPercentage = maximalOutlierPercentage;
        double numerator = Math.log(1.0 - desiredProbabilityForNoOutliers);
        double denominator = Math.log(1 - Math.pow(
                             1 - maximalOutlierPercentage, numForEstimate));
        if (parameters != null) {
            parameters.clear();
        } else {
            parameters = new ArrayList<S>();
        }
        bestSize = -1;
        Random random = new Random(new Date().getTime());
        tryTimes = (int) Math.round(numerator / denominator);
        for (int i = 0; i < tryTimes; i++) {
            for (int j = 0; j < notChosen.length; j++) {
                notChosen[j] = true;
            }
            curSubSetIndexes = new int[numForEstimate];
            exactedData.clear();
            // 随机选取样本
            for (int j = 0; j < numForEstimate; j++) {
                int selectedIndex = random.nextInt(dataSize - j);
                int k, l;
                for (k = 0, l = -1; k < dataSize && l < selectedIndex; k++) {
                    if (notChosen[k]) {
                        l++;
                    }
                }
                k--;
                exactedData.add(data.get(k));
                notChosen[k] = false;
            }
            for (int j = 0, k = 0; j < dataSize; j++) {
                if (!notChosen[j]) {
                    curSubSetIndexes[k] = j;
                    k++;
                }
            }
            // 若子集未选区过则执行测试
            if (chosenSubSets.add(curSubSetIndexes)) {
                exactedParams = paramEstimator.estimate(exactedData);
                curSize = 0;
                for (int j = 0; j < notChosen.length; j++) {
                    curVotes[j] = false;
                }
                for (int j = 0; j < dataSize; j++) {
                    if (paramEstimator.agree(exactedParams, data.get(j))) {
                        curVotes[j] = true;
                        curSize++;
                    }
                }
                if (curSize > bestSize) {
                    bestSize = curSize;
                    System.arraycopy(curVotes, 0, bestVotes, 0, dataSize);
                }
                outlierPercentage = 1.0 - (double) curSize / (double) dataSize;
                if (outlierPercentage < maximalOutlierPercentage) {
                    maximalOutlierPercentage = outlierPercentage;
                    denominator = Math.log(1 - Math.pow(
                                  1 - maximalOutlierPercentage, numForEstimate));
                    tryTimes = (int) Math.round(numerator / denominator);
                }
            } else {
                i--;
            }
        }
        chosenSubSets.clear();

        // 对当前最优子集使用最小二乘法计算最优参数
        leastSqData = new ArrayList<T>();
        for (int i = 0; i < dataSize; i++) {
            if (bestVotes[i]) {
                leastSqData.add(data.get(i));
            }
        }
        parameters = paramEstimator.leastSquaresEstimate(leastSqData);

        return (double) bestSize / (double) dataSize;
    }
}

代码2:tigerlihao/cv/ransac/ParameterEstimator.java

package tigerlihao.cv.ransac;

import java.util.List;

/**
 * 模型估计器接口
 *
 * @author tigerlihao
 *
 * @param <T>
 *            样本的类型
 * @param <S>
 *            参数的类型
 */
public interface ParameterEstimator<T, S> {
    /**
     * 执行准确参数估计的方法
     *
     * @param data
     *            用于估计的样本集合
     * @return
     *            模型参数列表
     */
    public List<S> estimate(List<T> data);

    /**
     * 执行最小二乘法估计的方法
     *
     * @param data
     *            用于估计的样本集合
     * @return
     *            模型参数列表
     */
    public List<S> leastSquaresEstimate(List<T> data);

    /**
     * 测试样本是否符合模型参数的方法
     *
     * @param parameters
     *            模型参数
     * @param data
     *            待测样本
     */
    public boolean agree(List<S> parameters, T data);
}

以点集中提取直线为例的测试结果:

RANSAC测试结果
RANSAC测试结果

 

图中空心的点为干扰点,实心的点为内点,蓝色的线为原始的直线,红色的线为估算出的直线。可以看出RANSAC算法的精度是比较好的。

看来上课只能听个皮毛,自己动手做一做还是有很大收获的。

《RANSAC算法的Java语言实现》有一个想法

  1. 哥们,你的技术不错,能留个联系方式么?我是作手机移动GIS的,http://hi.baidu.com/geochenyj 这是我的博客,多交流!

发表评论

邮箱地址不会被公开。 必填项已用*标注

Time limit is exhausted. Please reload CAPTCHA.