本文整理汇总了Java中smile.data.AttributeDataset类的典型用法代码示例。如果您正苦于以下问题:Java AttributeDataset类的具体用法?Java AttributeDataset怎么用?Java AttributeDataset使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
AttributeDataset类属于smile.data包,在下文中一共展示了AttributeDataset类的40个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: main
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
public static void main(String[] args) {
try {
RESParser parser = new RESParser();
AttributeDataset data = parser.parse("RES", smile.data.parser.IOUtils.getTestDataFile("microarray/all_aml_test.res"));
double[][] x = data.toArray(new double[data.size()][]);
String[] genes = data.toArray(new String[data.size()]);
String[] arrays = new String[data.attributes().length];
for (int i = 0; i < arrays.length; i++) {
arrays[i] = data.attributes()[i].getName();
}
JFrame frame = new JFrame("Heatmap");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.setLocationRelativeTo(null);
frame.getContentPane().add(Heatmap.plot(genes, arrays, x, Palette.jet(256)));
frame.setVisible(true);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:22,
代码来源:HeatmapDemo.java
示例2: testTest_3args_1
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of test method, of class Validation.
*/
@Test
public void testTest_3args_1() {
System.out.println("test");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
LDA lda = new LDA(x, y);
double accuracy = Validation.test(lda, testx, testy);
System.out.println("accuracy = " + accuracy);
assertEquals(0.8724, accuracy, 1E-4);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:26,
代码来源:ValidationTest.java
示例3: testTest_4args_1
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of test method, of class Validation.
*/
@Test
public void testTest_4args_1() {
System.out.println("test");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
LDA lda = new LDA(x, y);
ClassificationMeasure[] measures = {new Accuracy()};
double[] accuracy = Validation.test(lda, testx, testy, measures);
System.out.println("accuracy = " + accuracy[0]);
assertEquals(0.8724, accuracy[0], 1E-4);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:27,
代码来源:ValidationTest.java
示例4: testLoocv_3args_1
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of loocv method, of class Validation.
*/
@Test
public void testLoocv_3args_1() {
System.out.println("loocv");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
ClassifierTrainer<double[]> trainer = new LDA.Trainer();
double accuracy = Validation.loocv(trainer, x, y);
System.out.println("LOOCV accuracy = " + accuracy);
assertEquals(0.8533, accuracy, 1E-4);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:ValidationTest.java
示例5: testLoocv_3args_2
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of loocv method, of class Validation.
*/
@Test
public void testLoocv_3args_2() {
System.out.println("loocv");
ArffParser parser = new ArffParser();
parser.setResponseIndex(6);
try {
AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
double[] y = data.toArray(new double[data.size()]);
double[][] x = data.toArray(new double[data.size()][]);
Math.standardize(x);
RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
trainer.setNumCenters(20);
double rmse = Validation.loocv(trainer, x, y);
System.out.println("RMSE = " + rmse);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:ValidationTest.java
示例6: testCv_4args_1
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of cv method, of class Validation.
*/
@Test
public void testCv_4args_1() {
System.out.println("cv");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
ClassifierTrainer<double[]> trainer = new LDA.Trainer();
double accuracy = Validation.cv(10, trainer, x, y);
System.out.println("10-fold CV accuracy = " + accuracy);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:22,
代码来源:ValidationTest.java
示例7: testCv_4args_2
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of cv method, of class Validation.
*/
@Test
public void testCv_4args_2() {
System.out.println("cv");
ArffParser parser = new ArffParser();
parser.setResponseIndex(6);
try {
AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
double[] y = data.toArray(new double[data.size()]);
double[][] x = data.toArray(new double[data.size()][]);
Math.standardize(x);
RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
trainer.setNumCenters(20);
double rmse = Validation.cv(10, trainer, x, y);
System.out.println("RMSE = " + rmse);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:ValidationTest.java
示例8: testCv_5args_1
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of cv method, of class Validation.
*/
@Test
public void testCv_5args_1() {
System.out.println("cv");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
ClassifierTrainer<double[]> trainer = new LDA.Trainer();
ClassificationMeasure[] measures = {new Accuracy()};
double[] results = Validation.cv(10, trainer, x, y, measures);
for (int i = 0; i < measures.length; i++) {
System.out.println(measures[i] + " = " + results[i]);
}
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:24,
代码来源:ValidationTest.java
示例9: testCv_5args_2
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of cv method, of class Validation.
*/
@Test
public void testCv_5args_2() {
System.out.println("cv");
ArffParser parser = new ArffParser();
parser.setResponseIndex(6);
try {
AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
double[] y = data.toArray(new double[data.size()]);
double[][] x = data.toArray(new double[data.size()][]);
Math.standardize(x);
RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
trainer.setNumCenters(20);
RegressionMeasure[] measures = {new RMSE(), new MeanAbsoluteDeviation()};
double[] results = Validation.cv(10, trainer, x, y, measures);
System.out.println("RMSE = " + results[0]);
System.out.println("MAD = " + results[1]);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:25,
代码来源:ValidationTest.java
示例10: testBootstrap_4args_1
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of bootstrap method, of class Validation.
*/
@Test
public void testBootstrap_4args_1() {
System.out.println("bootstrap");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
ClassifierTrainer<double[]> trainer = new LDA.Trainer();
double[] accuracy = Validation.bootstrap(100, trainer, x, y);
System.out.println("100-fold bootstrap accuracy average = " + Math.mean(accuracy));
System.out.println("100-fold bootstrap accuracy std.dev = " + Math.sd(accuracy));
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:ValidationTest.java
示例11: testBootstrap_4args_2
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of bootstrap method, of class Validation.
*/
@Test
public void testBootstrap_4args_2() {
System.out.println("bootstrap");
ArffParser parser = new ArffParser();
parser.setResponseIndex(6);
try {
AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
double[] y = data.toArray(new double[data.size()]);
double[][] x = data.toArray(new double[data.size()][]);
Math.standardize(x);
RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
trainer.setNumCenters(20);
double[] rmse = Validation.bootstrap(100, trainer, x, y);
System.out.println("100-fold bootstrap RMSE average = " + Math.mean(rmse));
System.out.println("100-fold bootstrap RMSE std.dev = " + Math.sd(rmse));
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:24,
代码来源:ValidationTest.java
示例12: testBootstrap_5args_2
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of bootstrap method, of class Validation.
*/
@Test
public void testBootstrap_5args_2() {
System.out.println("bootstrap");
ArffParser parser = new ArffParser();
parser.setResponseIndex(6);
try {
AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
double[] y = data.toArray(new double[data.size()]);
double[][] x = data.toArray(new double[data.size()][]);
Math.standardize(x);
RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
trainer.setNumCenters(20);
RegressionMeasure[] measures = {new RMSE(), new MeanAbsoluteDeviation()};
double[][] results = Validation.bootstrap(100, trainer, x, y, measures);
System.out.println("100-fold bootstrap RMSE average = " + Math.mean(results[0]));
System.out.println("100-fold bootstrap RMSE std.dev = " + Math.sd(results[0]));
System.out.println("100-fold bootstrap AbsoluteDeviation average = " + Math.mean(results[1]));
System.out.println("100-fold bootstrap AbsoluteDeviation std.dev = " + Math.sd(results[1]));
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:27,
代码来源:ValidationTest.java
示例13: testAttributes
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of attributes method, of class DateFeature.
*/
@Test
public void testAttributes() {
System.out.println("attributes");
try {
ArffParser parser = new ArffParser();
AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/date.arff"));
DateFeature.Type[] features = {DateFeature.Type.YEAR, DateFeature.Type.MONTH, DateFeature.Type.DAY_OF_MONTH, DateFeature.Type.DAY_OF_WEEK, DateFeature.Type.HOURS, DateFeature.Type.MINUTES, DateFeature.Type.SECONDS};
DateFeature df = new DateFeature(features);
Attribute[] attributes = df.attributes();
assertEquals(features.length, attributes.length);
for (int i = 0; i < attributes.length; i++) {
System.out.println(attributes[i]);
if (i == 1 || i == 3) {
assertEquals(Attribute.Type.NOMINAL, attributes[i].getType());
} else {
assertEquals(Attribute.Type.NUMERIC, attributes[i].getType());
}
}
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:27,
代码来源:DateFeatureTest.java
示例14: testRank
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of rank method, of class SumSquaresRatio.
*/
@Test
public void testRank() {
System.out.println("rank");
try {
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
SumSquaresRatio ssr = new SumSquaresRatio();
double[] ratio = ssr.rank(x, y);
assertEquals(4, ratio.length);
assertEquals(1.6226463, ratio[0], 1E-7);
assertEquals(0.6444144, ratio[1], 1E-7);
assertEquals(16.0412833, ratio[2], 1E-7);
assertEquals(13.0520327, ratio[3], 1E-7);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:25,
代码来源:SumSquaresRatioTest.java
示例15: CoverTreeSpeedTest
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
public CoverTreeSpeedTest() {
long start = System.currentTimeMillis();
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
x = train.toArray(new double[train.size()][]);
testx = test.toArray(new double[test.size()][]);
} catch (Exception ex) {
System.err.println(ex);
}
double time = (System.currentTimeMillis() - start) / 1000.0;
System.out.format("Loading data: %.2fs%n", time);
start = System.currentTimeMillis();
coverTree = new CoverTree<>(x, new EuclideanDistance());
time = (System.currentTimeMillis() - start) / 1000.0;
System.out.format("Building cover tree: %.2fs%n", time);
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:CoverTreeSpeedTest.java
示例16: LSHTest
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
public LSHTest() {
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
x = train.toArray(new double[train.size()][]);
testx = test.toArray(new double[test.size()][]);
} catch (Exception ex) {
System.err.println(ex);
}
naive = new LinearSearch<>(x, new EuclideanDistance());
lsh = new LSH<>(x, x);
/*
lsh = new LSH<double[]>(256, 100, 3, 4.0);
for (double[] xi : x) {
lsh.put(xi, xi);
}
*
*/
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:24,
代码来源:LSHTest.java
示例17: testUSPSNystrom
点赞 3
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class SpectralClustering.
*/
@Test
public void testUSPSNystrom() {
System.out.println("USPS Nystrom approximation");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
SpectralClustering spectral = new SpectralClustering(x, 10, 100, 8.0);
AdjustedRandIndex ari = new AdjustedRandIndex();
RandIndex rand = new RandIndex();
double r = rand.measure(y, spectral.getClusterLabel());
double r2 = ari.measure(y, spectral.getClusterLabel());
System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
assertTrue(r > 0.8);
assertTrue(r2 > 0.35);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:28,
代码来源:SpectralClusteringTest.java
示例18: testUSPS
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class RDA.
*/
@Test
public void testUSPS() {
System.out.println("USPS");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
RDA rda = new RDA(x, y, 0.7);
int error = 0;
for (int i = 0; i < testx.length; i++) {
if (rda.predict(testx[i]) != testy[i]) {
error++;
}
}
System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
assertEquals(235, error);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:33,
代码来源:RDATest.java
示例19: testWeather
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class AdaBoost.
*/
@Test
public void testWeather() {
System.out.println("Weather");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset weather = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/weather.nominal.arff"));
double[][] x = weather.toArray(new double[weather.size()][]);
int[] y = weather.toArray(new int[weather.size()]);
int n = x.length;
LOOCV loocv = new LOOCV(n);
int error = 0;
for (int i = 0; i < n; i++) {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
AdaBoost forest = new AdaBoost(weather.attributes(), trainx, trainy, 200, 4);
if (y[loocv.test[i]] != forest.predict(x[loocv.test[i]]))
error++;
}
System.out.println("AdaBoost error = " + error);
assertEquals(3, error);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:32,
代码来源:AdaBoostTest.java
示例20: testSegment
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class SVM.
*/
@Test
public void testSegment() {
System.out.println("Segment");
ArffParser parser = new ArffParser();
parser.setResponseIndex(19);
try {
AttributeDataset train = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/segment-challenge.arff"));
AttributeDataset test = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/segment-test.arff"));
System.out.println(train.size() + " " + test.size());
double[][] x = train.toArray(new double[0][]);
int[] y = train.toArray(new int[0]);
double[][] testx = test.toArray(new double[0][]);
int[] testy = test.toArray(new int[0]);
SVM<double[]> svm = new SVM<>(new GaussianKernel(8.0), 5.0, Math.max(y) + 1, SVM.Multiclass.ONE_VS_ALL);
svm.learn(x, y);
svm.finish();
int error = 0;
for (int i = 0; i < testx.length; i++) {
if (svm.predict(testx[i]) != testy[i]) {
error++;
}
}
System.out.format("Segment error rate = %.2f%%%n", 100.0 * error / testx.length);
assertTrue(error < 70);
} catch (Exception ex) {
ex.printStackTrace();
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:36,
代码来源:SVMTest.java
示例21: test
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
public void test(String dataset, String url, int response) {
System.out.println(dataset);
ArffParser parser = new ArffParser();
parser.setResponseIndex(response);
try {
AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile(url));
double[] datay = data.toArray(new double[data.size()]);
double[][] datax = data.toArray(new double[data.size()][]);
int n = datax.length;
int k = 10;
CrossValidation cv = new CrossValidation(n, k);
double rss = 0.0;
double ad = 0.0;
for (int i = 0; i < k; i++) {
double[][] trainx = Math.slice(datax, cv.train[i]);
double[] trainy = Math.slice(datay, cv.train[i]);
double[][] testx = Math.slice(datax, cv.test[i]);
double[] testy = Math.slice(datay, cv.test[i]);
RegressionTree tree = new RegressionTree(data.attributes(), trainx, trainy, 20);
for (int j = 0; j < testx.length; j++) {
double r = testy[j] - tree.predict(testx[j]);
rss += r * r;
ad += Math.abs(r);
}
}
System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss/n), ad/n);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:36,
代码来源:RegressionTreeTest.java
示例22: testLearn
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class QDA.
*/
@Test
public void testLearn() {
System.out.println("learn");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
int n = x.length;
LOOCV loocv = new LOOCV(n);
int error = 0;
double[] posteriori = new double[3];
for (int i = 0; i < n; i++) {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
QDA qda = new QDA(trainx, trainy);
if (y[loocv.test[i]] != qda.predict(x[loocv.test[i]], posteriori))
error++;
//System.out.println(posteriori[0]+"\t"+posteriori[1]+"\t"+posteriori[2]);
}
System.out.println("QDA error = " + error);
assertEquals(4, error);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:QDATest.java
示例23: testUSPS
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class LogisticRegression.
*/
@Test
public void testUSPS() {
System.out.println("USPS");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
LogisticRegression logit = new LogisticRegression(x, y, 0.3, 1E-3, 1000);
int error = 0;
for (int i = 0; i < testx.length; i++) {
if (logit.predict(testx[i]) != testy[i]) {
error++;
}
}
System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
assertEquals(188, error);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:33,
代码来源:LogisticRegressionTest.java
示例24: testIris
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class LogisticRegression.
*/
@Test
public void testIris() {
System.out.println("Iris");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
int n = x.length;
LOOCV loocv = new LOOCV(n);
int error = 0;
for (int i = 0; i < n; i++) {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
LogisticRegression logit = new LogisticRegression(trainx, trainy);
if (y[loocv.test[i]] != logit.predict(x[loocv.test[i]]))
error++;
}
System.out.println("Logistic Regression error = " + error);
assertEquals(3, error);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:32,
代码来源:LogisticRegressionTest.java
示例25: getClassifier
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
@Override
public Classifier<double[]> getClassifier() throws IOException, ParseException {
if (classifier != null) {
return classifier;
}
DelimitedTextParser parser = new DelimitedTextParser();
parser.setDelimiter(",");
parser.setResponseIndex(new NominalAttribute("shading", new String[] { "0", "1", "2" }), 0);
AttributeDataset dataset = parser.parse("data/train-out-" + getFileSuffix());
double[][] vectors = dataset.toArray(new double[dataset.size()][]);
int[] label = dataset.toArray(new int[dataset.size()]);
classifier = KNN.learn(vectors, label, 5);
return classifier;
}
开发者ID:tomwhite,
项目名称:set-game,
代码行数:15,
代码来源:FindCardShadingFeatures.java
示例26: testUSPS
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class RandomForest.
*/
@Test
public void testUSPS() {
System.out.println("USPS");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
RandomForest forest = new RandomForest(x, y, 200);
int error = 0;
for (int i = 0; i < testx.length; i++) {
if (forest.predict(testx[i]) != testy[i]) {
error++;
}
}
System.out.println("USPS error = " + error);
System.out.format("USPS OOB error rate = %.2f%%%n", 100.0 * forest.error());
System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
assertTrue(error <= 225);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:RandomForestTest.java
示例27: testUSPS
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class LDA.
*/
@Test
public void testUSPS() {
System.out.println("USPS");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
LDA lda = new LDA(x, y);
int error = 0;
for (int i = 0; i < testx.length; i++) {
if (lda.predict(testx[i]) != testy[i]) {
error++;
}
}
System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
assertEquals(256, error);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:33,
代码来源:LDATest.java
示例28: testIris
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class AdaBoost.
*/
@Test
public void testIris() {
System.out.println("Iris");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
for (int i = 0; i < y.length; i++) {
if (y[i] != 0) y[i] = 1;
}
int n = x.length;
LOOCV loocv = new LOOCV(n);
int error = 0;
for (int i = 0; i < n; i++) {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
AdaBoost forest = new AdaBoost(iris.attributes(), trainx, trainy, 200);
if (y[loocv.test[i]] != forest.predict(x[loocv.test[i]]))
error++;
}
System.out.println("AdaBoost error = " + error);
assertEquals(0, error);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:AdaBoostTest.java
示例29: testUSPS
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class GradientTreeBoost.
*/
@Test
public void testUSPS() {
System.out.println("USPS");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
GradientTreeBoost boost = new GradientTreeBoost(train.attributes(), x, y, 100);
int error = 0;
for (int i = 0; i < testx.length; i++) {
if (boost.predict(testx[i]) != testy[i]) {
error++;
}
}
System.out.format("Gradient Tree Boost error rate = %.2f%%%n", 100.0 * error / testx.length);
double[] accuracy = boost.test(testx, testy);
for (int i = 1; i <= accuracy.length; i++) {
System.out.format("%d trees accuracy = %.2f%%%n", i, 100.0 * accuracy[i-1]);
}
double[] importance = boost.importance();
int[] index = QuickSort.sort(importance);
for (int i = importance.length; i-- > 0; ) {
System.out.format("%s importance is %.4f%n", train.attributes()[index[i]], importance[i]);
}
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:43,
代码来源:GradientTreeBoostTest.java
示例30: testIris2
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of predict method, of class GradientTreeBoost.
*/
@Test
public void testIris2() {
System.out.println("Iris binary");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
for (int i = 0; i < y.length; i++) {
if (y[i] == 2) {
y[i] = 1;
} else {
y[i] = 0;
}
}
int n = x.length;
LOOCV loocv = new LOOCV(n);
int error = 0;
for (int i = 0; i < n; i++) {
double[][] trainx = Math.slice(x, loocv.train[i]);
int[] trainy = Math.slice(y, loocv.train[i]);
GradientTreeBoost boost = new GradientTreeBoost(iris.attributes(), trainx, trainy, 100);
if (y[loocv.test[i]] != boost.predict(x[loocv.test[i]]))
error++;
}
System.out.println("Gradient Tree Boost error = " + error);
//assertEquals(6, error);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:40,
代码来源:GradientTreeBoostTest.java
示例31: testParse
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of parse method, of class DelimitedTextParser.
*/
@Test
public void testParse() throws Exception {
System.out.println("parse");
try {
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
AttributeDataset usps = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
double[][] x = usps.toArray(new double[usps.size()][]);
int[] y = usps.toArray(new int[usps.size()]);
assertEquals(Attribute.Type.NOMINAL, usps.responseAttribute().getType());
for (Attribute attribute : usps.attributes()) {
assertEquals(Attribute.Type.NUMERIC, attribute.getType());
}
assertEquals(7291, usps.size());
assertEquals(256, usps.attributes().length);
assertEquals("6", usps.responseAttribute().toString(y[0]));
assertEquals("5", usps.responseAttribute().toString(y[1]));
assertEquals("4", usps.responseAttribute().toString(y[2]));
assertEquals(-1.0000, x[0][6], 1E-7);
assertEquals(-0.6310, x[0][7], 1E-7);
assertEquals(0.8620, x[0][8], 1E-7);
assertEquals("1", usps.responseAttribute().toString(y[7290]));
assertEquals(-1.0000, x[7290][4], 1E-7);
assertEquals(-0.1080, x[7290][5], 1E-7);
assertEquals(1.0000, x[7290][6], 1E-7);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:37,
代码来源:DelimitedTextParserTest.java
示例32: test
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
public void test(GradientTreeBoost.Loss loss, String dataset, String url, int response) {
System.out.println(dataset + "\t" + loss);
ArffParser parser = new ArffParser();
parser.setResponseIndex(response);
try {
AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile(url));
double[] datay = data.toArray(new double[data.size()]);
double[][] datax = data.toArray(new double[data.size()][]);
int n = datax.length;
int k = 10;
CrossValidation cv = new CrossValidation(n, k);
double rss = 0.0;
double ad = 0.0;
for (int i = 0; i < k; i++) {
double[][] trainx = Math.slice(datax, cv.train[i]);
double[] trainy = Math.slice(datay, cv.train[i]);
double[][] testx = Math.slice(datax, cv.test[i]);
double[] testy = Math.slice(datay, cv.test[i]);
GradientTreeBoost boost = new GradientTreeBoost(data.attributes(), trainx, trainy, loss, 100, 6, 0.05, 0.7);
for (int j = 0; j < testx.length; j++) {
double r = testy[j] - boost.predict(testx[j]);
ad += Math.abs(r);
rss += r * r;
}
}
System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss/n), ad/n);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:36,
代码来源:GradientTreeBoostTest.java
示例33: testParse
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of parse method, of class GCTParser.
*/
@Test
public void testParse() throws Exception {
System.out.println("parse");
GCTParser parser = new GCTParser();
try {
AttributeDataset data = parser.parse("GCT", smile.data.parser.IOUtils.getTestDataFile("microarray/allaml.dataset.gct"));
double[][] x = data.toArray(new double[data.size()][]);
String[] id = data.toArray(new String[data.size()]);
for (Attribute attribute : data.attributes()) {
assertEquals(Attribute.Type.NUMERIC, attribute.getType());
System.out.println(attribute.getName());
}
assertEquals(12564, data.size());
assertEquals(48, data.attributes().length);
assertEquals("AFFX-MurIL2_at", id[0]);
assertEquals(-161.8, x[0][0], 1E-7);
assertEquals(-231.0, x[0][1], 1E-7);
assertEquals(-279.0, x[0][2], 1E-7);
assertEquals("128_at", id[12563]);
assertEquals(95.0, x[12563][45], 1E-7);
assertEquals(108.0, x[12563][46], 1E-7);
assertEquals(346.0, x[12563][47], 1E-7);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:GCTParserTest.java
示例34: testUSPS
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class AdaBoost.
*/
@Test
public void testUSPS() {
System.out.println("USPS");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
for (int i = 0; i < y.length; i++) {
if (y[i] != 0) y[i] = 1;
}
for (int i = 0; i < testy.length; i++) {
if (testy[i] != 0) testy[i] = 1;
}
AdaBoost forest = new AdaBoost(x, y, 100, 6);
int error = 0;
for (int i = 0; i < testx.length; i++) {
if (forest.predict(testx[i]) != testy[i]) {
error++;
}
}
System.out.println("AdaBoost error = " + error);
System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
assertTrue(error <= 25);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:41,
代码来源:AdaBoostTest.java
示例35: testParse
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of parse method, of class PCLParser.
*/
@Test
public void testParse() throws Exception {
System.out.println("parse");
PCLParser parser = new PCLParser();
try {
AttributeDataset data = parser.parse("PCL", smile.data.parser.IOUtils.getTestDataFile("microarray/Dunham2002.pcl"));
double[][] x = data.toArray(new double[data.size()][]);
String[] id = data.toArray(new String[data.size()]);
for (Attribute attribute : data.attributes()) {
assertEquals(Attribute.Type.NUMERIC, attribute.getType());
System.out.println(attribute.getName());
}
assertEquals(6694, data.size());
assertEquals(16, data.attributes().length);
assertEquals("YKR005C", id[0]);
assertEquals(-0.43, x[0][0], 1E-7);
assertEquals(-0.47, x[0][1], 1E-7);
assertEquals(-0.39, x[0][2], 1E-7);
assertEquals("YKR004C", id[6693]);
assertEquals(0.03, x[6693][13], 1E-7);
assertEquals(-0.53, x[6693][14], 1E-7);
assertEquals(0.3, x[6693][15], 1E-7);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:PCLParserTest.java
示例36: testParseWeather
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of parse method, of class ArffParser.
*/
@Test
public void testParseWeather() throws Exception {
System.out.println("weather");
try {
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset weather = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/weather.nominal.arff"));
double[][] x = weather.toArray(new double[weather.size()][]);
int[] y = weather.toArray(new int[weather.size()]);
assertEquals(Attribute.Type.NOMINAL, weather.responseAttribute().getType());
for (Attribute attribute : weather.attributes()) {
assertEquals(Attribute.Type.NOMINAL, attribute.getType());
}
assertEquals(14, weather.size());
assertEquals(4, weather.attributes().length);
assertEquals("no", weather.responseAttribute().toString(y[0]));
assertEquals("no", weather.responseAttribute().toString(y[1]));
assertEquals("yes", weather.responseAttribute().toString(y[2]));
assertEquals("sunny", weather.attributes()[0].toString(x[0][0]));
assertEquals("hot", weather.attributes()[1].toString(x[0][1]));
assertEquals("high", weather.attributes()[2].toString(x[0][2]));
assertEquals("FALSE", weather.attributes()[3].toString(x[0][3]));
assertEquals("no", weather.responseAttribute().toString(y[13]));
assertEquals("rainy", weather.attributes()[0].toString(x[13][0]));
assertEquals("mild", weather.attributes()[1].toString(x[13][1]));
assertEquals("high", weather.attributes()[2].toString(x[13][2]));
assertEquals("TRUE", weather.attributes()[3].toString(x[13][3]));
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:38,
代码来源:ArffParserTest.java
示例37: testParseIris
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of parse method, of class ArffParser.
*/
@Test
public void testParseIris() throws Exception {
System.out.println("iris");
try {
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
int[] y = iris.toArray(new int[iris.size()]);
assertEquals(Attribute.Type.NOMINAL, iris.responseAttribute().getType());
for (Attribute attribute : iris.attributes()) {
assertEquals(Attribute.Type.NUMERIC, attribute.getType());
}
assertEquals(150, iris.size());
assertEquals(4, iris.attributes().length);
assertEquals("Iris-setosa", iris.responseAttribute().toString(y[0]));
assertEquals("Iris-setosa", iris.responseAttribute().toString(y[1]));
assertEquals("Iris-setosa", iris.responseAttribute().toString(y[2]));
assertEquals(5.1, x[0][0], 1E-7);
assertEquals(3.5, x[0][1], 1E-7);
assertEquals(1.4, x[0][2], 1E-7);
assertEquals(0.2, x[0][3], 1E-7);
assertEquals("Iris-virginica", iris.responseAttribute().toString(y[149]));
assertEquals(5.9, x[149][0], 1E-7);
assertEquals(3.0, x[149][1], 1E-7);
assertEquals(5.1, x[149][2], 1E-7);
assertEquals(1.8, x[149][3], 1E-7);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:38,
代码来源:ArffParserTest.java
示例38: testParseSparse
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of parse method, of class ArffParser.
*/
@Test
public void testParseSparse() throws Exception {
System.out.println("sparse");
try {
ArffParser arffParser = new ArffParser();
AttributeDataset sparse = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/sparse.arff"));
double[][] x = sparse.toArray(new double[sparse.size()][]);
assertEquals(2, sparse.size());
assertEquals(5, sparse.attributes().length);
assertEquals(0.0, x[0][0], 1E-7);
assertEquals(2.0, x[0][1], 1E-7);
assertEquals(0.0, x[0][2], 1E-7);
assertEquals(3.0, x[0][3], 1E-7);
assertEquals(0.0, x[0][4], 1E-7);
assertEquals(0.0, x[1][0], 1E-7);
assertEquals(0.0, x[1][1], 1E-7);
assertEquals(1.0, x[1][2], 1E-7);
assertEquals(0.0, x[1][3], 1E-7);
assertEquals(1.0, x[1][4], 1E-7);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:30,
代码来源:ArffParserTest.java
示例39: testKPCAThreshold
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class PCA.
*/
@Test
public void testKPCAThreshold() {
System.out.println("learn threshold");
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
try {
AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
double[][] x = iris.toArray(new double[iris.size()][]);
KPCA<double[]> kpca = new KPCA(x, new GaussianKernel(Math.sqrt(2.5)), 1E-4);
assertTrue(Math.equals(latent, kpca.getVariances(), 1E-3));
double[][] points = kpca.project(x);
points[0] = kpca.project(x[0]);
assertTrue(Math.equals(points, kpca.getCoordinates(), 1E-7));
/*
for (int j = 0; j < points[0].length; j++) {
double sign = Math.signum(points[0][j] / scores[0][j]);
for (int i = 0; i < points.length; i++) {
points[i][j] *= sign;
}
}
assertTrue(Math.equals(scores, points, 1E-1));
*/
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:32,
代码来源:KPCATest.java
示例40: testUSPS
点赞 2
import smile.data.AttributeDataset; //导入依赖的package包/类
/**
* Test of learn method, of class DecisionTree.
*/
@Test
public void testUSPS() {
System.out.println("USPS");
DelimitedTextParser parser = new DelimitedTextParser();
parser.setResponseIndex(new NominalAttribute("class"), 0);
try {
AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));
double[][] x = train.toArray(new double[train.size()][]);
int[] y = train.toArray(new int[train.size()]);
double[][] testx = test.toArray(new double[test.size()][]);
int[] testy = test.toArray(new int[test.size()]);
DecisionTree tree = new DecisionTree(x, y, 350, DecisionTree.SplitRule.ENTROPY);
int error = 0;
for (int i = 0; i < testx.length; i++) {
if (tree.predict(testx[i]) != testy[i]) {
error++;
}
}
System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
assertEquals(328, error);
} catch (Exception ex) {
System.err.println(ex);
}
}
开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:33,
代码来源:DecisionTreeTest.java