牛骨文教育服务平台(让学习变的简单)
博文笔记

[035]Java实现SVM对乳腺癌检测数据分类分析

创建时间:2016-04-26 投稿人: 浏览次数:1527

背景简介:

最近在做SVM分类的学习,查看网上大多相关内容都是SVM原理介绍、推导和用终端命令行使用svm-train,svm-predict。具体数据分析实现很少。通过查找资料发现了一个很好的开发库LIBSVM。LIBSVM– A Library for Support Vector Machines是由the National Science Council of Taiwan发布维护的,对SVM进行了很好的封装,对数据分析更加方便,更主要它收集了大量的用于分类、回归、对标签的数据集,从数据角度对SVM进行深层次的学习,地址:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/。
官方地址:https://www.csie.ntu.edu.tw/~cjlin/libsvm/ 。

准备训练和测试数据:

在LibSVM官网就可以下载到需要的数据集,本例下载的UCI的breast-cancer数据集,训练样本和测试样本的基本格式如下:

<label> <index1>:<value1> <index2>:<value2>

例如:


  1. 4.000000 1:1099510.000000 2:10.000000 3:4.000000 4:3.000000 5:1.000000 6:3.000000 7:3.000000 8:6.000000 9:5.000000 10:2.000000
  2. 4.000000 1:1100524.000000 2:6.000000 3:10.000000 4:10.000000 5:2.000000 6:8.000000 7:10.000000 8:7.000000 9:3.000000 10:3.000000
  3. 4.000000 1:1102573.000000 2:5.000000 3:6.000000 4:5.000000 5:6.000000 6:10.000000 7:1.000000 8:3.000000 9:1.000000 10:1.000000

链接:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#breast-cancer

字段含义:
0.Class: (2 for benign, 4 for malignant)
1. Sample code number: id number
2. Clump Thickness: 1 - 10
3. Uniformity of Cell Size: 1 - 10
4. Uniformity of Cell Shape: 1 - 10
5. Marginal Adhesion: 1 - 10
6. Single Epithelial Cell Size: 1 - 10
7. Bare Nuclei: 1 - 10
8. Bland Chromatin: 1 - 10
9. Normal Nucleoli: 1 - 10
10. Mitoses: 1 - 10

项目部署:

建立JAVA工程,导入LibSVM 的JAR包,要注意还需要导入java文件下的svm_train.java、svm_scale.java和svm_predict.java这三个文件,这三个类其实主要在LibSVM基础上做了进一步封装,把命令行参数转化成了String []类型的函数参数,方便API调用。另外一个svm_tony.java是图形界面可以不导入。

将训练和测试数据文件放在工程下,方便调用。
编写JAVA调用LibSVM API分类代码如下:

import java.io.IOException;

import libsvm.*;

/**JAVA test code for LibSVM
 * @author yangliu
 * @blog http://blog.csdn.net/yangliuy
 * @mail yangliuyx@gmail.com
 */

public class LibSVMTest {

    public static void main(String[] args) throws IOException {
        // TODO Auto-generated method stub
        //Test for svm_train and svm_predict
        //svm_train: 
        //    param: String[], parse result of command line parameter of svm-train
        //    return: String, the directory of modelFile
        //svm_predect:
        //    param: String[], parse result of command line parameter of svm-predict, including the modelfile
        //    return: Double, the accuracy of SVM classification
        String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file
        String modelFile = svm_train.main(trainArgs);
        String[] testArgs = {"UCI-breast-cancer-test", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file
        Double accuracy = svm_predict.main(testArgs);
        System.out.println("SVM Classification is done! The accuracy is " + accuracy);

        //Test for cross validation
        //String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation
        //modelFile = svm_train.main(crossValidationTrainArgs);
        //System.out.print("Cross validation is done! The modelFile is " + modelFile);
    }

}

执行结果:

.*
optimization finished, #iter = 1223
nu = 0.6996186233933985
obj = -271.992875483972, rho = 0.4257786283326366
nSV = 639, nBSV = 222
Total nSV = 639
Accuracy = 69.23076923076923% (27/39) (classification)
SVM Classification is done! The accuracy is 0.6923076923076923

可以看到准确率只有0.69

程序改进:

利用svm_scale.java将数据归一化,归一化数据需要单独存储到UCI-breast-cancer-tra-scale和UCI-breast-cancer-test-scale,再次处理。

svm_scale.java需要修改几个地方代码:
output_target函数修改为:

private String output_target(double value)
    {
        if(y_scaling)
        {
            if(value == y_min)
                value = y_lower;
            else if(value == y_max)
                value = y_upper;
            else
                value = y_lower + (y_upper-y_lower) *
                (value-y_min) / (y_max-y_min);
        }

        System.out.print(value + " ");
        return value + " ";
    }

output函数改为:

private String output(int index, double value)
    {
        /* skip single-valued attribute */
        if(feature_max[index] == feature_min[index])
            return " ";

        if(value == feature_min[index])
            value = lower;
        else if(value == feature_max[index])
            value = upper;
        else
            value = lower + (upper-lower) * 
                (value-feature_min[index])/
                (feature_max[index]-feature_min[index]);

        if(value != 0)
        {
            System.out.print(index + ":" + value + " ");
            new_num_nonzeros++;
            return index + ":" + value + " ";
        }
        return " ";
    }

run需要修改两部分代码:

switch(argv[i-1].charAt(1))
            {
                case "l": lower = Double.parseDouble(argv[i]);  break;
                case "u": upper = Double.parseDouble(argv[i]);  break;
                case "y":
                      y_lower = Double.parseDouble(argv[i]);
                      ++i;
                      y_upper = Double.parseDouble(argv[i]);
                      y_scaling = true;
                      break;
                case "s": save_filename = argv[i];  break;
                case "r": restore_filename = argv[i];   break;
                case "p": save_filePath = argv[i];  break;
                default:
                      System.err.println("unknown option");
                      exit_with_help();
            }

