[035]Java实现SVM对乳腺癌检测数据分类分析
背景简介:
最近在做SVM分类的学习,查看网上大多相关内容都是SVM原理介绍、推导和用终端命令行使用svm-train,svm-predict。具体数据分析实现很少。通过查找资料发现了一个很好的开发库LIBSVM。LIBSVM– A Library for Support Vector Machines是由the National Science Council of Taiwan发布维护的,对SVM进行了很好的封装,对数据分析更加方便,更主要它收集了大量的用于分类、回归、对标签的数据集,从数据角度对SVM进行深层次的学习,地址:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/。
官方地址:https://www.csie.ntu.edu.tw/~cjlin/libsvm/ 。
准备训练和测试数据:
在LibSVM官网就可以下载到需要的数据集,本例下载的UCI的breast-cancer数据集,训练样本和测试样本的基本格式如下:
<label> <index1>:<value1> <index2>:<value2>
例如:
- 4.000000 1:1099510.000000 2:10.000000 3:4.000000 4:3.000000 5:1.000000 6:3.000000 7:3.000000 8:6.000000 9:5.000000 10:2.000000
- 4.000000 1:1100524.000000 2:6.000000 3:10.000000 4:10.000000 5:2.000000 6:8.000000 7:10.000000 8:7.000000 9:3.000000 10:3.000000
- 4.000000 1:1102573.000000 2:5.000000 3:6.000000 4:5.000000 5:6.000000 6:10.000000 7:1.000000 8:3.000000 9:1.000000 10:1.000000
链接:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#breast-cancer
字段含义:
0.Class: (2 for benign, 4 for malignant)
1. Sample code number: id number
2. Clump Thickness: 1 - 10
3. Uniformity of Cell Size: 1 - 10
4. Uniformity of Cell Shape: 1 - 10
5. Marginal Adhesion: 1 - 10
6. Single Epithelial Cell Size: 1 - 10
7. Bare Nuclei: 1 - 10
8. Bland Chromatin: 1 - 10
9. Normal Nucleoli: 1 - 10
10. Mitoses: 1 - 10
项目部署:
建立JAVA工程,导入LibSVM 的JAR包,要注意还需要导入java文件下的svm_train.java、svm_scale.java和svm_predict.java这三个文件,这三个类其实主要在LibSVM基础上做了进一步封装,把命令行参数转化成了String []类型的函数参数,方便API调用。另外一个svm_tony.java是图形界面可以不导入。
将训练和测试数据文件放在工程下,方便调用。
编写JAVA调用LibSVM API分类代码如下:
import java.io.IOException;
import libsvm.*;
/**JAVA test code for LibSVM
* @author yangliu
* @blog http://blog.csdn.net/yangliuy
* @mail yangliuyx@gmail.com
*/
public class LibSVMTest {
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
//Test for svm_train and svm_predict
//svm_train:
// param: String[], parse result of command line parameter of svm-train
// return: String, the directory of modelFile
//svm_predect:
// param: String[], parse result of command line parameter of svm-predict, including the modelfile
// return: Double, the accuracy of SVM classification
String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file
String modelFile = svm_train.main(trainArgs);
String[] testArgs = {"UCI-breast-cancer-test", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file
Double accuracy = svm_predict.main(testArgs);
System.out.println("SVM Classification is done! The accuracy is " + accuracy);
//Test for cross validation
//String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation
//modelFile = svm_train.main(crossValidationTrainArgs);
//System.out.print("Cross validation is done! The modelFile is " + modelFile);
}
}
执行结果:
.*
optimization finished, #iter = 1223
nu = 0.6996186233933985
obj = -271.992875483972, rho = 0.4257786283326366
nSV = 639, nBSV = 222
Total nSV = 639
Accuracy = 69.23076923076923% (27/39) (classification)
SVM Classification is done! The accuracy is 0.6923076923076923
可以看到准确率只有0.69
程序改进:
利用svm_scale.java将数据归一化,归一化数据需要单独存储到UCI-breast-cancer-tra-scale和UCI-breast-cancer-test-scale,再次处理。
svm_scale.java需要修改几个地方代码:
output_target函数修改为:
private String output_target(double value)
{
if(y_scaling)
{
if(value == y_min)
value = y_lower;
else if(value == y_max)
value = y_upper;
else
value = y_lower + (y_upper-y_lower) *
(value-y_min) / (y_max-y_min);
}
System.out.print(value + " ");
return value + " ";
}
output函数改为:
private String output(int index, double value)
{
/* skip single-valued attribute */
if(feature_max[index] == feature_min[index])
return " ";
if(value == feature_min[index])
value = lower;
else if(value == feature_max[index])
value = upper;
else
value = lower + (upper-lower) *
(value-feature_min[index])/
(feature_max[index]-feature_min[index]);
if(value != 0)
{
System.out.print(index + ":" + value + " ");
new_num_nonzeros++;
return index + ":" + value + " ";
}
return " ";
}
run需要修改两部分代码:
switch(argv[i-1].charAt(1))
{
case "l": lower = Double.parseDouble(argv[i]); break;
case "u": upper = Double.parseDouble(argv[i]); break;
case "y":
y_lower = Double.parseDouble(argv[i]);
++i;
y_upper = Double.parseDouble(argv[i]);
y_scaling = true;
break;
case "s": save_filename = argv[i]; break;
case "r": restore_filename = argv[i]; break;
case "p": save_filePath = argv[i]; break;
default:
System.err.println("unknown option");
exit_with_help();
}
BufferedWriter bw = FileStream.fileWriterStream(save_filePath, true);
/* pass 3: scale */
while(readline(fp) != null)
{
int next_index = 1;
double target;
double value;
String dataLine = "";
StringTokenizer st = new StringTokenizer(line,"
f:");
target = Double.parseDouble(st.nextToken());
dataLine = output_target(target);
while(st.hasMoreElements())
{
index = Integer.parseInt(st.nextToken());
value = Double.parseDouble(st.nextToken());
for (i = next_index; i<index; i++)
dataLine += output(i, 0);
dataLine += output(index, value);
next_index = index + 1;
}
for(i=next_index;i<= max_index;i++)
output(i, 0);
System.out.print("
");
dataLine += "
";
FileStream.writerData(bw, dataLine);
}
if (new_num_nonzeros > num_nonzeros)
System.err.print(
"WARNING: original #nonzeros " + num_nonzeros+"
"
+" new #nonzeros " + new_num_nonzeros+"
"
+"Use -l 0 if many original feature values are zeros
");
fp.close();
bw.close();
新建FileStream 类,用于数据存储
package com.yuan.util;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
public class FileStream {
public static BufferedWriter fileWriterStream(String fileName, boolean append){
BufferedWriter fp_save = null;
try {
fp_save = new BufferedWriter(new FileWriter(fileName, append));
} catch(IOException e) {
System.err.println("can"t open file " + fileName);
System.exit(1);
}
return fp_save;
}
public static void writerData(BufferedWriter bw, String data) throws IOException{
bw.write(data);
}
}
修改SVMClassifierTest类
// TODO Auto-generated method stub
//Test for svm_train and svm_predict
//svm_train:
// param: String[], parse result of command line parameter of svm-train
// return: String, the directory of modelFile
//svm_predect:
// param: String[], parse result of command line parameter of svm-predict, including the modelfile
// return: Double, the accuracy of SVM classification
String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file
svm_scale.main(new String[]{"-p", "UCI-breast-cancer-tra-scale", "UCI-breast-cancer-tra"});//训练数据归一化存储
svm_scale.main(new String[]{"-p", "UCI-breast-cancer-test-scale", "UCI-breast-cancer-test"});//测试数据归一化存储
String[] scaleTrainArgs = {"UCI-breast-cancer-tra-scale"};//directory of training file
String modelFile = svm_train.main(scaleTrainArgs);
String[] testArgs = {"UCI-breast-cancer-test-scale", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file
Double accuracy = svm_predict.main(testArgs);
System.out.println("SVM Classification is done! The accuracy is " + accuracy);
//Test for cross validation
//String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation
//modelFile = svm_train.main(crossValidationTrainArgs);
//System.out.print("Cross validation is done! The modelFile is " + modelFile);
结果:
*
optimization finished, #iter = 97
nu = 0.0711047842614367
obj = -78.46733678185721, rho = -0.9253740588830286
nSV = 99, nBSV = 83
Total nSV = 99
Accuracy = 89.74358974358975% (70/78) (classification)
SVM Classification is done! The accuracy is 0.8974358974358975
可以看到准确率大幅度提高。
至此LIBSVM的简单调用及改进就完成了。
引用:
Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support vector machines. ACM Transactions on Intelligent Systems and Technology, 2:27:1–27:27, 2011. Software available at http://www.csie.ntu.edu.tw/~cjlin/libsvm.