机器学习的1NN最近邻算法,在weka⾥叫IB1,是因为Instance Base 1 ,也就是只基于⼀个最近邻的实例的惰性学习算法。 下⾯总结⼀下,weka中对IB1源码的学习总结。
⾸先需要把 weka-src.jar 引⼊编译路径,否则⽆法跟踪源码。
1)读取data数据,完成 IB1 分类器的调⽤,结果预测评估。为了后⾯的跟踪。
try {
File file = new File(\"F:\\\ools/lib/data/contact-lenses.arff\"); ArffLoader loader = new ArffLoader(); loader.setFile(file);
ins = loader.getDataSet();
// 在使⽤样本之前⼀定要⾸先设置instances的classIndex,否则在使⽤instances对象是会抛出异常 ins.setClassIndex(ins.numAttributes() - 1);
cfs = new IB1();
cfs.buildClassifier(ins);
Instance testInst;
Evaluation testingEvaluation = new Evaluation(ins); int length = ins.numInstances(); for (int i = 0; i < length; i++) { testInst = ins.instance(i);
// 通过这个⽅法来⽤每个测试样本测试分类器的效果 double predictValue = cfs.classifyInstance(testInst);
System.out.println(testInst.classValue()+\"--\"+predictValue); }
// System.out.println(\"分类器的正确率:\" + (1 - testingEvaluation.errorRate())); } catch (Exception e) { e.printStackTrace(); }
2)ctrl 点击buildClassifier,进⼀步跟踪buildClassifier⽅法的源码,在IB1的类中重写了这个抽象⽅法,源码为:
public void buildClassifier(Instances instances) throws Exception {
// can classifier handle the data?
getCapabilities().testWithFail(instances);
// remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass();
m_Train = new Instances(instances, 0, instances.numInstances()); m_MinArray = new double [m_Train.numAttributes()]; m_MaxArray = new double [m_Train.numAttributes()]; for (int i = 0; i < m_Train.numAttributes(); i++) { m_MinArray[i] = m_MaxArray[i] = Double.NaN; }
Enumeration enu = m_Train.enumerateInstances(); while (enu.hasMoreElements()) {
updateMinMax((Instance) enu.nextElement()); } }
(1)if是判断,IB1分类器不能处理属性是字符串和类别是数值型的样本; (2)if是判断,删除没有类标签的样本;
(3)m_MinArray 和 m_MaxArray 分别保存最⼩和最⼤值,并且初始化double数组【样本个数】; (4)遍历所有的训练样本实例,求最⼩和最⼤值;继续跟踪updateMinMax⽅法;
3)IB1类的updateMinMax⽅法的源码如下:
private void updateMinMax(Instance instance) {
for (int j = 0;j < m_Train.numAttributes(); j++) {
if ((m_Train.attribute(j).isNumeric()) && (!instance.isMissing(j))) { if (Double.isNaN(m_MinArray[j])) { m_MinArray[j] = instance.value(j); m_MaxArray[j] = instance.value(j); } else {
if (instance.value(j) < m_MinArray[j]) { m_MinArray[j] = instance.value(j); } else {
if (instance.value(j) > m_MaxArray[j]) { m_MaxArray[j] = instance.value(j); } } } } } }
(1)过滤掉属性不是数值型和缺失标签的实例;
(2)若是isNaN,is not a number,是数值型的话,循环遍历样本的每⼀个属性,求出最⼤最⼩值;
到此为⽌,训练了IB1模型(有⼈可能会问lazy的算法难道不是不需要训练模型吗?我认为build分类器是为了初始化 m_Train和求所有实例的每个属性的最⼤最⼩值,为了下⼀步求distance做准备)
下⾯介绍下预测源码:
4)跟踪classifyInstance⽅法,源码如下:
public double classifyInstance(Instance instance) throws Exception {
if (m_Train.numInstances() == 0) {
throw new Exception(\"No training instances!\"); }
double distance, minDistance = Double.MAX_VALUE, classValue = 0; updateMinMax(instance);
Enumeration enu = m_Train.enumerateInstances(); while (enu.hasMoreElements()) {
Instance trainInstance = (Instance) enu.nextElement(); if (!trainInstance.classIsMissing()) {
distance = distance(instance, trainInstance); if (distance < minDistance) { minDistance = distance;
classValue = trainInstance.classValue(); } } }
return classValue; }
(1)调⽤⽅法updateMinMax更新了加⼊测试实例后的最⼤最⼩值;
(2)计算测试实例到每⼀个训练实例的距离,distance⽅法,并且保存距离最⼩的实例minDistance;
5)跟踪classifyInstance⽅法,源码如下:
private double distance(Instance first, Instance second) {
double diff, distance = 0;
for(int i = 0; i < m_Train.numAttributes(); i++) { if (i == m_Train.classIndex()) { continue; }
if (m_Train.attribute(i).isNominal()) { // If attribute is nominal
if (first.isMissing(i) || second.isMissing(i) || ((int)first.value(i) != (int)second.value(i))) { distance += 1; }
} else {
// If attribute is numeric
if (first.isMissing(i) || second.isMissing(i)){ if (first.isMissing(i) && second.isMissing(i)) { diff = 1; } else {
if (second.isMissing(i)) { diff = norm(first.value(i), i); } else {
diff = norm(second.value(i), i); }
if (diff < 0.5) { diff = 1.0 - diff; } }
} else {
diff = norm(first.value(i), i) - norm(second.value(i), i); }
distance += diff * diff; } }
return distance; }
对每⼀个属性遍历,计算数值属性距离的平⽅和,norm⽅法为规范化距离公式,为【0,1】的实数
6)跟踪norm规范化⽅法,源码如下:
private double norm(double x,int i) {
if (Double.isNaN(m_MinArray[i])
|| Utils.eq(m_MaxArray[i], m_MinArray[i])) { return 0; } else {
return (x - m_MinArray[i]) / (m_MaxArray[i] - m_MinArray[i]); } }
规范化距离:(x - m_MinArray[i]) / (m_MaxArray[i] - m_MinArray[i]);
具体的算法伪代码,请查找最近邻分类器的论⽂,我就不贴出来了。
因篇幅问题不能全部显示,请点此查看更多更全内容