        BufferedWriter bw = FileStream.fileWriterStream(save_filePath,  true);

        /* pass 3: scale */
        while(readline(fp) != null)
        {
            int next_index = 1;
            double target;
            double value;
            String dataLine = "";

            StringTokenizer st = new StringTokenizer(line," 	

f:");
            target = Double.parseDouble(st.nextToken());
            dataLine = output_target(target);
            while(st.hasMoreElements())
            {
                index = Integer.parseInt(st.nextToken());
                value = Double.parseDouble(st.nextToken());
                for (i = next_index; i<index; i++)
                    dataLine += output(i, 0);
                dataLine += output(index, value);
                next_index = index + 1;
            }

            for(i=next_index;i<= max_index;i++)
                output(i, 0);
            System.out.print("
");
            dataLine += "
";
            FileStream.writerData(bw, dataLine);
        }
        if (new_num_nonzeros > num_nonzeros)
            System.err.print(
             "WARNING: original #nonzeros " + num_nonzeros+"
"
            +"         new      #nonzeros " + new_num_nonzeros+"
"
            +"Use -l 0 if many original feature values are zeros
");

        fp.close();
        bw.close();

新建FileStream 类,用于数据存储

package com.yuan.util;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;

public class FileStream {


    public static BufferedWriter fileWriterStream(String fileName, boolean append){
        BufferedWriter fp_save = null;
        try {
            fp_save = new BufferedWriter(new FileWriter(fileName, append));
        } catch(IOException e) {
            System.err.println("can"t open file " + fileName);
            System.exit(1);
        }
        return fp_save;
    }

    public static void writerData(BufferedWriter bw, String data) throws IOException{
        bw.write(data);
    }
}

修改SVMClassifierTest类

// TODO Auto-generated method stub  
            //Test for svm_train and svm_predict  
            //svm_train:   
            //    param: String[], parse result of command line parameter of svm-train  
            //    return: String, the directory of modelFile  
            //svm_predect:  
            //    param: String[], parse result of command line parameter of svm-predict, including the modelfile  
            //    return: Double, the accuracy of SVM classification  
            String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file  
            svm_scale.main(new String[]{"-p", "UCI-breast-cancer-tra-scale", "UCI-breast-cancer-tra"});//训练数据归一化存储
            svm_scale.main(new String[]{"-p", "UCI-breast-cancer-test-scale", "UCI-breast-cancer-test"});//测试数据归一化存储

            String[] scaleTrainArgs = {"UCI-breast-cancer-tra-scale"};//directory of training file  
            String modelFile = svm_train.main(scaleTrainArgs); 

            String[] testArgs = {"UCI-breast-cancer-test-scale", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file  
            Double accuracy = svm_predict.main(testArgs);  
            System.out.println("SVM Classification is done! The accuracy is " + accuracy);  

            //Test for cross validation  
            //String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation  
            //modelFile = svm_train.main(crossValidationTrainArgs);  
            //System.out.print("Cross validation is done! The modelFile is " + modelFile);

结果:

*
optimization finished, #iter = 97
nu = 0.0711047842614367
obj = -78.46733678185721, rho = -0.9253740588830286
nSV = 99, nBSV = 83
Total nSV = 99
Accuracy = 89.74358974358975% (70/78) (classification)
SVM Classification is done! The accuracy is 0.8974358974358975

可以看到准确率大幅度提高。
至此LIBSVM的简单调用及改进就完成了。

引用:

Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector machines. ACM Transactions on Intelligent Systems and Technology, 2:27:1–27:27, 2011. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm.

声明:该文观点仅代表作者本人,牛骨文系教育信息发布平台,牛骨文仅提供信息存储空间服务。